Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +3 -0
- phivenv/Lib/site-packages/torch/_C/_VariableFunctions.pyi +0 -0
- phivenv/Lib/site-packages/torch/_C/__init__.pyi +0 -0
- phivenv/Lib/site-packages/torch/_C/_aoti.pyi +164 -0
- phivenv/Lib/site-packages/torch/_C/_autograd.pyi +141 -0
- phivenv/Lib/site-packages/torch/_C/_cpu.pyi +13 -0
- phivenv/Lib/site-packages/torch/_C/_cudnn.pyi +14 -0
- phivenv/Lib/site-packages/torch/_C/_cusparselt.pyi +1 -0
- phivenv/Lib/site-packages/torch/_C/_distributed_autograd.pyi +26 -0
- phivenv/Lib/site-packages/torch/_C/_distributed_c10d.pyi +797 -0
- phivenv/Lib/site-packages/torch/_C/_distributed_rpc.pyi +188 -0
- phivenv/Lib/site-packages/torch/_C/_distributed_rpc_testing.pyi +32 -0
- phivenv/Lib/site-packages/torch/_C/_dynamo/__init__.pyi +4 -0
- phivenv/Lib/site-packages/torch/_C/_dynamo/compiled_autograd.pyi +13 -0
- phivenv/Lib/site-packages/torch/_C/_dynamo/eval_frame.pyi +71 -0
- phivenv/Lib/site-packages/torch/_C/_dynamo/guards.pyi +191 -0
- phivenv/Lib/site-packages/torch/_C/_export/__init__.pyi +9 -0
- phivenv/Lib/site-packages/torch/_C/_export/pt2_archive_constants.pyi +22 -0
- phivenv/Lib/site-packages/torch/_C/_functions.pyi +19 -0
- phivenv/Lib/site-packages/torch/_C/_functorch.pyi +86 -0
- phivenv/Lib/site-packages/torch/_C/_instruction_counter.pyi +4 -0
- phivenv/Lib/site-packages/torch/_C/_itt.pyi +5 -0
- phivenv/Lib/site-packages/torch/_C/_lazy.pyi +26 -0
- phivenv/Lib/site-packages/torch/_C/_lazy_ts_backend.pyi +12 -0
- phivenv/Lib/site-packages/torch/_C/_monitor.pyi +58 -0
- phivenv/Lib/site-packages/torch/_C/_nn.pyi +175 -0
- phivenv/Lib/site-packages/torch/_C/_nvtx.pyi +9 -0
- phivenv/Lib/site-packages/torch/_C/_onnx.pyi +39 -0
- phivenv/Lib/site-packages/torch/_C/_profiler.pyi +246 -0
- phivenv/Lib/site-packages/torch/_C/_verbose.pyi +3 -0
- phivenv/Lib/site-packages/torch/_C_flatbuffer/__init__.pyi +11 -0
- phivenv/Lib/site-packages/torch/_awaits/__init__.py +53 -0
- phivenv/Lib/site-packages/torch/_awaits/__pycache__/__init__.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/torch/_custom_op/__init__.py +0 -0
- phivenv/Lib/site-packages/torch/_custom_op/__pycache__/__init__.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/torch/_custom_op/__pycache__/autograd.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/torch/_custom_op/__pycache__/impl.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/torch/_custom_op/autograd.py +307 -0
- phivenv/Lib/site-packages/torch/_custom_op/impl.py +715 -0
- phivenv/Lib/site-packages/torch/_decomp/__init__.py +544 -0
- phivenv/Lib/site-packages/torch/_decomp/__pycache__/__init__.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/torch/_decomp/__pycache__/decompositions_for_jvp.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/torch/_decomp/__pycache__/decompositions_for_rng.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/torch/_decomp/decompositions.py +0 -0
- phivenv/Lib/site-packages/torch/_decomp/decompositions_for_jvp.py +335 -0
- phivenv/Lib/site-packages/torch/_decomp/decompositions_for_rng.py +266 -0
- phivenv/Lib/site-packages/torch/_dispatch/__init__.py +0 -0
- phivenv/Lib/site-packages/torch/_dispatch/__pycache__/__init__.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/torch/_dispatch/__pycache__/python.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/torch/_dispatch/python.py +192 -0
.gitattributes
CHANGED
|
@@ -112,3 +112,6 @@ phivenv/Lib/site-packages/torch/lib/cpuinfo.lib filter=lfs diff=lfs merge=lfs -t
|
|
| 112 |
phivenv/Lib/site-packages/torch/lib/fbgemm.dll filter=lfs diff=lfs merge=lfs -text
|
| 113 |
phivenv/Lib/site-packages/torch/lib/c10.dll filter=lfs diff=lfs merge=lfs -text
|
| 114 |
phivenv/Lib/site-packages/torch/lib/fbgemm.lib filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
phivenv/Lib/site-packages/torch/lib/fbgemm.dll filter=lfs diff=lfs merge=lfs -text
|
| 113 |
phivenv/Lib/site-packages/torch/lib/c10.dll filter=lfs diff=lfs merge=lfs -text
|
| 114 |
phivenv/Lib/site-packages/torch/lib/fbgemm.lib filter=lfs diff=lfs merge=lfs -text
|
| 115 |
+
phivenv/Lib/site-packages/torch/lib/fmt.lib filter=lfs diff=lfs merge=lfs -text
|
| 116 |
+
phivenv/Lib/site-packages/torch/lib/libiomp5md.dll filter=lfs diff=lfs merge=lfs -text
|
| 117 |
+
phivenv/Lib/site-packages/torch/lib/libittnotify.lib filter=lfs diff=lfs merge=lfs -text
|
phivenv/Lib/site-packages/torch/_C/_VariableFunctions.pyi
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
phivenv/Lib/site-packages/torch/_C/__init__.pyi
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
phivenv/Lib/site-packages/torch/_C/_aoti.pyi
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ctypes import c_void_p
|
| 2 |
+
from typing import overload, Protocol
|
| 3 |
+
|
| 4 |
+
from torch import Tensor
|
| 5 |
+
|
| 6 |
+
# Defined in torch/csrc/inductor/aoti_runner/pybind.cpp
|
| 7 |
+
|
| 8 |
+
# Tensor to AtenTensorHandle
|
| 9 |
+
def unsafe_alloc_void_ptrs_from_tensors(tensors: list[Tensor]) -> list[c_void_p]: ...
|
| 10 |
+
def unsafe_alloc_void_ptr_from_tensor(tensor: Tensor) -> c_void_p: ...
|
| 11 |
+
|
| 12 |
+
# AtenTensorHandle to Tensor
|
| 13 |
+
def alloc_tensors_by_stealing_from_void_ptrs(
|
| 14 |
+
handles: list[c_void_p],
|
| 15 |
+
) -> list[Tensor]: ...
|
| 16 |
+
def alloc_tensor_by_stealing_from_void_ptr(
|
| 17 |
+
handle: c_void_p,
|
| 18 |
+
) -> Tensor: ...
|
| 19 |
+
|
| 20 |
+
class AOTIModelContainerRunner(Protocol):
|
| 21 |
+
def run(
|
| 22 |
+
self, inputs: list[Tensor], stream_handle: c_void_p = ...
|
| 23 |
+
) -> list[Tensor]: ...
|
| 24 |
+
def get_call_spec(self) -> list[str]: ...
|
| 25 |
+
def get_constant_names_to_original_fqns(self) -> dict[str, str]: ...
|
| 26 |
+
def get_constant_names_to_dtypes(self) -> dict[str, int]: ...
|
| 27 |
+
def extract_constants_map(self, use_inactive: bool) -> dict[str, Tensor]: ...
|
| 28 |
+
def update_constant_buffer(
|
| 29 |
+
self,
|
| 30 |
+
tensor_map: dict[str, Tensor],
|
| 31 |
+
use_inactive: bool,
|
| 32 |
+
validate_full_updates: bool,
|
| 33 |
+
user_managed: bool = ...,
|
| 34 |
+
) -> None: ...
|
| 35 |
+
def swap_constant_buffer(self) -> None: ...
|
| 36 |
+
def free_inactive_constant_buffer(self) -> None: ...
|
| 37 |
+
|
| 38 |
+
class AOTIModelContainerRunnerCpu:
|
| 39 |
+
def __init__(self, model_so_path: str, num_models: int) -> None: ...
|
| 40 |
+
def run(
|
| 41 |
+
self, inputs: list[Tensor], stream_handle: c_void_p = ...
|
| 42 |
+
) -> list[Tensor]: ...
|
| 43 |
+
def get_call_spec(self) -> list[str]: ...
|
| 44 |
+
def get_constant_names_to_original_fqns(self) -> dict[str, str]: ...
|
| 45 |
+
def get_constant_names_to_dtypes(self) -> dict[str, int]: ...
|
| 46 |
+
def extract_constants_map(self, use_inactive: bool) -> dict[str, Tensor]: ...
|
| 47 |
+
def update_constant_buffer(
|
| 48 |
+
self,
|
| 49 |
+
tensor_map: dict[str, Tensor],
|
| 50 |
+
use_inactive: bool,
|
| 51 |
+
validate_full_updates: bool,
|
| 52 |
+
user_managed: bool = ...,
|
| 53 |
+
) -> None: ...
|
| 54 |
+
def swap_constant_buffer(self) -> None: ...
|
| 55 |
+
def free_inactive_constant_buffer(self) -> None: ...
|
| 56 |
+
|
| 57 |
+
class AOTIModelContainerRunnerCuda:
|
| 58 |
+
@overload
|
| 59 |
+
def __init__(self, model_so_path: str, num_models: int) -> None: ...
|
| 60 |
+
@overload
|
| 61 |
+
def __init__(
|
| 62 |
+
self, model_so_path: str, num_models: int, device_str: str
|
| 63 |
+
) -> None: ...
|
| 64 |
+
@overload
|
| 65 |
+
def __init__(
|
| 66 |
+
self, model_so_path: str, num_models: int, device_str: str, cubin_dir: str
|
| 67 |
+
) -> None: ...
|
| 68 |
+
def run(
|
| 69 |
+
self, inputs: list[Tensor], stream_handle: c_void_p = ...
|
| 70 |
+
) -> list[Tensor]: ...
|
| 71 |
+
def get_call_spec(self) -> list[str]: ...
|
| 72 |
+
def get_constant_names_to_original_fqns(self) -> dict[str, str]: ...
|
| 73 |
+
def get_constant_names_to_dtypes(self) -> dict[str, int]: ...
|
| 74 |
+
def extract_constants_map(self, use_inactive: bool) -> dict[str, Tensor]: ...
|
| 75 |
+
def update_constant_buffer(
|
| 76 |
+
self,
|
| 77 |
+
tensor_map: dict[str, Tensor],
|
| 78 |
+
use_inactive: bool,
|
| 79 |
+
validate_full_updates: bool,
|
| 80 |
+
user_managed: bool = ...,
|
| 81 |
+
) -> None: ...
|
| 82 |
+
def swap_constant_buffer(self) -> None: ...
|
| 83 |
+
def free_inactive_constant_buffer(self) -> None: ...
|
| 84 |
+
|
| 85 |
+
class AOTIModelContainerRunnerXpu:
|
| 86 |
+
@overload
|
| 87 |
+
def __init__(self, model_so_path: str, num_models: int) -> None: ...
|
| 88 |
+
@overload
|
| 89 |
+
def __init__(
|
| 90 |
+
self, model_so_path: str, num_models: int, device_str: str
|
| 91 |
+
) -> None: ...
|
| 92 |
+
@overload
|
| 93 |
+
def __init__(
|
| 94 |
+
self, model_so_path: str, num_models: int, device_str: str, kernel_bin_dir: str
|
| 95 |
+
) -> None: ...
|
| 96 |
+
def run(
|
| 97 |
+
self, inputs: list[Tensor], stream_handle: c_void_p = ...
|
| 98 |
+
) -> list[Tensor]: ...
|
| 99 |
+
def get_call_spec(self) -> list[str]: ...
|
| 100 |
+
def get_constant_names_to_original_fqns(self) -> dict[str, str]: ...
|
| 101 |
+
def get_constant_names_to_dtypes(self) -> dict[str, int]: ...
|
| 102 |
+
def extract_constants_map(self, use_inactive: bool) -> dict[str, Tensor]: ...
|
| 103 |
+
def update_constant_buffer(
|
| 104 |
+
self,
|
| 105 |
+
tensor_map: dict[str, Tensor],
|
| 106 |
+
use_inactive: bool,
|
| 107 |
+
validate_full_updates: bool,
|
| 108 |
+
user_managed: bool = ...,
|
| 109 |
+
) -> None: ...
|
| 110 |
+
def swap_constant_buffer(self) -> None: ...
|
| 111 |
+
def free_inactive_constant_buffer(self) -> None: ...
|
| 112 |
+
|
| 113 |
+
class AOTIModelContainerRunnerMps:
|
| 114 |
+
def __init__(self, model_so_path: str, num_models: int) -> None: ...
|
| 115 |
+
def run(
|
| 116 |
+
self, inputs: list[Tensor], stream_handle: c_void_p = ...
|
| 117 |
+
) -> list[Tensor]: ...
|
| 118 |
+
def get_call_spec(self) -> list[str]: ...
|
| 119 |
+
def get_constant_names_to_original_fqns(self) -> dict[str, str]: ...
|
| 120 |
+
def get_constant_names_to_dtypes(self) -> dict[str, int]: ...
|
| 121 |
+
def extract_constants_map(self, use_inactive: bool) -> dict[str, Tensor]: ...
|
| 122 |
+
def update_constant_buffer(
|
| 123 |
+
self,
|
| 124 |
+
tensor_map: dict[str, Tensor],
|
| 125 |
+
use_inactive: bool,
|
| 126 |
+
validate_full_updates: bool,
|
| 127 |
+
user_managed: bool = ...,
|
| 128 |
+
) -> None: ...
|
| 129 |
+
def swap_constant_buffer(self) -> None: ...
|
| 130 |
+
def free_inactive_constant_buffer(self) -> None: ...
|
| 131 |
+
|
| 132 |
+
# Defined in torch/csrc/inductor/aoti_package/pybind.cpp
|
| 133 |
+
class AOTIModelPackageLoader:
|
| 134 |
+
def __init__(
|
| 135 |
+
self,
|
| 136 |
+
model_package_path: str,
|
| 137 |
+
model_name: str,
|
| 138 |
+
run_single_threaded: bool,
|
| 139 |
+
num_runners: int,
|
| 140 |
+
device_index: int,
|
| 141 |
+
) -> None: ...
|
| 142 |
+
def get_metadata(self) -> dict[str, str]: ...
|
| 143 |
+
def run(
|
| 144 |
+
self, inputs: list[Tensor], stream_handle: c_void_p = ...
|
| 145 |
+
) -> list[Tensor]: ...
|
| 146 |
+
def boxed_run(
|
| 147 |
+
self, inputs: list[Tensor], stream_handle: c_void_p = ...
|
| 148 |
+
) -> list[Tensor]: ...
|
| 149 |
+
def get_call_spec(self) -> list[str]: ...
|
| 150 |
+
def get_constant_fqns(self) -> list[str]: ...
|
| 151 |
+
def load_constants(
|
| 152 |
+
self,
|
| 153 |
+
constants_map: dict[str, Tensor],
|
| 154 |
+
use_inactive: bool,
|
| 155 |
+
check_full_update: bool,
|
| 156 |
+
user_managed: bool = ...,
|
| 157 |
+
) -> None: ...
|
| 158 |
+
def update_constant_buffer(
|
| 159 |
+
self,
|
| 160 |
+
tensor_map: dict[str, Tensor],
|
| 161 |
+
use_inactive: bool,
|
| 162 |
+
validate_full_updates: bool,
|
| 163 |
+
user_managed: bool = ...,
|
| 164 |
+
) -> None: ...
|
phivenv/Lib/site-packages/torch/_C/_autograd.pyi
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from enum import Enum
|
| 3 |
+
from typing import Any, Callable
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch._C._profiler import (
|
| 7 |
+
_ProfilerEvent,
|
| 8 |
+
ActiveProfilerType,
|
| 9 |
+
ProfilerActivity,
|
| 10 |
+
ProfilerConfig,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
# Defined in torch/csrc/autograd/init.cpp
|
| 14 |
+
|
| 15 |
+
class DeviceType(Enum):
|
| 16 |
+
CPU = ...
|
| 17 |
+
CUDA = ...
|
| 18 |
+
XPU = ...
|
| 19 |
+
MKLDNN = ...
|
| 20 |
+
OPENGL = ...
|
| 21 |
+
OPENCL = ...
|
| 22 |
+
IDEEP = ...
|
| 23 |
+
HIP = ...
|
| 24 |
+
FPGA = ...
|
| 25 |
+
MAIA = ...
|
| 26 |
+
XLA = ...
|
| 27 |
+
MTIA = ...
|
| 28 |
+
MPS = ...
|
| 29 |
+
HPU = ...
|
| 30 |
+
Meta = ...
|
| 31 |
+
Vulkan = ...
|
| 32 |
+
Metal = ...
|
| 33 |
+
PrivateUse1 = ...
|
| 34 |
+
|
| 35 |
+
class ProfilerEvent:
|
| 36 |
+
def cpu_elapsed_us(self, other: ProfilerEvent) -> float: ...
|
| 37 |
+
def cpu_memory_usage(self) -> int: ...
|
| 38 |
+
def cuda_elapsed_us(self, other: ProfilerEvent) -> float: ...
|
| 39 |
+
def privateuse1_elapsed_us(self, other: ProfilerEvent) -> float: ...
|
| 40 |
+
def cuda_memory_usage(self) -> int: ...
|
| 41 |
+
def device(self) -> int: ...
|
| 42 |
+
def handle(self) -> int: ...
|
| 43 |
+
def has_cuda(self) -> bool: ...
|
| 44 |
+
def is_remote(self) -> bool: ...
|
| 45 |
+
def kind(self) -> int: ...
|
| 46 |
+
def name(self) -> str: ...
|
| 47 |
+
def node_id(self) -> int: ...
|
| 48 |
+
def sequence_nr(self) -> int: ...
|
| 49 |
+
def shapes(self) -> list[list[int]]: ...
|
| 50 |
+
def thread_id(self) -> int: ...
|
| 51 |
+
def flops(self) -> float: ...
|
| 52 |
+
def is_async(self) -> bool: ...
|
| 53 |
+
|
| 54 |
+
class _KinetoEvent:
|
| 55 |
+
def name(self) -> str: ...
|
| 56 |
+
def overload_name(self) -> str: ...
|
| 57 |
+
def device_index(self) -> int: ...
|
| 58 |
+
def device_resource_id(self) -> int: ...
|
| 59 |
+
def start_ns(self) -> int: ...
|
| 60 |
+
def end_ns(self) -> int: ...
|
| 61 |
+
def duration_ns(self) -> int: ...
|
| 62 |
+
def is_async(self) -> bool: ...
|
| 63 |
+
def linked_correlation_id(self) -> int: ...
|
| 64 |
+
def shapes(self) -> list[list[int]]: ...
|
| 65 |
+
def dtypes(self) -> list[str]: ...
|
| 66 |
+
def concrete_inputs(self) -> list[Any]: ...
|
| 67 |
+
def kwinputs(self) -> dict[str, Any]: ...
|
| 68 |
+
def device_type(self) -> DeviceType: ...
|
| 69 |
+
def start_thread_id(self) -> int: ...
|
| 70 |
+
def end_thread_id(self) -> int: ...
|
| 71 |
+
def correlation_id(self) -> int: ...
|
| 72 |
+
def fwd_thread_id(self) -> int: ...
|
| 73 |
+
def stack(self) -> list[str]: ...
|
| 74 |
+
def scope(self) -> int: ...
|
| 75 |
+
def sequence_nr(self) -> int: ...
|
| 76 |
+
def flops(self) -> int: ...
|
| 77 |
+
def cuda_elapsed_us(self) -> int: ...
|
| 78 |
+
def privateuse1_elapsed_us(self) -> int: ...
|
| 79 |
+
def is_user_annotation(self) -> bool: ...
|
| 80 |
+
|
| 81 |
+
class _ProfilerResult:
|
| 82 |
+
def events(self) -> list[_KinetoEvent]: ...
|
| 83 |
+
def legacy_events(self) -> list[list[ProfilerEvent]]: ...
|
| 84 |
+
def save(self, path: str) -> None: ...
|
| 85 |
+
def experimental_event_tree(self) -> list[_ProfilerEvent]: ...
|
| 86 |
+
def trace_start_ns(self) -> int: ...
|
| 87 |
+
|
| 88 |
+
class SavedTensor: ...
|
| 89 |
+
|
| 90 |
+
def _enable_profiler(
|
| 91 |
+
config: ProfilerConfig,
|
| 92 |
+
activities: set[ProfilerActivity],
|
| 93 |
+
) -> None: ...
|
| 94 |
+
def _prepare_profiler(
|
| 95 |
+
config: ProfilerConfig,
|
| 96 |
+
activities: set[ProfilerActivity],
|
| 97 |
+
) -> None: ...
|
| 98 |
+
def _toggle_collection_dynamic(
|
| 99 |
+
enable: bool,
|
| 100 |
+
activities: set[ProfilerActivity],
|
| 101 |
+
) -> None: ...
|
| 102 |
+
def _disable_profiler() -> _ProfilerResult: ...
|
| 103 |
+
def _profiler_enabled() -> bool: ...
|
| 104 |
+
def _add_metadata_json(key: str, value: str) -> None: ...
|
| 105 |
+
def _kineto_step() -> None: ...
|
| 106 |
+
def _get_current_graph_task_keep_graph() -> bool: ...
|
| 107 |
+
def _get_sequence_nr() -> int: ...
|
| 108 |
+
def kineto_available() -> bool: ...
|
| 109 |
+
def _record_function_with_args_enter(name: str, *args) -> torch.Tensor: ...
|
| 110 |
+
def _record_function_with_args_exit(handle: torch.Tensor) -> None: ...
|
| 111 |
+
def _supported_activities() -> set[ProfilerActivity]: ...
|
| 112 |
+
def _enable_record_function(enable: bool) -> None: ...
|
| 113 |
+
def _set_empty_test_observer(is_global: bool, sampling_prob: float) -> None: ...
|
| 114 |
+
def _push_saved_tensors_default_hooks(
|
| 115 |
+
pack_hook: Callable[[torch.Tensor], Any],
|
| 116 |
+
unpack_hook: Callable[[Any], torch.Tensor],
|
| 117 |
+
) -> None: ...
|
| 118 |
+
def _pop_saved_tensors_default_hooks() -> None: ...
|
| 119 |
+
def _top_saved_tensors_default_hooks(
|
| 120 |
+
ignore_is_tracing: bool,
|
| 121 |
+
) -> tuple[Callable[[torch.Tensor], Any], Callable[[Any], torch.Tensor]]: ...
|
| 122 |
+
def _unsafe_set_version_counter(
|
| 123 |
+
t: tuple[torch.Tensor, ...], prev_version: tuple[int, ...]
|
| 124 |
+
) -> None: ...
|
| 125 |
+
def _enable_profiler_legacy(config: ProfilerConfig) -> None: ...
|
| 126 |
+
def _disable_profiler_legacy() -> list[list[ProfilerEvent]]: ...
|
| 127 |
+
def _profiler_type() -> ActiveProfilerType: ...
|
| 128 |
+
def _saved_tensors_hooks_enable() -> None: ...
|
| 129 |
+
def _saved_tensors_hooks_disable(message: str, fail_if_non_empty=True) -> None: ...
|
| 130 |
+
def _saved_tensors_hooks_get_disabled_error_message() -> str | None: ...
|
| 131 |
+
def _saved_tensors_hooks_set_tracing(is_tracing: bool) -> bool: ...
|
| 132 |
+
|
| 133 |
+
class CreationMeta(Enum):
|
| 134 |
+
DEFAULT = ...
|
| 135 |
+
IN_CUSTOM_FUNCTION = ...
|
| 136 |
+
MULTI_OUTPUT_NODE = ...
|
| 137 |
+
NO_GRAD_MODE = ...
|
| 138 |
+
INFERENCE_MODE = ...
|
| 139 |
+
|
| 140 |
+
def _set_creation_meta(t: torch.Tensor, creation_meta: CreationMeta) -> None: ...
|
| 141 |
+
def _get_creation_meta(t: torch.Tensor) -> CreationMeta: ...
|
phivenv/Lib/site-packages/torch/_C/_cpu.pyi
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.types import _bool, _int
|
| 2 |
+
|
| 3 |
+
# Defined in torch/csrc/cpu/Module.cpp
|
| 4 |
+
|
| 5 |
+
def _is_avx2_supported() -> _bool: ...
|
| 6 |
+
def _is_avx512_supported() -> _bool: ...
|
| 7 |
+
def _is_avx512_vnni_supported() -> _bool: ...
|
| 8 |
+
def _is_avx512_bf16_supported() -> _bool: ...
|
| 9 |
+
def _is_amx_tile_supported() -> _bool: ...
|
| 10 |
+
def _is_amx_fp16_supported() -> _bool: ...
|
| 11 |
+
def _init_amx() -> _bool: ...
|
| 12 |
+
def _L1d_cache_size() -> _int: ...
|
| 13 |
+
def _L2_cache_size() -> _int: ...
|
phivenv/Lib/site-packages/torch/_C/_cudnn.pyi
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import IntEnum
|
| 2 |
+
|
| 3 |
+
# Defined in torch/csrc/cuda/shared/cudnn.cpp
|
| 4 |
+
is_cuda: bool
|
| 5 |
+
|
| 6 |
+
def getRuntimeVersion() -> tuple[int, int, int]: ...
|
| 7 |
+
def getCompileVersion() -> tuple[int, int, int]: ...
|
| 8 |
+
def getVersionInt() -> int: ...
|
| 9 |
+
|
| 10 |
+
class RNNMode(IntEnum):
|
| 11 |
+
rnn_relu = ...
|
| 12 |
+
rnn_tanh = ...
|
| 13 |
+
lstm = ...
|
| 14 |
+
gru = ...
|
phivenv/Lib/site-packages/torch/_C/_cusparselt.pyi
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
def getVersionInt() -> int: ...
|
phivenv/Lib/site-packages/torch/_C/_distributed_autograd.pyi
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
# This module is defined in torch/csrc/distributed/autograd/init.cpp
|
| 6 |
+
|
| 7 |
+
class DistAutogradContext:
|
| 8 |
+
def _context_id(self) -> int: ...
|
| 9 |
+
def _recv_functions(self) -> dict[int, Any]: ...
|
| 10 |
+
def _send_functions(self) -> dict[int, Any]: ...
|
| 11 |
+
def _known_worker_ids(self) -> set[int]: ...
|
| 12 |
+
|
| 13 |
+
def _new_context() -> DistAutogradContext: ...
|
| 14 |
+
def _release_context(context_id: int) -> None: ...
|
| 15 |
+
def _get_max_id() -> int: ...
|
| 16 |
+
def _is_valid_context(worker_id: int) -> bool: ...
|
| 17 |
+
def _retrieve_context(context_id: int) -> DistAutogradContext: ...
|
| 18 |
+
def _current_context() -> DistAutogradContext: ...
|
| 19 |
+
def _init(worker_id: int) -> None: ...
|
| 20 |
+
def _get_debug_info() -> dict[str, str]: ...
|
| 21 |
+
def backward(
|
| 22 |
+
context_id: int,
|
| 23 |
+
roots: list[torch.Tensor],
|
| 24 |
+
retain_graph: bool = False,
|
| 25 |
+
) -> None: ...
|
| 26 |
+
def get_gradients(context_id: int) -> dict[torch.Tensor, torch.Tensor]: ...
|
phivenv/Lib/site-packages/torch/_C/_distributed_c10d.pyi
ADDED
|
@@ -0,0 +1,797 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
# mypy: disable-error-code="type-arg"
|
| 3 |
+
from datetime import timedelta
|
| 4 |
+
from enum import Enum
|
| 5 |
+
from typing import Any, Optional, overload, Union
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
from torch._C import ScriptObject
|
| 10 |
+
from torch._C._autograd import DeviceType
|
| 11 |
+
from torch.futures import Future
|
| 12 |
+
|
| 13 |
+
# This module is defined in torch/csrc/distributed/c10d/init.cpp
|
| 14 |
+
|
| 15 |
+
_DEFAULT_FIRST_BUCKET_BYTES: int
|
| 16 |
+
_DEFAULT_NO_TIMEOUT: timedelta
|
| 17 |
+
_DEFAULT_PG_TIMEOUT: timedelta
|
| 18 |
+
_DEFAULT_PG_NCCL_TIMEOUT: timedelta
|
| 19 |
+
|
| 20 |
+
class BuiltinCommHookType(Enum):
|
| 21 |
+
ALLREDUCE = ...
|
| 22 |
+
FP16_COMPRESS = ...
|
| 23 |
+
|
| 24 |
+
def _register_comm_hook(reducer: Reducer, state: Any, comm_hook: Any): ...
|
| 25 |
+
def _register_builtin_comm_hook(
|
| 26 |
+
reducer: Reducer,
|
| 27 |
+
comm_hook_type: BuiltinCommHookType,
|
| 28 |
+
): ...
|
| 29 |
+
def _set_global_rank(rank: int) -> None: ...
|
| 30 |
+
def _hash_tensors(tensors: list[Tensor]) -> int: ...
|
| 31 |
+
|
| 32 |
+
class GradBucket:
|
| 33 |
+
def index(self) -> int: ...
|
| 34 |
+
def buffer(self) -> Tensor: ...
|
| 35 |
+
def gradients(self) -> list[Tensor]: ...
|
| 36 |
+
def is_last(self) -> bool: ...
|
| 37 |
+
def set_buffer(self, tensor: Tensor) -> None: ...
|
| 38 |
+
def parameters(self) -> list[Tensor]: ...
|
| 39 |
+
|
| 40 |
+
class Reducer:
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
params: list[Tensor],
|
| 44 |
+
bucket_indices: list[list[int]],
|
| 45 |
+
per_bucket_size_limits: list[int],
|
| 46 |
+
process_group: ProcessGroup,
|
| 47 |
+
expect_sparse_gradients: list[bool] = ...,
|
| 48 |
+
bucket_bytes_cap: int = ..., # kDefaultBucketBytesCap in reducer.hpp
|
| 49 |
+
find_unused_parameters: bool = ...,
|
| 50 |
+
gradient_as_bucket_view: bool = ...,
|
| 51 |
+
param_to_name_mapping: dict[int, str] = ...,
|
| 52 |
+
first_bucket_types_cap: int = ..., # kDefaultFirstBucketBytes in reducer.hpp
|
| 53 |
+
skip_all_reduce_unused_params: bool = ...,
|
| 54 |
+
use_python_reducer: bool = ...,
|
| 55 |
+
) -> None: ...
|
| 56 |
+
def prepare_for_forward(self) -> None: ...
|
| 57 |
+
def prepare_for_backward(self, output: list[Tensor]) -> None: ...
|
| 58 |
+
def get_backward_stats(self) -> list[int]: ...
|
| 59 |
+
def _install_post_backward_futures(self, futures: list[Future]) -> None: ...
|
| 60 |
+
def _rebuild_buckets(self) -> bool: ...
|
| 61 |
+
def _get_zeros_like_grad_buckets(self) -> list[GradBucket]: ...
|
| 62 |
+
def _push_all_rebuilt_params(self) -> None: ...
|
| 63 |
+
def _set_forward_pass_work_handle(
|
| 64 |
+
self,
|
| 65 |
+
work: Work,
|
| 66 |
+
use_static_world_size: bool,
|
| 67 |
+
): ...
|
| 68 |
+
def _get_local_used_map(self) -> Tensor: ...
|
| 69 |
+
def _set_ddp_runtime_logging_sample_rate(self, sample_rate: int) -> None: ...
|
| 70 |
+
def _set_static_graph(self) -> None: ...
|
| 71 |
+
def _run_comm_hook(self, bucket: GradBucket) -> Future: ...
|
| 72 |
+
def set_logger(self, logger: Logger) -> None: ...
|
| 73 |
+
def _remove_autograd_hooks(self) -> None: ...
|
| 74 |
+
def _check_reducer_finalized(self) -> None: ...
|
| 75 |
+
def _set_sparse_metadata(self, global_unique_ids: dict[str, Tensor]) -> None: ...
|
| 76 |
+
def _reset_state(self) -> None: ...
|
| 77 |
+
def _update_process_group(self, new_process_group: ProcessGroup) -> None: ...
|
| 78 |
+
|
| 79 |
+
class DDPLoggingData:
|
| 80 |
+
strs_map: dict[str, str]
|
| 81 |
+
ints_map: dict[str, int]
|
| 82 |
+
|
| 83 |
+
class Logger:
|
| 84 |
+
def __init__(self, reducer: Reducer) -> None: ...
|
| 85 |
+
def set_construction_data_and_log(
|
| 86 |
+
self,
|
| 87 |
+
module_name: str,
|
| 88 |
+
device_ids: list[int],
|
| 89 |
+
output_device: int,
|
| 90 |
+
broadcast_buffers: bool,
|
| 91 |
+
has_sync_bn: bool,
|
| 92 |
+
static_graph: bool,
|
| 93 |
+
): ...
|
| 94 |
+
def set_runtime_stats_and_log(self) -> None: ...
|
| 95 |
+
def set_error_and_log(self, error: str) -> None: ...
|
| 96 |
+
def _get_ddp_logging_data(self) -> DDPLoggingData: ...
|
| 97 |
+
def _set_comm_hook_name(self, comm_hook: str) -> None: ...
|
| 98 |
+
def _set_uneven_input_join(self) -> None: ...
|
| 99 |
+
def _set_static_graph(self) -> None: ...
|
| 100 |
+
|
| 101 |
+
class _WorkerServer:
|
| 102 |
+
def __init__(self, socket_path: str) -> None: ...
|
| 103 |
+
def shutdown(self) -> None: ...
|
| 104 |
+
|
| 105 |
+
def get_debug_level(): ...
|
| 106 |
+
def set_debug_level(): ...
|
| 107 |
+
def set_debug_level_from_env(): ...
|
| 108 |
+
|
| 109 |
+
class DebugLevel(Enum):
|
| 110 |
+
OFF = ...
|
| 111 |
+
INFO = ...
|
| 112 |
+
DETAIL = ...
|
| 113 |
+
|
| 114 |
+
class ReduceOp:
|
| 115 |
+
def __init__(self, op: RedOpType) -> None: ...
|
| 116 |
+
|
| 117 |
+
SUM: RedOpType = ...
|
| 118 |
+
AVG: RedOpType = ...
|
| 119 |
+
PRODUCT: RedOpType = ...
|
| 120 |
+
MIN: RedOpType = ...
|
| 121 |
+
MAX: RedOpType = ...
|
| 122 |
+
BAND: RedOpType = ...
|
| 123 |
+
BOR: RedOpType = ...
|
| 124 |
+
BXOR: RedOpType = ...
|
| 125 |
+
PREMUL_SUM: RedOpType = ...
|
| 126 |
+
UNUSED: RedOpType = ...
|
| 127 |
+
|
| 128 |
+
# mypy error being ignored:
|
| 129 |
+
# Detected enum "torch._C._distributed_c10d.ReduceOp.RedOpType" in a type
|
| 130 |
+
# stub with zero members. There is a chance this is due to a recent change
|
| 131 |
+
# in the semantics of enum membership. If so, use `member = value` to mark
|
| 132 |
+
# an enum member, instead of `member: type`
|
| 133 |
+
class RedOpType(Enum): ... # type: ignore[misc]
|
| 134 |
+
|
| 135 |
+
class BroadcastOptions:
|
| 136 |
+
rootRank: int
|
| 137 |
+
rootTensor: int
|
| 138 |
+
timeout: timedelta
|
| 139 |
+
asyncOp: bool
|
| 140 |
+
|
| 141 |
+
class AllreduceOptions:
|
| 142 |
+
reduceOp: ReduceOp
|
| 143 |
+
timeout: timedelta
|
| 144 |
+
asyncOp: bool
|
| 145 |
+
sparseIndices: Optional[Tensor]
|
| 146 |
+
|
| 147 |
+
class AllreduceCoalescedOptions(AllreduceOptions): ...
|
| 148 |
+
|
| 149 |
+
class ReduceOptions:
|
| 150 |
+
reduceOp: ReduceOp
|
| 151 |
+
rootRank: int
|
| 152 |
+
rootTensor: int
|
| 153 |
+
timeout: timedelta
|
| 154 |
+
asyncOp: bool
|
| 155 |
+
|
| 156 |
+
class AllgatherOptions:
|
| 157 |
+
timeout: timedelta
|
| 158 |
+
asyncOp: bool
|
| 159 |
+
|
| 160 |
+
class GatherOptions:
|
| 161 |
+
rootRank: int
|
| 162 |
+
timeout: timedelta
|
| 163 |
+
asyncOp: bool
|
| 164 |
+
|
| 165 |
+
class ScatterOptions:
|
| 166 |
+
rootRank: int
|
| 167 |
+
timeout: timedelta
|
| 168 |
+
asyncOp: bool
|
| 169 |
+
|
| 170 |
+
class ReduceScatterOptions:
|
| 171 |
+
reduceOp: ReduceOp
|
| 172 |
+
timeout: timedelta
|
| 173 |
+
asyncOp: bool
|
| 174 |
+
|
| 175 |
+
class BarrierOptions:
|
| 176 |
+
device_ids: list[int]
|
| 177 |
+
device: torch.device
|
| 178 |
+
timeout: timedelta
|
| 179 |
+
asyncOp: bool
|
| 180 |
+
|
| 181 |
+
class AllToAllOptions:
|
| 182 |
+
timeout: timedelta
|
| 183 |
+
asyncOp: bool
|
| 184 |
+
|
| 185 |
+
class Store:
|
| 186 |
+
def set(self, key: str, value: str): ...
|
| 187 |
+
def get(self, key: str) -> bytes: ...
|
| 188 |
+
def add(self, key: str, value: int) -> int: ...
|
| 189 |
+
def check(self, keys: list[str]) -> bool: ...
|
| 190 |
+
def compare_set(
|
| 191 |
+
self,
|
| 192 |
+
key: str,
|
| 193 |
+
expected_value: str,
|
| 194 |
+
desired_value: str,
|
| 195 |
+
) -> bytes: ...
|
| 196 |
+
def delete_key(self, key: str) -> bool: ...
|
| 197 |
+
def num_keys(self) -> int: ...
|
| 198 |
+
def set_timeout(self, timeout: timedelta): ...
|
| 199 |
+
@overload
|
| 200 |
+
def wait(self, keys: list[str]): ...
|
| 201 |
+
@overload
|
| 202 |
+
def wait(self, keys: list[str], timeout: timedelta): ...
|
| 203 |
+
def queue_pop(self, key: str, block: bool = True) -> bytes: ...
|
| 204 |
+
def queue_push(self, key: str, value: Union[bytes, str]) -> None: ...
|
| 205 |
+
def queue_len(self, key: str) -> int: ...
|
| 206 |
+
|
| 207 |
+
class FileStore(Store):
|
| 208 |
+
def __init__(self, path: str, numWorkers: int = ...) -> None: ...
|
| 209 |
+
|
| 210 |
+
class HashStore(Store):
|
| 211 |
+
def __init__(self) -> None: ...
|
| 212 |
+
|
| 213 |
+
class TCPStore(Store):
|
| 214 |
+
def __init__(
|
| 215 |
+
self,
|
| 216 |
+
host_name: str,
|
| 217 |
+
port: int,
|
| 218 |
+
world_size: int | None = ...,
|
| 219 |
+
is_master: bool = ...,
|
| 220 |
+
timeout: timedelta = ...,
|
| 221 |
+
wait_for_workers: bool = ...,
|
| 222 |
+
multi_tenant: bool = ...,
|
| 223 |
+
master_listen_fd: int | None = ...,
|
| 224 |
+
use_libuv: bool | None = ...,
|
| 225 |
+
) -> None: ...
|
| 226 |
+
@property
|
| 227 |
+
def host(self) -> str: ...
|
| 228 |
+
@property
|
| 229 |
+
def port(self) -> int: ...
|
| 230 |
+
|
| 231 |
+
class PrefixStore(Store):
|
| 232 |
+
def __init__(self, prefix: str, store: Store) -> None: ...
|
| 233 |
+
@property
|
| 234 |
+
def underlying_store(self) -> Store: ...
|
| 235 |
+
|
| 236 |
+
class _ControlCollectives:
|
| 237 |
+
def barrier(self, key: str, timeout: timedelta, blocking: bool) -> None: ...
|
| 238 |
+
def broadcast_send(self, key: str, data: str, timeout: timedelta) -> None: ...
|
| 239 |
+
def broadcast_recv(self, key: str, timeout: timedelta) -> str: ...
|
| 240 |
+
def gather_send(self, key: str, data: str, timeout: timedelta) -> None: ...
|
| 241 |
+
def gather_recv(self, key: str, timeout: timedelta) -> str: ...
|
| 242 |
+
def scatter_send(self, key: str, data: str, timeout: timedelta) -> None: ...
|
| 243 |
+
def scatter_recv(self, key: str, timeout: timedelta) -> str: ...
|
| 244 |
+
def all_gather(self, key: str, data: str, timeout: timedelta) -> str: ...
|
| 245 |
+
def all_sum(self, key: str, data: int, timeout: timedelta) -> int: ...
|
| 246 |
+
|
| 247 |
+
class _StoreCollectives(_ControlCollectives):
|
| 248 |
+
def __init__(self, store: Store, rank: int, world_size: int) -> None: ...
|
| 249 |
+
|
| 250 |
+
class _DistributedBackendOptions:
|
| 251 |
+
def __init__(self) -> None: ...
|
| 252 |
+
@property
|
| 253 |
+
def store(self) -> Store: ...
|
| 254 |
+
@store.setter
|
| 255 |
+
def store(self, store: Store) -> None: ...
|
| 256 |
+
@property
|
| 257 |
+
def group_rank(self) -> int: ...
|
| 258 |
+
@group_rank.setter
|
| 259 |
+
def group_rank(self, rank: int) -> None: ...
|
| 260 |
+
@property
|
| 261 |
+
def group_size(self) -> int: ...
|
| 262 |
+
@group_size.setter
|
| 263 |
+
def group_size(self, size: int) -> None: ...
|
| 264 |
+
@property
|
| 265 |
+
def timeout(self) -> timedelta: ...
|
| 266 |
+
@timeout.setter
|
| 267 |
+
def timeout(self, timeout: timedelta) -> None: ...
|
| 268 |
+
@property
|
| 269 |
+
def group_id(self) -> str: ...
|
| 270 |
+
@group_id.setter
|
| 271 |
+
def group_id(self, group_id: str) -> None: ...
|
| 272 |
+
@property
|
| 273 |
+
def global_ranks_in_group(self) -> list[int]: ...
|
| 274 |
+
@global_ranks_in_group.setter
|
| 275 |
+
def global_ranks_in_group(self, ranks: list[int]) -> None: ...
|
| 276 |
+
|
| 277 |
+
class Work:
|
| 278 |
+
def is_completed(self) -> bool: ...
|
| 279 |
+
def is_success(self) -> bool: ...
|
| 280 |
+
def exception(self) -> Any: ...
|
| 281 |
+
def wait(self, timeout: timedelta = ...) -> bool: ...
|
| 282 |
+
def get_future(self) -> Future: ...
|
| 283 |
+
def source_rank(self) -> int: ...
|
| 284 |
+
def _source_rank(self) -> int: ...
|
| 285 |
+
def result(self) -> list[Tensor]: ...
|
| 286 |
+
def synchronize(self): ...
|
| 287 |
+
def boxed(self) -> ScriptObject: ...
|
| 288 |
+
@staticmethod
|
| 289 |
+
def unbox(obj: ScriptObject) -> Work: ...
|
| 290 |
+
|
| 291 |
+
class Backend:
|
| 292 |
+
class Options:
|
| 293 |
+
def __init__(self, backend: str, timeout: timedelta = ...) -> None: ...
|
| 294 |
+
@property
|
| 295 |
+
def backend(self) -> str: ...
|
| 296 |
+
@property
|
| 297 |
+
def _timeout(self) -> timedelta: ...
|
| 298 |
+
@_timeout.setter
|
| 299 |
+
def _timeout(self, val: timedelta) -> None: ...
|
| 300 |
+
|
| 301 |
+
def __init__(
|
| 302 |
+
self,
|
| 303 |
+
rank: int,
|
| 304 |
+
size: int,
|
| 305 |
+
) -> None: ...
|
| 306 |
+
@property
|
| 307 |
+
def supports_splitting(self) -> bool: ...
|
| 308 |
+
@property
|
| 309 |
+
def supports_coalescing(self) -> bool: ...
|
| 310 |
+
@property
|
| 311 |
+
def supports_time_estimate(self) -> bool: ...
|
| 312 |
+
@property
|
| 313 |
+
def options(self) -> Options: ...
|
| 314 |
+
def rank(self) -> int: ...
|
| 315 |
+
def size(self) -> int: ...
|
| 316 |
+
def abort(self) -> None: ...
|
| 317 |
+
def shutdown(self) -> None: ...
|
| 318 |
+
def eager_connect_single_device(self, device: torch.device | None) -> None: ...
|
| 319 |
+
def _set_sequence_number_for_group(self) -> None: ...
|
| 320 |
+
def _set_default_timeout(self, timeout: timedelta) -> None: ...
|
| 321 |
+
def get_error(self) -> ErrorType: ...
|
| 322 |
+
def supports_tensor_alloc(self, device: torch.device) -> bool: ...
|
| 323 |
+
def allocate_tensor(
|
| 324 |
+
self,
|
| 325 |
+
size: int,
|
| 326 |
+
*,
|
| 327 |
+
dtype: torch.dtype,
|
| 328 |
+
device: torch.device,
|
| 329 |
+
) -> Tensor: ...
|
| 330 |
+
@property
|
| 331 |
+
def mem_allocator(self) -> Any: ...
|
| 332 |
+
|
| 333 |
+
class ProcessGroup:
|
| 334 |
+
class BackendType(Enum):
|
| 335 |
+
UNDEFINED = ...
|
| 336 |
+
GLOO = ...
|
| 337 |
+
NCCL = ...
|
| 338 |
+
UCC = ...
|
| 339 |
+
MPI = ...
|
| 340 |
+
XCCL = ...
|
| 341 |
+
CUSTOM = ...
|
| 342 |
+
|
| 343 |
+
def __init__(
|
| 344 |
+
self,
|
| 345 |
+
store: Store,
|
| 346 |
+
rank: int,
|
| 347 |
+
size: int,
|
| 348 |
+
) -> None: ...
|
| 349 |
+
def rank(self) -> int: ...
|
| 350 |
+
def size(self) -> int: ...
|
| 351 |
+
def abort(self) -> None: ...
|
| 352 |
+
def shutdown(self) -> None: ...
|
| 353 |
+
@overload
|
| 354 |
+
def broadcast(
|
| 355 |
+
self,
|
| 356 |
+
tensors: list[Tensor],
|
| 357 |
+
opts=...,
|
| 358 |
+
) -> Work: ...
|
| 359 |
+
@overload
|
| 360 |
+
def broadcast(
|
| 361 |
+
self,
|
| 362 |
+
tensor: Tensor,
|
| 363 |
+
root: int,
|
| 364 |
+
) -> Work: ...
|
| 365 |
+
@overload
|
| 366 |
+
def allreduce(
|
| 367 |
+
self,
|
| 368 |
+
tensors: list[Tensor],
|
| 369 |
+
opts: AllreduceOptions = ...,
|
| 370 |
+
) -> Work: ...
|
| 371 |
+
@overload
|
| 372 |
+
def allreduce(
|
| 373 |
+
self,
|
| 374 |
+
tensors: list[Tensor],
|
| 375 |
+
op=...,
|
| 376 |
+
) -> Work: ...
|
| 377 |
+
@overload
|
| 378 |
+
def allreduce(
|
| 379 |
+
self,
|
| 380 |
+
tensor: Tensor,
|
| 381 |
+
op=...,
|
| 382 |
+
) -> Work: ...
|
| 383 |
+
def allreduce_coalesced(
|
| 384 |
+
self,
|
| 385 |
+
tensors: list[Tensor],
|
| 386 |
+
opts=...,
|
| 387 |
+
) -> Work: ...
|
| 388 |
+
def reduce_scatter_tensor_coalesced(
|
| 389 |
+
self,
|
| 390 |
+
outputTensors: list[Tensor],
|
| 391 |
+
inputTensors: list[Tensor],
|
| 392 |
+
opts: ReduceScatterOptions | None = None,
|
| 393 |
+
) -> Work: ...
|
| 394 |
+
@overload
|
| 395 |
+
def reduce(
|
| 396 |
+
self,
|
| 397 |
+
tensors: list[Tensor],
|
| 398 |
+
opts=...,
|
| 399 |
+
) -> Work: ...
|
| 400 |
+
@overload
|
| 401 |
+
def reduce(
|
| 402 |
+
self,
|
| 403 |
+
tensor: Tensor,
|
| 404 |
+
root: int,
|
| 405 |
+
op=...,
|
| 406 |
+
) -> Work: ...
|
| 407 |
+
@overload
|
| 408 |
+
def allgather(
|
| 409 |
+
self,
|
| 410 |
+
output_tensors: list[list[Tensor]],
|
| 411 |
+
input_tensors: list[Tensor],
|
| 412 |
+
opts=...,
|
| 413 |
+
) -> Work: ...
|
| 414 |
+
@overload
|
| 415 |
+
def allgather(
|
| 416 |
+
self,
|
| 417 |
+
output_tensors: list[Tensor],
|
| 418 |
+
input_tensor: Tensor,
|
| 419 |
+
) -> Work: ...
|
| 420 |
+
def _allgather_base(
|
| 421 |
+
self,
|
| 422 |
+
output: Tensor,
|
| 423 |
+
input: Tensor,
|
| 424 |
+
opts=...,
|
| 425 |
+
) -> Work: ...
|
| 426 |
+
def allgather_coalesced(
|
| 427 |
+
self,
|
| 428 |
+
output_lists: list[list[Tensor]],
|
| 429 |
+
input_list: list[Tensor],
|
| 430 |
+
opts=...,
|
| 431 |
+
) -> Work: ...
|
| 432 |
+
def allgather_into_tensor_coalesced(
|
| 433 |
+
self,
|
| 434 |
+
output_lists: list[Tensor],
|
| 435 |
+
input_list: list[Tensor],
|
| 436 |
+
opts=...,
|
| 437 |
+
) -> Work: ...
|
| 438 |
+
@overload
|
| 439 |
+
def gather(
|
| 440 |
+
self,
|
| 441 |
+
output_tensors: list[list[Tensor]],
|
| 442 |
+
input_tensors: list[Tensor],
|
| 443 |
+
opts=...,
|
| 444 |
+
) -> Work: ...
|
| 445 |
+
@overload
|
| 446 |
+
def gather(
|
| 447 |
+
self,
|
| 448 |
+
output_tensors: list[Tensor],
|
| 449 |
+
input_tensor: Tensor,
|
| 450 |
+
root: int,
|
| 451 |
+
) -> Work: ...
|
| 452 |
+
@overload
|
| 453 |
+
def scatter(
|
| 454 |
+
self,
|
| 455 |
+
output_tensors: list[Tensor],
|
| 456 |
+
input_tensors: list[list[Tensor]],
|
| 457 |
+
opts=...,
|
| 458 |
+
) -> Work: ...
|
| 459 |
+
@overload
|
| 460 |
+
def scatter(
|
| 461 |
+
self,
|
| 462 |
+
output_tensor: Tensor,
|
| 463 |
+
input_tensors: list[Tensor],
|
| 464 |
+
root: int,
|
| 465 |
+
) -> Work: ...
|
| 466 |
+
@overload
|
| 467 |
+
def reduce_scatter(
|
| 468 |
+
self,
|
| 469 |
+
output_tensors: list[Tensor],
|
| 470 |
+
input_tensors: list[list[Tensor]],
|
| 471 |
+
opts=...,
|
| 472 |
+
) -> Work: ...
|
| 473 |
+
@overload
|
| 474 |
+
def reduce_scatter(
|
| 475 |
+
self,
|
| 476 |
+
output_tensors: Tensor,
|
| 477 |
+
input_tensor: list[Tensor],
|
| 478 |
+
) -> Work: ...
|
| 479 |
+
def _reduce_scatter_base(
|
| 480 |
+
self,
|
| 481 |
+
outputTensor: Tensor,
|
| 482 |
+
inputTensor: Tensor,
|
| 483 |
+
opts: ReduceScatterOptions | None,
|
| 484 |
+
) -> Work: ...
|
| 485 |
+
@overload
|
| 486 |
+
def alltoall_base(
|
| 487 |
+
self,
|
| 488 |
+
output_tensor: Tensor,
|
| 489 |
+
input_tensor: Tensor,
|
| 490 |
+
output_split_sizes: list[int],
|
| 491 |
+
input_split_sizes: list[int],
|
| 492 |
+
opts=...,
|
| 493 |
+
) -> Work: ...
|
| 494 |
+
@overload
|
| 495 |
+
def alltoall_base(
|
| 496 |
+
self,
|
| 497 |
+
output: Tensor,
|
| 498 |
+
input: Tensor,
|
| 499 |
+
output_split_sizes: list[int],
|
| 500 |
+
input_split_sizes: list[int],
|
| 501 |
+
) -> Work: ...
|
| 502 |
+
@overload
|
| 503 |
+
def alltoall(
|
| 504 |
+
self,
|
| 505 |
+
output_tensor: list[Tensor],
|
| 506 |
+
input_tensor: list[Tensor],
|
| 507 |
+
opts=...,
|
| 508 |
+
) -> Work: ...
|
| 509 |
+
@overload
|
| 510 |
+
def alltoall(
|
| 511 |
+
self,
|
| 512 |
+
output: list[Tensor],
|
| 513 |
+
input: list[Tensor],
|
| 514 |
+
) -> Work: ...
|
| 515 |
+
def send(
|
| 516 |
+
self,
|
| 517 |
+
tensors: list[Tensor],
|
| 518 |
+
dstRank: int,
|
| 519 |
+
tag: int,
|
| 520 |
+
) -> Work: ...
|
| 521 |
+
def recv(
|
| 522 |
+
self,
|
| 523 |
+
tensors: list[Tensor],
|
| 524 |
+
srcRank: int,
|
| 525 |
+
tag: int,
|
| 526 |
+
) -> Work: ...
|
| 527 |
+
def recv_anysource(self, tensors: list[Tensor], tag: int) -> Work: ...
|
| 528 |
+
def barrier(self, opts=...) -> Work: ...
|
| 529 |
+
def boxed(self) -> ScriptObject: ...
|
| 530 |
+
@staticmethod
|
| 531 |
+
def unbox(obj: ScriptObject) -> ProcessGroup: ...
|
| 532 |
+
def _start_coalescing(self, device: torch.device) -> None: ...
|
| 533 |
+
def _end_coalescing(self, device: torch.device) -> Work: ...
|
| 534 |
+
def _get_backend_name(self) -> str: ...
|
| 535 |
+
def _backend_id(self, backend_type: BackendType) -> int: ...
|
| 536 |
+
@property
|
| 537 |
+
def _device_types(self) -> list[torch.device]: ...
|
| 538 |
+
def _get_backend(self, device: torch.device) -> Backend: ...
|
| 539 |
+
def _set_default_backend(self, backend_type: BackendType) -> None: ...
|
| 540 |
+
def _register_backend(
|
| 541 |
+
self,
|
| 542 |
+
device: torch.device,
|
| 543 |
+
backend_type: BackendType,
|
| 544 |
+
backend: Backend | None,
|
| 545 |
+
) -> None: ...
|
| 546 |
+
def _set_group_name(self, name: str) -> None: ...
|
| 547 |
+
def _set_group_desc(self, desc: str) -> None: ...
|
| 548 |
+
def name(self) -> str: ...
|
| 549 |
+
def _has_hooks(self) -> bool: ...
|
| 550 |
+
def _wait_for_pending_works(self) -> None: ...
|
| 551 |
+
def _set_sequence_number_for_group(self) -> None: ...
|
| 552 |
+
@property
|
| 553 |
+
def bound_device_id(self) -> torch.device | None: ...
|
| 554 |
+
@bound_device_id.setter
|
| 555 |
+
def bound_device_id(self, device: torch.device | None) -> None: ...
|
| 556 |
+
@property
|
| 557 |
+
def group_name(self) -> str: ...
|
| 558 |
+
@property
|
| 559 |
+
def group_desc(self) -> str: ...
|
| 560 |
+
|
| 561 |
+
class FakeProcessGroup(Backend):
|
| 562 |
+
def __init__(self, rank: int, world_size: int) -> None: ...
|
| 563 |
+
|
| 564 |
+
class FakeWork(Work):
|
| 565 |
+
seq_id: int
|
| 566 |
+
def __init__(self) -> None: ...
|
| 567 |
+
def wait(self, timeout: timedelta = ...) -> bool: ...
|
| 568 |
+
def getFuture(self) -> Future: ...
|
| 569 |
+
|
| 570 |
+
class ProcessGroupGloo(Backend):
|
| 571 |
+
class Device: ...
|
| 572 |
+
|
| 573 |
+
class Options(Backend.Options):
|
| 574 |
+
devices: list[ProcessGroupGloo.Device]
|
| 575 |
+
threads: int
|
| 576 |
+
global_ranks_in_group: list[int]
|
| 577 |
+
group_name: str
|
| 578 |
+
|
| 579 |
+
def __init__(self): ...
|
| 580 |
+
|
| 581 |
+
def __init__(
|
| 582 |
+
self,
|
| 583 |
+
store: Store,
|
| 584 |
+
rank: int,
|
| 585 |
+
size: int,
|
| 586 |
+
timeout: timedelta,
|
| 587 |
+
) -> None: ...
|
| 588 |
+
@staticmethod
|
| 589 |
+
def create_device(hostname="", interface="", lazy_init=None) -> Device: ...
|
| 590 |
+
@staticmethod
|
| 591 |
+
def create_default_device(lazy_init=None) -> Device: ...
|
| 592 |
+
def _set_default_timeout(self, timeout) -> None: ...
|
| 593 |
+
@property
|
| 594 |
+
def options(self) -> Options: ... # type: ignore[override]
|
| 595 |
+
|
| 596 |
+
class _ProcessGroupWrapper(Backend):
|
| 597 |
+
def __init__(self, pg: Backend, gloo_pg: ProcessGroupGloo) -> None: ...
|
| 598 |
+
wrapped_pg: Backend
|
| 599 |
+
|
| 600 |
+
class ErrorType(Enum):
|
| 601 |
+
SUCCESS = ...
|
| 602 |
+
TIMEOUT = ...
|
| 603 |
+
COMM_ERROR = ...
|
| 604 |
+
REMOTE_ERROR = ...
|
| 605 |
+
|
| 606 |
+
class ProcessGroupNCCL(Backend):
|
| 607 |
+
class NCCLConfig:
|
| 608 |
+
blocking: int
|
| 609 |
+
cga_cluster_size: int
|
| 610 |
+
min_ctas: int
|
| 611 |
+
max_ctas: int
|
| 612 |
+
|
| 613 |
+
class Options(Backend.Options):
|
| 614 |
+
config: ProcessGroupNCCL.NCCLConfig
|
| 615 |
+
is_high_priority_stream: bool
|
| 616 |
+
split_from: ProcessGroupNCCL
|
| 617 |
+
split_color: int
|
| 618 |
+
global_ranks_in_group: list[int]
|
| 619 |
+
group_name: str
|
| 620 |
+
|
| 621 |
+
def __init__(self, is_high_priority_stream: bool = False): ...
|
| 622 |
+
|
| 623 |
+
def __init__(
|
| 624 |
+
self,
|
| 625 |
+
store: Store,
|
| 626 |
+
rank: int,
|
| 627 |
+
size: int,
|
| 628 |
+
options: Options,
|
| 629 |
+
) -> None: ...
|
| 630 |
+
def _group_start(self) -> None: ...
|
| 631 |
+
def _group_end(self) -> None: ...
|
| 632 |
+
def _start_time_estimate(self) -> None: ...
|
| 633 |
+
def _end_time_estimate(self) -> float: ...
|
| 634 |
+
def _set_default_timeout(self, timeout) -> None: ...
|
| 635 |
+
def perform_nocolor_split(self, device: torch.device) -> None: ...
|
| 636 |
+
def register_mem_pool(self, pool: torch.cuda.MemPool) -> None: ...
|
| 637 |
+
def deregister_mem_pool(self, pool: torch.cuda.MemPool) -> None: ...
|
| 638 |
+
def comm_split_count(self) -> int: ...
|
| 639 |
+
def _add_ephemeral_timeout(self, timeout: timedelta) -> None: ...
|
| 640 |
+
def abort(self) -> None: ...
|
| 641 |
+
def _is_initialized(self) -> bool: ...
|
| 642 |
+
@property
|
| 643 |
+
def uid(self) -> int: ...
|
| 644 |
+
@property
|
| 645 |
+
def options(self) -> Options: ... # type: ignore[override]
|
| 646 |
+
@staticmethod
|
| 647 |
+
def get_build_nccl_version(self) -> tuple[int, int, int]: ...
|
| 648 |
+
@staticmethod
|
| 649 |
+
def get_runtime_nccl_version(self) -> tuple[int, int, int]: ...
|
| 650 |
+
|
| 651 |
+
class ProcessGroupUCC(Backend):
|
| 652 |
+
def __init__(
|
| 653 |
+
self,
|
| 654 |
+
store: Store,
|
| 655 |
+
rank: int,
|
| 656 |
+
size: int,
|
| 657 |
+
timeout: timedelta,
|
| 658 |
+
) -> None: ...
|
| 659 |
+
|
| 660 |
+
class ProcessGroupMPI(Backend):
|
| 661 |
+
def __init__(
|
| 662 |
+
self,
|
| 663 |
+
rank: int,
|
| 664 |
+
size: int,
|
| 665 |
+
pgComm: int,
|
| 666 |
+
) -> None: ...
|
| 667 |
+
@staticmethod
|
| 668 |
+
def create(ranks: list[int]) -> ProcessGroupMPI: ...
|
| 669 |
+
|
| 670 |
+
def _compute_bucket_assignment_by_size(
|
| 671 |
+
tensors: list[Tensor],
|
| 672 |
+
bucket_size_limits: list[int],
|
| 673 |
+
expect_sparse_gradient: list[bool] = ...,
|
| 674 |
+
tensor_indices: list[int] = ...,
|
| 675 |
+
) -> tuple[list[list[int]], list[int]]: ...
|
| 676 |
+
def _broadcast_coalesced(
|
| 677 |
+
process_group: ProcessGroup,
|
| 678 |
+
tensors: list[Tensor],
|
| 679 |
+
buffer_size: int,
|
| 680 |
+
src: int,
|
| 681 |
+
): ...
|
| 682 |
+
def _test_python_store(store: Store): ...
|
| 683 |
+
def _verify_params_across_processes(
|
| 684 |
+
process_group: ProcessGroup,
|
| 685 |
+
params: list[Tensor],
|
| 686 |
+
logger: Logger | None,
|
| 687 |
+
): ...
|
| 688 |
+
def _make_nccl_premul_sum(factor: float | list[Tensor]) -> ReduceOp: ...
|
| 689 |
+
def _register_process_group(
|
| 690 |
+
group_name: str,
|
| 691 |
+
process_group: ProcessGroup,
|
| 692 |
+
) -> None: ...
|
| 693 |
+
def _resolve_process_group(group_name: str) -> ProcessGroup: ...
|
| 694 |
+
def _register_work(tensor: torch.Tensor, work: Work) -> ProcessGroup: ...
|
| 695 |
+
def _get_work_registry_size() -> int: ...
|
| 696 |
+
def _set_allow_inflight_collective_as_graph_input(
|
| 697 |
+
value: bool,
|
| 698 |
+
) -> None: ...
|
| 699 |
+
def _allow_inflight_collective_as_graph_input() -> bool: ...
|
| 700 |
+
def _unregister_all_process_groups() -> None: ...
|
| 701 |
+
def _unregister_process_group(group_name: str) -> None: ...
|
| 702 |
+
|
| 703 |
+
# Intializes the device state in CUmodule so that it’s able to perform NVSHMEM
|
| 704 |
+
# operations. CUmodule is a pointer to a CUDA module, carried by a int64 in
|
| 705 |
+
# Python. At C++ interface, it is converted to a uintptr_t.
|
| 706 |
+
def _nvshmemx_cumodule_init(module: int) -> None: ...
|
| 707 |
+
|
| 708 |
+
# Check if NVSHMEM is available on current system.
|
| 709 |
+
def _is_nvshmem_available() -> bool: ...
|
| 710 |
+
|
| 711 |
+
class _SymmetricMemory:
|
| 712 |
+
@staticmethod
|
| 713 |
+
def set_group_info(
|
| 714 |
+
group_name: str,
|
| 715 |
+
rank: int,
|
| 716 |
+
world_size: int,
|
| 717 |
+
store: Store,
|
| 718 |
+
) -> None: ...
|
| 719 |
+
@staticmethod
|
| 720 |
+
def empty_strided_p2p(
|
| 721 |
+
size: torch.types._size,
|
| 722 |
+
stride: torch.types._size,
|
| 723 |
+
dtype: torch.dtype,
|
| 724 |
+
device: torch.device,
|
| 725 |
+
group_name: str | None = None,
|
| 726 |
+
alloc_id: int | None = None,
|
| 727 |
+
) -> torch.Tensor: ...
|
| 728 |
+
@staticmethod
|
| 729 |
+
def has_multicast_support(
|
| 730 |
+
device_type: DeviceType,
|
| 731 |
+
device_idx: int,
|
| 732 |
+
) -> bool: ...
|
| 733 |
+
@property
|
| 734 |
+
def rank(self) -> int: ...
|
| 735 |
+
@property
|
| 736 |
+
def world_size(self) -> int: ...
|
| 737 |
+
@staticmethod
|
| 738 |
+
def rendezvous(
|
| 739 |
+
tensor: torch.Tensor, group_name: str | None = None
|
| 740 |
+
) -> _SymmetricMemory: ...
|
| 741 |
+
def get_buffer(
|
| 742 |
+
self,
|
| 743 |
+
rank: int,
|
| 744 |
+
sizes: torch.types._size,
|
| 745 |
+
dtype: torch.dtype,
|
| 746 |
+
storage_offset: int | None = 0,
|
| 747 |
+
) -> torch.Tensor: ...
|
| 748 |
+
def get_signal_pad(
|
| 749 |
+
self,
|
| 750 |
+
rank: int,
|
| 751 |
+
sizes: torch.types._size = [],
|
| 752 |
+
dtype: torch.dtype | None = None,
|
| 753 |
+
storage_offset: int | None = 0,
|
| 754 |
+
) -> torch.Tensor: ...
|
| 755 |
+
def barrier(self, channel: int = 0, timeout_ms: int = 0) -> None: ...
|
| 756 |
+
def put_signal(
|
| 757 |
+
self,
|
| 758 |
+
dst_rank: int,
|
| 759 |
+
channel: int = 0,
|
| 760 |
+
timeout_ms: int = 0,
|
| 761 |
+
) -> None: ...
|
| 762 |
+
def wait_signal(
|
| 763 |
+
self,
|
| 764 |
+
src_rank: int,
|
| 765 |
+
channel: int = 0,
|
| 766 |
+
timeout_ms: int = 0,
|
| 767 |
+
) -> None: ...
|
| 768 |
+
@staticmethod
|
| 769 |
+
def memset32(
|
| 770 |
+
tensor: torch.Tensor, offset: int, val: int, count: int = 1
|
| 771 |
+
) -> torch.Tensor: ...
|
| 772 |
+
@staticmethod
|
| 773 |
+
def stream_write_value32(
|
| 774 |
+
tensor: torch.Tensor, offset: int, val: int
|
| 775 |
+
) -> torch.Tensor: ...
|
| 776 |
+
@property
|
| 777 |
+
def buffer_ptrs(self) -> list[int]: ...
|
| 778 |
+
@property
|
| 779 |
+
def buffer_ptrs_dev(self) -> int: ...
|
| 780 |
+
@property
|
| 781 |
+
def signal_pad_ptrs(self) -> list[int]: ...
|
| 782 |
+
@property
|
| 783 |
+
def signal_pad_ptrs_dev(self) -> int: ...
|
| 784 |
+
@property
|
| 785 |
+
def multicast_ptr(self) -> int: ...
|
| 786 |
+
@property
|
| 787 |
+
def buffer_size(self) -> int: ...
|
| 788 |
+
@property
|
| 789 |
+
def signal_pad_size(self) -> int: ...
|
| 790 |
+
|
| 791 |
+
class ProcessGroupXCCL(Backend):
|
| 792 |
+
def __init__(
|
| 793 |
+
self,
|
| 794 |
+
store: Store,
|
| 795 |
+
rank: int,
|
| 796 |
+
size: int,
|
| 797 |
+
): ...
|
phivenv/Lib/site-packages/torch/_C/_distributed_rpc.pyi
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
# mypy: disable-error-code="type-arg"
|
| 3 |
+
from datetime import timedelta
|
| 4 |
+
from typing import Any, Generic, overload, TypeVar
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch._C import Future
|
| 8 |
+
from torch._C._autograd import ProfilerEvent
|
| 9 |
+
from torch._C._distributed_c10d import Store
|
| 10 |
+
from torch._C._profiler import ProfilerConfig
|
| 11 |
+
|
| 12 |
+
# This module is defined in torch/csrc/distributed/rpc/init.cpp
|
| 13 |
+
|
| 14 |
+
_DEFAULT_INIT_METHOD: str
|
| 15 |
+
_DEFAULT_NUM_WORKER_THREADS: int
|
| 16 |
+
_UNSET_RPC_TIMEOUT: float
|
| 17 |
+
_DEFAULT_RPC_TIMEOUT_SEC: float
|
| 18 |
+
|
| 19 |
+
_T = TypeVar("_T")
|
| 20 |
+
|
| 21 |
+
class RpcBackendOptions:
|
| 22 |
+
rpc_timeout: float
|
| 23 |
+
init_method: str
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
rpc_timeout: float = ...,
|
| 27 |
+
init_method: str = ...,
|
| 28 |
+
) -> None: ...
|
| 29 |
+
|
| 30 |
+
class WorkerInfo:
|
| 31 |
+
def __init__(self, name: str, worker_id: int) -> None: ...
|
| 32 |
+
@property
|
| 33 |
+
def name(self) -> str: ...
|
| 34 |
+
@property
|
| 35 |
+
def id(self) -> int: ...
|
| 36 |
+
def __eq__(self, other: object) -> bool: ...
|
| 37 |
+
|
| 38 |
+
class RpcAgent:
|
| 39 |
+
def join(self, shutdown: bool = False, timeout: float = 0): ...
|
| 40 |
+
def sync(self): ...
|
| 41 |
+
def shutdown(self): ...
|
| 42 |
+
@overload
|
| 43 |
+
def get_worker_info(self) -> WorkerInfo: ...
|
| 44 |
+
@overload
|
| 45 |
+
def get_worker_info(self, workerName: str) -> WorkerInfo: ...
|
| 46 |
+
def get_worker_infos(self) -> list[WorkerInfo]: ...
|
| 47 |
+
def _get_device_map(self, dst: WorkerInfo) -> dict[torch.device, torch.device]: ...
|
| 48 |
+
def get_debug_info(self) -> dict[str, str]: ...
|
| 49 |
+
def get_metrics(self) -> dict[str, str]: ...
|
| 50 |
+
|
| 51 |
+
class PyRRef(Generic[_T]):
|
| 52 |
+
def __init__(self, value: _T, type_hint: Any = None) -> None: ...
|
| 53 |
+
def is_owner(self) -> bool: ...
|
| 54 |
+
def confirmed_by_owner(self) -> bool: ...
|
| 55 |
+
def owner(self) -> WorkerInfo: ...
|
| 56 |
+
def owner_name(self) -> str: ...
|
| 57 |
+
def to_here(self, timeout: float = ...) -> _T: ...
|
| 58 |
+
def local_value(self) -> Any: ...
|
| 59 |
+
def rpc_sync(self, timeout: float = ...) -> Any: ...
|
| 60 |
+
def rpc_async(self, timeout: float = ...) -> Any: ...
|
| 61 |
+
def remote(self, timeout: float = ...) -> Any: ...
|
| 62 |
+
def _serialize(self) -> tuple: ...
|
| 63 |
+
@staticmethod
|
| 64 |
+
def _deserialize(tp: tuple) -> PyRRef: ...
|
| 65 |
+
def _get_type(self) -> type[_T]: ...
|
| 66 |
+
def _get_future(self) -> Future[_T]: ...
|
| 67 |
+
def _get_profiling_future(self) -> Future[_T]: ...
|
| 68 |
+
def _set_profiling_future(self, profilingFuture: Future[_T]): ...
|
| 69 |
+
|
| 70 |
+
class _TensorPipeRpcBackendOptionsBase(RpcBackendOptions):
|
| 71 |
+
num_worker_threads: int
|
| 72 |
+
device_maps: dict[str, dict[torch.device, torch.device]]
|
| 73 |
+
devices: list[torch.device]
|
| 74 |
+
def __init__(
|
| 75 |
+
self,
|
| 76 |
+
num_worker_threads: int,
|
| 77 |
+
_transports: list | None,
|
| 78 |
+
_channels: list | None,
|
| 79 |
+
rpc_timeout: float = ...,
|
| 80 |
+
init_method: str = ...,
|
| 81 |
+
device_maps: dict[str, dict[torch.device, torch.device]] = {}, # noqa: B006
|
| 82 |
+
devices: list[torch.device] = [], # noqa: B006
|
| 83 |
+
) -> None: ...
|
| 84 |
+
def _set_device_map(
|
| 85 |
+
self,
|
| 86 |
+
to: str,
|
| 87 |
+
device_map: dict[torch.device, torch.device],
|
| 88 |
+
): ...
|
| 89 |
+
|
| 90 |
+
class TensorPipeAgent(RpcAgent):
|
| 91 |
+
def __init__(
|
| 92 |
+
self,
|
| 93 |
+
store: Store,
|
| 94 |
+
name: str,
|
| 95 |
+
worker_id: int,
|
| 96 |
+
world_size: int | None,
|
| 97 |
+
opts: _TensorPipeRpcBackendOptionsBase,
|
| 98 |
+
reverse_device_maps: dict[str, dict[torch.device, torch.device]],
|
| 99 |
+
devices: list[torch.device],
|
| 100 |
+
) -> None: ...
|
| 101 |
+
def join(self, shutdown: bool = False, timeout: float = 0): ...
|
| 102 |
+
def shutdown(self): ...
|
| 103 |
+
@overload
|
| 104 |
+
def get_worker_info(self) -> WorkerInfo: ...
|
| 105 |
+
@overload
|
| 106 |
+
def get_worker_info(self, workerName: str) -> WorkerInfo: ...
|
| 107 |
+
@overload
|
| 108 |
+
def get_worker_info(self, id: int) -> WorkerInfo: ...
|
| 109 |
+
def get_worker_infos(self) -> list[WorkerInfo]: ...
|
| 110 |
+
def _get_device_map(self, dst: WorkerInfo) -> dict[torch.device, torch.device]: ...
|
| 111 |
+
def _update_group_membership(
|
| 112 |
+
self,
|
| 113 |
+
worker_info: WorkerInfo,
|
| 114 |
+
my_devices: list[torch.device],
|
| 115 |
+
reverse_device_map: dict[str, dict[torch.device, torch.device]],
|
| 116 |
+
is_join: bool,
|
| 117 |
+
): ...
|
| 118 |
+
def _get_backend_options(self) -> _TensorPipeRpcBackendOptionsBase: ...
|
| 119 |
+
@property
|
| 120 |
+
def is_static_group(self) -> bool: ...
|
| 121 |
+
@property
|
| 122 |
+
def store(self) -> Store: ...
|
| 123 |
+
|
| 124 |
+
def _is_current_rpc_agent_set() -> bool: ...
|
| 125 |
+
def _get_current_rpc_agent() -> RpcAgent: ...
|
| 126 |
+
def _set_and_start_rpc_agent(agent: RpcAgent): ...
|
| 127 |
+
def _reset_current_rpc_agent(): ...
|
| 128 |
+
def _delete_all_user_and_unforked_owner_rrefs(timeout: timedelta = ...): ...
|
| 129 |
+
def _destroy_rref_context(ignoreRRefLeak: bool): ...
|
| 130 |
+
def _rref_context_get_debug_info() -> dict[str, str]: ...
|
| 131 |
+
def _cleanup_python_rpc_handler(): ...
|
| 132 |
+
def _invoke_rpc_builtin(
|
| 133 |
+
dst: WorkerInfo,
|
| 134 |
+
opName: str,
|
| 135 |
+
rpcTimeoutSeconds: float,
|
| 136 |
+
*args: Any,
|
| 137 |
+
**kwargs: Any,
|
| 138 |
+
): ...
|
| 139 |
+
def _invoke_rpc_python_udf(
|
| 140 |
+
dst: WorkerInfo,
|
| 141 |
+
pickledPythonUDF: str,
|
| 142 |
+
tensors: list[torch.Tensor],
|
| 143 |
+
rpcTimeoutSeconds: float,
|
| 144 |
+
isAsyncExecution: bool,
|
| 145 |
+
): ...
|
| 146 |
+
def _invoke_rpc_torchscript(
|
| 147 |
+
dstWorkerName: str,
|
| 148 |
+
qualifiedNameStr: str,
|
| 149 |
+
argsTuple: tuple,
|
| 150 |
+
kwargsDict: dict,
|
| 151 |
+
rpcTimeoutSeconds: float,
|
| 152 |
+
isAsyncExecution: bool,
|
| 153 |
+
): ...
|
| 154 |
+
def _invoke_remote_builtin(
|
| 155 |
+
dst: WorkerInfo,
|
| 156 |
+
opName: str,
|
| 157 |
+
rpcTimeoutSeconds: float,
|
| 158 |
+
*args: Any,
|
| 159 |
+
**kwargs: Any,
|
| 160 |
+
): ...
|
| 161 |
+
def _invoke_remote_python_udf(
|
| 162 |
+
dst: WorkerInfo,
|
| 163 |
+
pickledPythonUDF: str,
|
| 164 |
+
tensors: list[torch.Tensor],
|
| 165 |
+
rpcTimeoutSeconds: float,
|
| 166 |
+
isAsyncExecution: bool,
|
| 167 |
+
): ...
|
| 168 |
+
def _invoke_remote_torchscript(
|
| 169 |
+
dstWorkerName: WorkerInfo,
|
| 170 |
+
qualifiedNameStr: str,
|
| 171 |
+
rpcTimeoutSeconds: float,
|
| 172 |
+
isAsyncExecution: bool,
|
| 173 |
+
*args: Any,
|
| 174 |
+
**kwargs: Any,
|
| 175 |
+
): ...
|
| 176 |
+
def get_rpc_timeout() -> float: ...
|
| 177 |
+
def enable_gil_profiling(flag: bool): ...
|
| 178 |
+
def _set_rpc_timeout(rpcTimeoutSeconds: float): ...
|
| 179 |
+
|
| 180 |
+
class RemoteProfilerManager:
|
| 181 |
+
@staticmethod
|
| 182 |
+
def set_current_profiling_key(key: str): ...
|
| 183 |
+
|
| 184 |
+
def _enable_server_process_global_profiler(new_config: ProfilerConfig): ...
|
| 185 |
+
def _disable_server_process_global_profiler() -> list[list[list[ProfilerEvent]]]: ...
|
| 186 |
+
def _set_profiler_node_id(default_node_id: int): ...
|
| 187 |
+
def _enable_jit_rref_pickle(): ...
|
| 188 |
+
def _disable_jit_rref_pickle(): ...
|
phivenv/Lib/site-packages/torch/_C/_distributed_rpc_testing.pyi
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch._C._distributed_c10d import Store
|
| 3 |
+
from torch._C._distributed_rpc import _TensorPipeRpcBackendOptionsBase, TensorPipeAgent
|
| 4 |
+
|
| 5 |
+
# This module is defined in torch/csrc/distributed/rpc/testing/init.cpp
|
| 6 |
+
|
| 7 |
+
class FaultyTensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase):
|
| 8 |
+
def __init__(
|
| 9 |
+
self,
|
| 10 |
+
num_worker_threads: int,
|
| 11 |
+
rpc_timeout: float,
|
| 12 |
+
init_method: str,
|
| 13 |
+
messages_to_fail: list[str],
|
| 14 |
+
messages_to_delay: dict[str, float],
|
| 15 |
+
num_fail_sends: int,
|
| 16 |
+
) -> None: ...
|
| 17 |
+
num_send_recv_threads: int
|
| 18 |
+
messages_to_fail: list[str]
|
| 19 |
+
messages_to_delay: dict[str, float]
|
| 20 |
+
num_fail_sends: int
|
| 21 |
+
|
| 22 |
+
class FaultyTensorPipeAgent(TensorPipeAgent):
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
store: Store,
|
| 26 |
+
name: str,
|
| 27 |
+
rank: int,
|
| 28 |
+
world_size: int,
|
| 29 |
+
options: FaultyTensorPipeRpcBackendOptions,
|
| 30 |
+
reverse_device_maps: dict[str, dict[torch.device, torch.device]],
|
| 31 |
+
devices: list[torch.device],
|
| 32 |
+
) -> None: ...
|
phivenv/Lib/site-packages/torch/_C/_dynamo/__init__.pyi
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import compiled_autograd, eval_frame, guards # noqa: F401
|
| 2 |
+
|
| 3 |
+
def strip_function_call(name: str) -> str: ...
|
| 4 |
+
def is_valid_var_name(name: str) -> bool | int: ...
|
phivenv/Lib/site-packages/torch/_C/_dynamo/compiled_autograd.pyi
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Callable
|
| 2 |
+
|
| 3 |
+
from torch import Tensor
|
| 4 |
+
from torch._dynamo.compiled_autograd import AutogradCompilerInstance
|
| 5 |
+
|
| 6 |
+
def set_autograd_compiler(
|
| 7 |
+
autograd_compiler: Callable[[], AutogradCompilerInstance] | None,
|
| 8 |
+
dynamic: bool,
|
| 9 |
+
) -> tuple[Callable[[], AutogradCompilerInstance] | None, bool]: ...
|
| 10 |
+
def clear_cache() -> None: ...
|
| 11 |
+
def is_cache_empty() -> bool: ...
|
| 12 |
+
def set_verbose_logger(fn: Callable[[str], None] | None) -> bool: ...
|
| 13 |
+
def call_cpp_tensor_pre_hooks(idx: int, grad: Tensor) -> Tensor: ...
|
phivenv/Lib/site-packages/torch/_C/_dynamo/eval_frame.pyi
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import enum
|
| 2 |
+
import types
|
| 3 |
+
from typing import Optional, overload
|
| 4 |
+
|
| 5 |
+
from torch._dynamo.types import (
|
| 6 |
+
DynamoCallback,
|
| 7 |
+
DynamoGuardCompleteHook,
|
| 8 |
+
DynamoGuardHook,
|
| 9 |
+
GuardFn,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
def set_eval_frame(callback: DynamoCallback) -> DynamoCallback: ...
|
| 13 |
+
def set_skip_guard_eval_unsafe(value: bool) -> bool: ...
|
| 14 |
+
def get_eval_frame_callback() -> DynamoCallback: ...
|
| 15 |
+
def reset_code(code: types.CodeType) -> None: ...
|
| 16 |
+
def unsupported(obj1: object, obj2: object) -> object: ...
|
| 17 |
+
def set_code_exec_strategy(
|
| 18 |
+
code: types.CodeType, strategy: _FrameExecStrategy
|
| 19 |
+
) -> None: ...
|
| 20 |
+
def set_guard_error_hook(hook: DynamoGuardHook) -> None: ...
|
| 21 |
+
def set_guard_complete_hook(
|
| 22 |
+
hook: Optional[DynamoGuardCompleteHook],
|
| 23 |
+
) -> Optional[DynamoGuardCompleteHook]: ...
|
| 24 |
+
def raise_sigtrap() -> None: ...
|
| 25 |
+
|
| 26 |
+
class _CacheEntry:
|
| 27 |
+
def check_fn(self, *args: object, **kwargs: object) -> bool: ...
|
| 28 |
+
code: types.CodeType
|
| 29 |
+
next: _CacheEntry | None
|
| 30 |
+
|
| 31 |
+
class _ExtraState:
|
| 32 |
+
def invalidate(self, cache_entry: _CacheEntry, guard_manager: object) -> None: ...
|
| 33 |
+
|
| 34 |
+
class _FrameAction(enum.IntEnum):
|
| 35 |
+
DEFAULT = 0
|
| 36 |
+
SKIP = 1
|
| 37 |
+
RUN_ONLY = 2
|
| 38 |
+
|
| 39 |
+
class _FrameExecStrategy:
|
| 40 |
+
cur_action: _FrameAction
|
| 41 |
+
recursive_action: _FrameAction
|
| 42 |
+
|
| 43 |
+
@overload
|
| 44 |
+
def __init__(self) -> None: ...
|
| 45 |
+
@overload
|
| 46 |
+
def __init__(
|
| 47 |
+
self, cur_action: _FrameAction, recursive_action: _FrameAction
|
| 48 |
+
) -> None: ...
|
| 49 |
+
|
| 50 |
+
# This is an object that encapsulates the Python FrameType, and exposes
|
| 51 |
+
# properties Dynamo cares about for a frame.
|
| 52 |
+
class _PyInterpreterFrame:
|
| 53 |
+
f_code: types.CodeType
|
| 54 |
+
f_locals: dict[str, object]
|
| 55 |
+
f_globals: dict[str, object]
|
| 56 |
+
f_builtins: dict[str, object]
|
| 57 |
+
f_lasti: int
|
| 58 |
+
f_lineo: int
|
| 59 |
+
f_back: types.FrameType
|
| 60 |
+
# A tuple containing cell objects captured by this frame.
|
| 61 |
+
closure: tuple[types.CellType]
|
| 62 |
+
|
| 63 |
+
def _debug_get_cache_entry_list(code: types.CodeType) -> list[_CacheEntry]: ...
|
| 64 |
+
|
| 65 |
+
py_opcode_caches: list[int]
|
| 66 |
+
|
| 67 |
+
def code_framelocals_names(code: types.CodeType) -> tuple[str]: ...
|
| 68 |
+
def _load_precompile_entry(
|
| 69 |
+
code: types.CodeType, guard_manager: GuardFn, dynamo_code: types.CodeType
|
| 70 |
+
) -> None: ...
|
| 71 |
+
def _reset_precompile_entries(code: types.CodeType) -> None: ...
|
phivenv/Lib/site-packages/torch/_C/_dynamo/guards.pyi
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from typing import Any, Callable
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
class GlobalStateGuard:
|
| 7 |
+
def check(self) -> bool: ...
|
| 8 |
+
def reason(self) -> str: ...
|
| 9 |
+
|
| 10 |
+
class LeafGuard: ...
|
| 11 |
+
class GuardDebugInfo: ...
|
| 12 |
+
|
| 13 |
+
class GuardManager:
|
| 14 |
+
def check(self, value) -> bool: ...
|
| 15 |
+
def check_verbose(self, value) -> GuardDebugInfo: ...
|
| 16 |
+
|
| 17 |
+
# Accessors
|
| 18 |
+
def globals_dict_manager(
|
| 19 |
+
self,
|
| 20 |
+
f_globals: dict[str, Any],
|
| 21 |
+
source,
|
| 22 |
+
example_value,
|
| 23 |
+
guard_manager_enum,
|
| 24 |
+
) -> GuardManager: ...
|
| 25 |
+
def framelocals_manager(
|
| 26 |
+
self,
|
| 27 |
+
key: tuple[str, int],
|
| 28 |
+
source,
|
| 29 |
+
example_value,
|
| 30 |
+
guard_manager_enum,
|
| 31 |
+
) -> GuardManager: ...
|
| 32 |
+
def dict_getitem_manager(
|
| 33 |
+
self,
|
| 34 |
+
key,
|
| 35 |
+
source,
|
| 36 |
+
example_value,
|
| 37 |
+
guard_manager_enum,
|
| 38 |
+
) -> GuardManager: ...
|
| 39 |
+
def global_weakref_manager(
|
| 40 |
+
self,
|
| 41 |
+
global_name: str,
|
| 42 |
+
source,
|
| 43 |
+
example_value,
|
| 44 |
+
guard_manager_enum,
|
| 45 |
+
) -> GuardManager: ...
|
| 46 |
+
def type_manager(
|
| 47 |
+
self,
|
| 48 |
+
source,
|
| 49 |
+
example_value,
|
| 50 |
+
guard_manager_enum,
|
| 51 |
+
) -> GuardManager: ...
|
| 52 |
+
def getattr_manager(
|
| 53 |
+
self,
|
| 54 |
+
attr: str,
|
| 55 |
+
source,
|
| 56 |
+
example_value,
|
| 57 |
+
guard_manager_enum,
|
| 58 |
+
) -> GuardManager: ...
|
| 59 |
+
def tensor_property_size_manager(
|
| 60 |
+
self,
|
| 61 |
+
idx: int,
|
| 62 |
+
source,
|
| 63 |
+
example_value,
|
| 64 |
+
guard_manager_enum,
|
| 65 |
+
) -> GuardManager: ...
|
| 66 |
+
def tensor_property_shape_manager(
|
| 67 |
+
self,
|
| 68 |
+
idx: int,
|
| 69 |
+
source,
|
| 70 |
+
example_value,
|
| 71 |
+
guard_manager_enum,
|
| 72 |
+
) -> GuardManager: ...
|
| 73 |
+
def tensor_property_storage_offset_manager(
|
| 74 |
+
self,
|
| 75 |
+
idx: None,
|
| 76 |
+
source,
|
| 77 |
+
example_value,
|
| 78 |
+
guard_manager_enum,
|
| 79 |
+
) -> GuardManager: ...
|
| 80 |
+
def indexed_manager(
|
| 81 |
+
self,
|
| 82 |
+
idx: int,
|
| 83 |
+
source,
|
| 84 |
+
example_value,
|
| 85 |
+
guard_manager_enum,
|
| 86 |
+
) -> GuardManager: ...
|
| 87 |
+
def lambda_manager(
|
| 88 |
+
self,
|
| 89 |
+
python_lambda,
|
| 90 |
+
source,
|
| 91 |
+
example_value,
|
| 92 |
+
guard_manager_enum,
|
| 93 |
+
) -> GuardManager: ...
|
| 94 |
+
|
| 95 |
+
# Leaf guards
|
| 96 |
+
def add_lambda_guard(self, user_lambda, verbose_code_parts: list[str]) -> None: ...
|
| 97 |
+
def add_id_match_guard(self, id_val, verbose_code_parts: list[str]) -> None: ...
|
| 98 |
+
def add_equals_match_guard(
|
| 99 |
+
self,
|
| 100 |
+
equals_val,
|
| 101 |
+
verbose_code_parts: list[str],
|
| 102 |
+
) -> None: ...
|
| 103 |
+
def add_global_state_guard(self, verbose_code_parts: list[str]) -> None: ...
|
| 104 |
+
def add_torch_function_mode_stack_guard(
|
| 105 |
+
self, initial_stack, verbose_code_parts: list[str]
|
| 106 |
+
) -> None: ...
|
| 107 |
+
def add_mapping_keys_guard(sef, value, verbose_code_parts: list[str]) -> None: ...
|
| 108 |
+
|
| 109 |
+
class RootGuardManager(GuardManager):
|
| 110 |
+
def get_epilogue_lambda_guards(self) -> list[LeafGuard]: ...
|
| 111 |
+
def add_epilogue_lambda_guard(
|
| 112 |
+
self,
|
| 113 |
+
guard: LeafGuard,
|
| 114 |
+
verbose_code_parts: list[str],
|
| 115 |
+
) -> None: ...
|
| 116 |
+
def clone_manager(
|
| 117 |
+
self, clone_filter_fn: Callable[[GuardManager], bool]
|
| 118 |
+
) -> RootGuardManager: ...
|
| 119 |
+
|
| 120 |
+
class DictGuardManager(GuardManager):
|
| 121 |
+
def get_key_manager(
|
| 122 |
+
self,
|
| 123 |
+
index,
|
| 124 |
+
source,
|
| 125 |
+
example_value,
|
| 126 |
+
guard_manager_enum,
|
| 127 |
+
) -> GuardManager: ...
|
| 128 |
+
def get_value_manager(
|
| 129 |
+
self,
|
| 130 |
+
index,
|
| 131 |
+
source,
|
| 132 |
+
example_value,
|
| 133 |
+
guard_manager_enum,
|
| 134 |
+
) -> GuardManager: ...
|
| 135 |
+
|
| 136 |
+
def install_object_aliasing_guard(
|
| 137 |
+
guard_managers: list[GuardManager],
|
| 138 |
+
tensor_names: list[str],
|
| 139 |
+
verbose_code_parts: list[str],
|
| 140 |
+
): ...
|
| 141 |
+
def install_no_tensor_aliasing_guard(
|
| 142 |
+
guard_managers: list[GuardManager],
|
| 143 |
+
tensor_names: list[str],
|
| 144 |
+
verbose_code_parts: list[str],
|
| 145 |
+
): ...
|
| 146 |
+
def install_storage_overlapping_guard(
|
| 147 |
+
overlapping_guard_managers: list[GuardManager],
|
| 148 |
+
non_overlapping_guard_managers: list[GuardManager],
|
| 149 |
+
verbose_code_parts: list[str],
|
| 150 |
+
): ...
|
| 151 |
+
def install_symbolic_shape_guard(
|
| 152 |
+
guard_managers: list[GuardManager],
|
| 153 |
+
nargs_int: int,
|
| 154 |
+
nargs_float: int,
|
| 155 |
+
py_addr: int,
|
| 156 |
+
py_addr_keep_alive: Any,
|
| 157 |
+
verbose_code_parts: list[str],
|
| 158 |
+
): ...
|
| 159 |
+
def profile_guard_manager(
|
| 160 |
+
guard_manager: GuardManager,
|
| 161 |
+
f_locals: dict[str, Any],
|
| 162 |
+
n_iters: int,
|
| 163 |
+
) -> float: ...
|
| 164 |
+
|
| 165 |
+
class TensorGuards:
|
| 166 |
+
def __init__(
|
| 167 |
+
self,
|
| 168 |
+
*,
|
| 169 |
+
dynamic_dims_sizes: list[torch.SymInt | None] | None = None,
|
| 170 |
+
dynamic_dims_strides: list[torch.SymInt | None] | None = None,
|
| 171 |
+
) -> None: ...
|
| 172 |
+
def check(self, *args) -> bool: ...
|
| 173 |
+
def check_verbose(self, *args, tensor_check_names=None) -> bool | str: ...
|
| 174 |
+
|
| 175 |
+
def assert_size_stride(
|
| 176 |
+
item: torch.Tensor,
|
| 177 |
+
size: torch.types._size,
|
| 178 |
+
stride: torch.types._size,
|
| 179 |
+
op_name: str | None = None,
|
| 180 |
+
): ...
|
| 181 |
+
def assert_alignment(
|
| 182 |
+
item: torch.Tensor,
|
| 183 |
+
alignment: int,
|
| 184 |
+
op_name: str | None = None,
|
| 185 |
+
): ...
|
| 186 |
+
def check_obj_id(obj: object, expected: int) -> bool: ...
|
| 187 |
+
def check_type_id(obj: object, expected: int) -> bool: ...
|
| 188 |
+
def dict_version(d: dict[Any, Any]) -> int: ...
|
| 189 |
+
def compute_overlapping_tensors(
|
| 190 |
+
tensors: list[torch.Tensor], symbolic: bool = True
|
| 191 |
+
) -> set[int]: ...
|
phivenv/Lib/site-packages/torch/_C/_export/__init__.pyi
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Defined in torch/csrc/export/pybind.cpp
|
| 2 |
+
class CppExportedProgram: ...
|
| 3 |
+
|
| 4 |
+
def deserialize_exported_program(
|
| 5 |
+
serialized_program: str,
|
| 6 |
+
) -> CppExportedProgram: ...
|
| 7 |
+
def serialize_exported_program(
|
| 8 |
+
cpp_exported_program: CppExportedProgram,
|
| 9 |
+
) -> str: ...
|
phivenv/Lib/site-packages/torch/_C/_export/pt2_archive_constants.pyi
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Defined in torch/csrc/export/pt2_archive_constants.h
|
| 2 |
+
|
| 3 |
+
ARCHIVE_ROOT_NAME: str = ...
|
| 4 |
+
ARCHIVE_FORMAT_PATH: str = ...
|
| 5 |
+
ARCHIVE_FORMAT_VALUE: str = ...
|
| 6 |
+
ARCHIVE_VERSION_PATH: str = ...
|
| 7 |
+
ARCHIVE_VERSION_VALUE: str = ...
|
| 8 |
+
MODELS_DIR: str = ...
|
| 9 |
+
MODELS_FILENAME_FORMAT: str = ...
|
| 10 |
+
AOTINDUCTOR_DIR: str = ...
|
| 11 |
+
MTIA_DIR: str = ...
|
| 12 |
+
WEIGHTS_DIR: str = ...
|
| 13 |
+
WEIGHT_FILENAME_PREFIX: str = ...
|
| 14 |
+
CONSTANTS_DIR: str = ...
|
| 15 |
+
TENSOR_CONSTANT_FILENAME_PREFIX: str = ...
|
| 16 |
+
CUSTOM_OBJ_FILENAME_PREFIX: str = ...
|
| 17 |
+
SAMPLE_INPUTS_DIR: str = ...
|
| 18 |
+
SAMPLE_INPUTS_FILENAME_FORMAT: str = ...
|
| 19 |
+
EXTRA_DIR: str = ...
|
| 20 |
+
MODULE_INFO_PATH: str = ...
|
| 21 |
+
XL_MODEL_WEIGHTS_DIR: str = ...
|
| 22 |
+
XL_MODEL_WEIGHTS_PARAM_CONFIG_PATH: str = ...
|
phivenv/Lib/site-packages/torch/_C/_functions.pyi
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import AnyStr, overload
|
| 2 |
+
|
| 3 |
+
from torch import Tensor
|
| 4 |
+
|
| 5 |
+
class UndefinedGrad:
|
| 6 |
+
def __init__(self) -> None: ...
|
| 7 |
+
def __call__(self, *inputs: Tensor) -> list[Tensor]: ...
|
| 8 |
+
|
| 9 |
+
class DelayedError:
|
| 10 |
+
def __init__(self, msg: AnyStr, num_inputs: int) -> None: ...
|
| 11 |
+
|
| 12 |
+
# __call__ should really be a higher-kinded type:
|
| 13 |
+
# def __call__(self, arg: Tensor) -> Tensor: ...
|
| 14 |
+
# def __call__(self, *args: Tensor * num_inputs) -> Tuple[Tensor * num_inputs]: ...
|
| 15 |
+
|
| 16 |
+
@overload
|
| 17 |
+
def __call__(self, i0: Tensor) -> Tensor: ...
|
| 18 |
+
@overload
|
| 19 |
+
def __call__(self, *args: Tensor) -> tuple[Tensor, ...]: ...
|
phivenv/Lib/site-packages/torch/_C/_functorch.pyi
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from enum import Enum
|
| 3 |
+
|
| 4 |
+
from torch import Tensor
|
| 5 |
+
|
| 6 |
+
# Defined in torch/csrc/functorch/init.cpp
|
| 7 |
+
|
| 8 |
+
def _set_dynamic_layer_keys_included(included: bool) -> None: ...
|
| 9 |
+
def get_unwrapped(tensor: Tensor) -> Tensor: ...
|
| 10 |
+
def is_batchedtensor(tensor: Tensor) -> bool: ...
|
| 11 |
+
def is_functionaltensor(tensor: Tensor) -> bool: ...
|
| 12 |
+
def is_functorch_wrapped_tensor(tensor: Tensor) -> bool: ...
|
| 13 |
+
def is_gradtrackingtensor(tensor: Tensor) -> bool: ...
|
| 14 |
+
def is_legacy_batchedtensor(tensor: Tensor) -> bool: ...
|
| 15 |
+
def maybe_get_bdim(tensor: Tensor) -> int: ...
|
| 16 |
+
def maybe_get_level(tensor: Tensor) -> int: ...
|
| 17 |
+
def maybe_current_level() -> int | None: ...
|
| 18 |
+
def unwrap_if_dead(tensor: Tensor) -> Tensor: ...
|
| 19 |
+
def _unwrap_for_grad(tensor: Tensor, level: int) -> Tensor: ...
|
| 20 |
+
def _wrap_for_grad(tensor: Tensor, level: int) -> Tensor: ...
|
| 21 |
+
def _unwrap_batched(tensor: Tensor, level: int) -> tuple[Tensor, int | None]: ...
|
| 22 |
+
def current_level() -> int: ...
|
| 23 |
+
def count_jvp_interpreters() -> int: ...
|
| 24 |
+
def _add_batch_dim(tensor: Tensor, bdim: int, level: int) -> Tensor: ...
|
| 25 |
+
def set_single_level_autograd_function_allowed(allowed: bool) -> None: ...
|
| 26 |
+
def get_single_level_autograd_function_allowed() -> bool: ...
|
| 27 |
+
def _unwrap_functional_tensor(tensor: Tensor, reapply_views: bool) -> Tensor: ...
|
| 28 |
+
def _wrap_functional_tensor(tensor: Tensor, level: int) -> Tensor: ...
|
| 29 |
+
def _vmap_increment_nesting(batch_size: int, randomness: str) -> int: ...
|
| 30 |
+
def _vmap_decrement_nesting() -> int: ...
|
| 31 |
+
def _grad_increment_nesting() -> int: ...
|
| 32 |
+
def _grad_decrement_nesting() -> int: ...
|
| 33 |
+
def _jvp_increment_nesting() -> int: ...
|
| 34 |
+
def _jvp_decrement_nesting() -> int: ...
|
| 35 |
+
|
| 36 |
+
# Defined in aten/src/ATen/functorch/Interpreter.h
|
| 37 |
+
class TransformType(Enum):
|
| 38 |
+
Torch = ...
|
| 39 |
+
Vmap = ...
|
| 40 |
+
Grad = ...
|
| 41 |
+
Jvp = ...
|
| 42 |
+
Functionalize = ...
|
| 43 |
+
|
| 44 |
+
class RandomnessType(Enum):
|
| 45 |
+
Error = ...
|
| 46 |
+
Same = ...
|
| 47 |
+
Different = ...
|
| 48 |
+
|
| 49 |
+
class CInterpreter:
|
| 50 |
+
def key(self) -> TransformType: ...
|
| 51 |
+
def level(self) -> int: ...
|
| 52 |
+
def serialize(self) -> bytes: ...
|
| 53 |
+
@staticmethod
|
| 54 |
+
def deserialize(bytes) -> CInterpreter: ...
|
| 55 |
+
|
| 56 |
+
class CGradInterpreterPtr:
|
| 57 |
+
def __init__(self, interpreter: CInterpreter) -> None: ...
|
| 58 |
+
def lift(self, Tensor) -> Tensor: ...
|
| 59 |
+
def prevGradMode(self) -> bool: ...
|
| 60 |
+
|
| 61 |
+
class CJvpInterpreterPtr:
|
| 62 |
+
def __init__(self, interpreter: CInterpreter) -> None: ...
|
| 63 |
+
def lift(self, Tensor) -> Tensor: ...
|
| 64 |
+
def prevFwdGradMode(self) -> bool: ...
|
| 65 |
+
|
| 66 |
+
class CFunctionalizeInterpreterPtr:
|
| 67 |
+
def __init__(self, interpreter: CInterpreter) -> None: ...
|
| 68 |
+
def key(self) -> TransformType: ...
|
| 69 |
+
def level(self) -> int: ...
|
| 70 |
+
def functionalizeAddBackViews(self) -> bool: ...
|
| 71 |
+
|
| 72 |
+
class CVmapInterpreterPtr:
|
| 73 |
+
def __init__(self, interpreter: CInterpreter) -> None: ...
|
| 74 |
+
def key(self) -> TransformType: ...
|
| 75 |
+
def level(self) -> int: ...
|
| 76 |
+
def batchSize(self) -> int: ...
|
| 77 |
+
def randomness(self) -> RandomnessType: ...
|
| 78 |
+
|
| 79 |
+
class DynamicLayer: ...
|
| 80 |
+
|
| 81 |
+
def get_dynamic_layer_stack_depth() -> int: ...
|
| 82 |
+
def get_interpreter_stack() -> list[CInterpreter]: ...
|
| 83 |
+
def peek_interpreter_stack() -> CInterpreter: ...
|
| 84 |
+
def pop_dynamic_layer_stack() -> DynamicLayer: ...
|
| 85 |
+
def pop_dynamic_layer_stack_and_undo_to_depth(int) -> None: ...
|
| 86 |
+
def push_dynamic_layer_stack(dl: DynamicLayer) -> int: ...
|
phivenv/Lib/site-packages/torch/_C/_instruction_counter.pyi
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Defined in torch/csrc/instruction_counter/Module.cpp
|
| 2 |
+
|
| 3 |
+
def start() -> int: ...
|
| 4 |
+
def end(id: int) -> int: ...
|
phivenv/Lib/site-packages/torch/_C/_itt.pyi
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Defined in torch/csrc/itt.cpp
|
| 2 |
+
def is_available() -> None: ...
|
| 3 |
+
def rangePush(message: str) -> None: ...
|
| 4 |
+
def rangePop() -> None: ...
|
| 5 |
+
def mark(message: str) -> None: ...
|
phivenv/Lib/site-packages/torch/_C/_lazy.pyi
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import Tensor
|
| 2 |
+
|
| 3 |
+
# defined in torch/csrc/lazy/python/init.cpp
|
| 4 |
+
def _mark_step(device: str, devices: list[str], wait: bool) -> None: ...
|
| 5 |
+
def _wait_device_ops(devices: list[str]) -> None: ...
|
| 6 |
+
def _reset_metrics() -> None: ...
|
| 7 |
+
def _counter_names() -> list[str]: ...
|
| 8 |
+
def _counter_value(name: str) -> int: ...
|
| 9 |
+
def _metrics_report() -> str: ...
|
| 10 |
+
def _get_graph_hash(tensors: list[Tensor]) -> str: ...
|
| 11 |
+
def _sync_multi(
|
| 12 |
+
tensors: list[Tensor],
|
| 13 |
+
devices: list[str],
|
| 14 |
+
wait: bool = True,
|
| 15 |
+
sync_ltc_data: bool = True,
|
| 16 |
+
) -> None: ...
|
| 17 |
+
def _get_tensor_id(tensor: Tensor) -> int: ...
|
| 18 |
+
def _get_tensors_text(tensors: list[Tensor]) -> str: ...
|
| 19 |
+
def _get_tensors_dot(tensors: list[Tensor]) -> str: ...
|
| 20 |
+
def _get_tensors_backend(tensors: list[Tensor]) -> str: ...
|
| 21 |
+
def _get_force_fallback() -> str: ...
|
| 22 |
+
def _set_force_fallback(newval: str) -> None: ...
|
| 23 |
+
def _clear_ir_cache() -> None: ...
|
| 24 |
+
def _dump_ir_cache(filename: str) -> None: ...
|
| 25 |
+
def _set_reuse_ir(val: bool) -> None: ...
|
| 26 |
+
def _get_default_device_type() -> str: ...
|
phivenv/Lib/site-packages/torch/_C/_lazy_ts_backend.pyi
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
# defined in torch/csrc/lazy/python/init.cpp
|
| 3 |
+
|
| 4 |
+
from typing import Any
|
| 5 |
+
|
| 6 |
+
from torch import Tensor
|
| 7 |
+
|
| 8 |
+
def _init(): ...
|
| 9 |
+
def _get_tensors_ts_device_data_node(
|
| 10 |
+
tensors: list[Tensor],
|
| 11 |
+
) -> tuple[list[int], list[Any]]: ...
|
| 12 |
+
def _run_cached_graph(hash_str: str, graph_inputs: list[Any]) -> list[Tensor]: ...
|
phivenv/Lib/site-packages/torch/_C/_monitor.pyi
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Defined in torch/csrc/monitor/python_init.cpp
|
| 2 |
+
|
| 3 |
+
import datetime
|
| 4 |
+
from enum import Enum
|
| 5 |
+
from types import TracebackType
|
| 6 |
+
from typing import Callable
|
| 7 |
+
|
| 8 |
+
class Aggregation(Enum):
|
| 9 |
+
VALUE = ...
|
| 10 |
+
MEAN = ...
|
| 11 |
+
COUNT = ...
|
| 12 |
+
SUM = ...
|
| 13 |
+
MAX = ...
|
| 14 |
+
MIN = ...
|
| 15 |
+
|
| 16 |
+
class Stat:
|
| 17 |
+
name: str
|
| 18 |
+
count: int
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
name: str,
|
| 22 |
+
aggregations: list[Aggregation],
|
| 23 |
+
window_size: int,
|
| 24 |
+
max_samples: int = -1,
|
| 25 |
+
) -> None: ...
|
| 26 |
+
def add(self, v: float) -> None: ...
|
| 27 |
+
def get(self) -> dict[Aggregation, float]: ...
|
| 28 |
+
|
| 29 |
+
class Event:
|
| 30 |
+
name: str
|
| 31 |
+
timestamp: datetime.datetime
|
| 32 |
+
data: dict[str, int | float | bool | str]
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
name: str,
|
| 36 |
+
timestamp: datetime.datetime,
|
| 37 |
+
data: dict[str, int | float | bool | str],
|
| 38 |
+
) -> None: ...
|
| 39 |
+
|
| 40 |
+
def log_event(e: Event) -> None: ...
|
| 41 |
+
|
| 42 |
+
class EventHandlerHandle: ...
|
| 43 |
+
|
| 44 |
+
def register_event_handler(handler: Callable[[Event], None]) -> EventHandlerHandle: ...
|
| 45 |
+
def unregister_event_handler(handle: EventHandlerHandle) -> None: ...
|
| 46 |
+
|
| 47 |
+
class _WaitCounterTracker:
|
| 48 |
+
def __enter__(self) -> None: ...
|
| 49 |
+
def __exit__(
|
| 50 |
+
self,
|
| 51 |
+
exc_type: type[BaseException] | None = None,
|
| 52 |
+
exc_value: BaseException | None = None,
|
| 53 |
+
traceback: TracebackType | None = None,
|
| 54 |
+
) -> None: ...
|
| 55 |
+
|
| 56 |
+
class _WaitCounter:
|
| 57 |
+
def __init__(self, key: str) -> None: ...
|
| 58 |
+
def guard(self) -> _WaitCounterTracker: ...
|
phivenv/Lib/site-packages/torch/_C/_nn.pyi
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @generated by tools/pyi/gen_pyi.py from torch/_C/_nn.pyi.in
|
| 2 |
+
# mypy: disable-error-code="type-arg"
|
| 3 |
+
|
| 4 |
+
from collections.abc import Sequence
|
| 5 |
+
from typing import Literal, overload
|
| 6 |
+
|
| 7 |
+
from torch import memory_format, Tensor
|
| 8 |
+
from torch.types import _bool, _device, _dtype, _int, _size
|
| 9 |
+
|
| 10 |
+
# Defined in tools/autograd/templates/python_nn_functions.cpp
|
| 11 |
+
|
| 12 |
+
def adaptive_avg_pool2d(input: Tensor, output_size: _int | _size) -> Tensor: ...
|
| 13 |
+
def adaptive_avg_pool3d(input: Tensor, output_size: _int | _size) -> Tensor: ...
|
| 14 |
+
def adaptive_max_pool2d(
|
| 15 |
+
input: Tensor,
|
| 16 |
+
output_size: _int | _size,
|
| 17 |
+
) -> tuple[Tensor, Tensor]: ...
|
| 18 |
+
def adaptive_max_pool3d(
|
| 19 |
+
input: Tensor,
|
| 20 |
+
output_size: _int | _size,
|
| 21 |
+
) -> tuple[Tensor, Tensor]: ...
|
| 22 |
+
def avg_pool2d(
|
| 23 |
+
input: Tensor,
|
| 24 |
+
kernel_size: _int | _size,
|
| 25 |
+
stride: _int | _size | None = None,
|
| 26 |
+
padding: _int | _size = 0,
|
| 27 |
+
ceil_mode: bool = False,
|
| 28 |
+
count_include_pad: bool = True,
|
| 29 |
+
divisor_override: int | None = None,
|
| 30 |
+
) -> Tensor: ...
|
| 31 |
+
def avg_pool3d(
|
| 32 |
+
input: Tensor,
|
| 33 |
+
kernel_size: _int | _size,
|
| 34 |
+
stride: _int | _size | None = None,
|
| 35 |
+
padding: _int | _size = 0,
|
| 36 |
+
ceil_mode: bool = False,
|
| 37 |
+
count_include_pad: bool = True,
|
| 38 |
+
divisor_override: int | None = None,
|
| 39 |
+
) -> Tensor: ...
|
| 40 |
+
def binary_cross_entropy(
|
| 41 |
+
input: Tensor,
|
| 42 |
+
target: Tensor,
|
| 43 |
+
weight: Tensor | None = None,
|
| 44 |
+
reduction: str = ...,
|
| 45 |
+
) -> Tensor: ...
|
| 46 |
+
def col2im(
|
| 47 |
+
input: Tensor,
|
| 48 |
+
output_size: _int | _size,
|
| 49 |
+
kernel_size: _int | _size,
|
| 50 |
+
dilation: _int | _size,
|
| 51 |
+
stride: _int | _size | None = None,
|
| 52 |
+
padding: _int | _size = 0,
|
| 53 |
+
) -> Tensor: ...
|
| 54 |
+
def elu_(input: Tensor, alpha: float = ...) -> Tensor: ...
|
| 55 |
+
def fractional_max_pool2d(
|
| 56 |
+
input: Tensor,
|
| 57 |
+
kernel_size: _int | _size,
|
| 58 |
+
output_size: _int | _size,
|
| 59 |
+
_random_samples: Tensor,
|
| 60 |
+
) -> tuple[Tensor, Tensor]: ...
|
| 61 |
+
def fractional_max_pool3d(
|
| 62 |
+
input: Tensor,
|
| 63 |
+
kernel_size: _int | _size,
|
| 64 |
+
output_size: _int | _size,
|
| 65 |
+
_random_samples: Tensor,
|
| 66 |
+
) -> tuple[Tensor, Tensor]: ...
|
| 67 |
+
def gelu(input: Tensor, approximate: str = ...) -> Tensor: ...
|
| 68 |
+
def hardsigmoid(input: Tensor, *, out: Tensor | None = None) -> Tensor: ...
|
| 69 |
+
def hardtanh(
|
| 70 |
+
input: Tensor,
|
| 71 |
+
min_val: float = ...,
|
| 72 |
+
max_val: float = ...,
|
| 73 |
+
*,
|
| 74 |
+
out: Tensor | None = None,
|
| 75 |
+
) -> Tensor: ...
|
| 76 |
+
def hardtanh_(
|
| 77 |
+
input: Tensor,
|
| 78 |
+
min_val: float = ...,
|
| 79 |
+
max_val: float = ...,
|
| 80 |
+
) -> Tensor: ...
|
| 81 |
+
def leaky_relu(
|
| 82 |
+
input: Tensor,
|
| 83 |
+
negative_slope: float = ...,
|
| 84 |
+
*,
|
| 85 |
+
out: Tensor | None = None,
|
| 86 |
+
) -> Tensor: ...
|
| 87 |
+
def leaky_relu_(input: Tensor, negative_slope: float = ...) -> Tensor: ...
|
| 88 |
+
def linear(
|
| 89 |
+
input: Tensor,
|
| 90 |
+
weight: Tensor,
|
| 91 |
+
bias: Tensor | None = None,
|
| 92 |
+
) -> Tensor: ...
|
| 93 |
+
def log_sigmoid(input: Tensor) -> Tensor: ...
|
| 94 |
+
def one_hot(tensor: Tensor, num_classes: int = ...) -> Tensor: ...
|
| 95 |
+
def pad(
|
| 96 |
+
input: Tensor,
|
| 97 |
+
pad: Sequence[int],
|
| 98 |
+
mode: str = ...,
|
| 99 |
+
value: float | None = None,
|
| 100 |
+
) -> Tensor: ...
|
| 101 |
+
def scaled_dot_product_attention(
|
| 102 |
+
query: Tensor,
|
| 103 |
+
key: Tensor,
|
| 104 |
+
value: Tensor,
|
| 105 |
+
attn_mask: Tensor | None = None,
|
| 106 |
+
dropout_p: float = 0.0,
|
| 107 |
+
is_causal: bool = False,
|
| 108 |
+
scale: float | None = None,
|
| 109 |
+
enable_gqa: bool = False,
|
| 110 |
+
) -> Tensor: ...
|
| 111 |
+
def softplus(
|
| 112 |
+
input: Tensor,
|
| 113 |
+
beta: float = ...,
|
| 114 |
+
threshold: float = ...,
|
| 115 |
+
) -> Tensor: ...
|
| 116 |
+
def softshrink(input: Tensor, lambd: float = ...) -> Tensor: ...
|
| 117 |
+
|
| 118 |
+
# Defined in aten/src/ATen/native/mkldnn/Linear.cpp
|
| 119 |
+
def mkldnn_linear(input: Tensor, weight: Tensor, bias: Tensor | None) -> Tensor: ...
|
| 120 |
+
|
| 121 |
+
# Defined at aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp
|
| 122 |
+
def mkldnn_reorder_conv2d_weight(
|
| 123 |
+
self: Tensor,
|
| 124 |
+
padding: list,
|
| 125 |
+
stride: list,
|
| 126 |
+
dilatation: list,
|
| 127 |
+
groups: int,
|
| 128 |
+
) -> Tensor: ...
|
| 129 |
+
def mkldnn_reorder_conv3d_weight(
|
| 130 |
+
self: Tensor,
|
| 131 |
+
padding: list,
|
| 132 |
+
stride: list,
|
| 133 |
+
dilatation: list,
|
| 134 |
+
groups: int,
|
| 135 |
+
) -> Tensor: ...
|
| 136 |
+
|
| 137 |
+
# Defined in aten/src/ATen/native/mkldnn/Prelu.cpp
|
| 138 |
+
def mkldnn_prelu(input: Tensor, weight: Tensor) -> Tensor: ...
|
| 139 |
+
|
| 140 |
+
# Defined at tools/autograd/templates/python_nn_functions.cpp
|
| 141 |
+
@overload
|
| 142 |
+
def _parse_to(
|
| 143 |
+
device: _device,
|
| 144 |
+
dtype: _dtype,
|
| 145 |
+
non_blocking: _bool,
|
| 146 |
+
copy: _bool,
|
| 147 |
+
*,
|
| 148 |
+
memory_format: memory_format,
|
| 149 |
+
) -> tuple[_device, _dtype, _bool, memory_format]: ...
|
| 150 |
+
@overload
|
| 151 |
+
def _parse_to(
|
| 152 |
+
dtype: _dtype,
|
| 153 |
+
non_blocking: _bool,
|
| 154 |
+
copy: _bool,
|
| 155 |
+
*,
|
| 156 |
+
memory_format: memory_format,
|
| 157 |
+
) -> tuple[_device, _dtype, _bool, memory_format]: ...
|
| 158 |
+
@overload
|
| 159 |
+
def _parse_to(
|
| 160 |
+
tensor: Tensor,
|
| 161 |
+
non_blocking: _bool,
|
| 162 |
+
copy: _bool,
|
| 163 |
+
*,
|
| 164 |
+
memory_format: memory_format,
|
| 165 |
+
) -> tuple[_device, _dtype, _bool, memory_format]: ...
|
| 166 |
+
|
| 167 |
+
# Defined in aten/src/ATen/native/PackedSequence.cpp
|
| 168 |
+
def pad_sequence(
|
| 169 |
+
sequences: list[Tensor] | tuple[Tensor, ...],
|
| 170 |
+
batch_first: bool = False,
|
| 171 |
+
padding_value: float = 0.0,
|
| 172 |
+
padding_side: Literal["left", "right"] = "right",
|
| 173 |
+
) -> Tensor: ...
|
| 174 |
+
def flatten_dense_tensors(tensors: list[Tensor]) -> Tensor: ...
|
| 175 |
+
def unflatten_dense_tensors(flat: Tensor, tensors: list[Tensor]) -> list[Tensor]: ...
|
phivenv/Lib/site-packages/torch/_C/_nvtx.pyi
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
# Defined in torch/csrc/cuda/shared/nvtx.cpp
|
| 3 |
+
def rangePushA(message: str) -> int: ...
|
| 4 |
+
def rangePop() -> int: ...
|
| 5 |
+
def rangeStartA(message: str) -> int: ...
|
| 6 |
+
def rangeEnd(int) -> None: ...
|
| 7 |
+
def markA(message: str) -> None: ...
|
| 8 |
+
def deviceRangeStart(message: str, stream: int) -> object: ...
|
| 9 |
+
def deviceRangeEnd(range_handle: object, stream: int) -> None: ...
|
phivenv/Lib/site-packages/torch/_C/_onnx.pyi
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Defined in torch/csrc/onnx/init.cpp
|
| 2 |
+
|
| 3 |
+
from enum import Enum
|
| 4 |
+
|
| 5 |
+
PRODUCER_VERSION: str
|
| 6 |
+
|
| 7 |
+
class TensorProtoDataType(Enum):
|
| 8 |
+
UNDEFINED = ...
|
| 9 |
+
FLOAT = ...
|
| 10 |
+
UINT8 = ...
|
| 11 |
+
INT8 = ...
|
| 12 |
+
UINT16 = ...
|
| 13 |
+
INT16 = ...
|
| 14 |
+
INT32 = ...
|
| 15 |
+
INT64 = ...
|
| 16 |
+
STRING = ...
|
| 17 |
+
BOOL = ...
|
| 18 |
+
FLOAT16 = ...
|
| 19 |
+
DOUBLE = ...
|
| 20 |
+
UINT32 = ...
|
| 21 |
+
UINT64 = ...
|
| 22 |
+
COMPLEX64 = ...
|
| 23 |
+
COMPLEX128 = ...
|
| 24 |
+
BFLOAT16 = ...
|
| 25 |
+
FLOAT8E5M2 = ...
|
| 26 |
+
FLOAT8E4M3FN = ...
|
| 27 |
+
FLOAT8E5M2FNUZ = ...
|
| 28 |
+
FLOAT8E4M3FNUZ = ...
|
| 29 |
+
|
| 30 |
+
class OperatorExportTypes(Enum):
|
| 31 |
+
ONNX = ...
|
| 32 |
+
ONNX_ATEN = ...
|
| 33 |
+
ONNX_ATEN_FALLBACK = ...
|
| 34 |
+
ONNX_FALLTHROUGH = ...
|
| 35 |
+
|
| 36 |
+
class TrainingMode(Enum):
|
| 37 |
+
EVAL = ...
|
| 38 |
+
PRESERVE = ...
|
| 39 |
+
TRAINING = ...
|
phivenv/Lib/site-packages/torch/_C/_profiler.pyi
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
from typing import Literal
|
| 3 |
+
from typing_extensions import TypeAlias
|
| 4 |
+
|
| 5 |
+
from torch._C import device, dtype, layout
|
| 6 |
+
|
| 7 |
+
# defined in torch/csrc/profiler/python/init.cpp
|
| 8 |
+
|
| 9 |
+
class RecordScope(Enum):
|
| 10 |
+
FUNCTION = ...
|
| 11 |
+
BACKWARD_FUNCTION = ...
|
| 12 |
+
TORCHSCRIPT_FUNCTION = ...
|
| 13 |
+
KERNEL_FUNCTION_DTYPE = ...
|
| 14 |
+
CUSTOM_CLASS = ...
|
| 15 |
+
BUILD_FEATURE = ...
|
| 16 |
+
LITE_INTERPRETER = ...
|
| 17 |
+
USER_SCOPE = ...
|
| 18 |
+
STATIC_RUNTIME_OP = ...
|
| 19 |
+
STATIC_RUNTIME_MODEL = ...
|
| 20 |
+
|
| 21 |
+
class ProfilerState(Enum):
|
| 22 |
+
Disable = ...
|
| 23 |
+
CPU = ...
|
| 24 |
+
CUDA = ...
|
| 25 |
+
NVTX = ...
|
| 26 |
+
ITT = ...
|
| 27 |
+
KINETO = ...
|
| 28 |
+
KINETO_GPU_FALLBACK = ...
|
| 29 |
+
KINETO_PRIVATEUSE1_FALLBACK = ...
|
| 30 |
+
KINETO_PRIVATEUSE1 = ...
|
| 31 |
+
|
| 32 |
+
class ActiveProfilerType(Enum):
|
| 33 |
+
NONE = ...
|
| 34 |
+
LEGACY = ...
|
| 35 |
+
KINETO = ...
|
| 36 |
+
NVTX = ...
|
| 37 |
+
ITT = ...
|
| 38 |
+
|
| 39 |
+
class ProfilerActivity(Enum):
|
| 40 |
+
CPU = ...
|
| 41 |
+
CUDA = ...
|
| 42 |
+
XPU = ...
|
| 43 |
+
MTIA = ...
|
| 44 |
+
HPU = ...
|
| 45 |
+
PrivateUse1 = ...
|
| 46 |
+
|
| 47 |
+
class _EventType(Enum):
|
| 48 |
+
TorchOp = ...
|
| 49 |
+
Backend = ...
|
| 50 |
+
Allocation = ...
|
| 51 |
+
OutOfMemory = ...
|
| 52 |
+
PyCall = ...
|
| 53 |
+
PyCCall = ...
|
| 54 |
+
Kineto = ...
|
| 55 |
+
|
| 56 |
+
class _ExperimentalConfig:
|
| 57 |
+
def __init__(
|
| 58 |
+
self,
|
| 59 |
+
profiler_metrics: list[str] = ...,
|
| 60 |
+
profiler_measure_per_kernel: bool = ...,
|
| 61 |
+
verbose: bool = ...,
|
| 62 |
+
performance_events: list[str] = ...,
|
| 63 |
+
enable_cuda_sync_events: bool = ...,
|
| 64 |
+
) -> None: ...
|
| 65 |
+
|
| 66 |
+
class ProfilerConfig:
|
| 67 |
+
def __init__(
|
| 68 |
+
self,
|
| 69 |
+
state: ProfilerState,
|
| 70 |
+
report_input_shapes: bool,
|
| 71 |
+
profile_memory: bool,
|
| 72 |
+
with_stack: bool,
|
| 73 |
+
with_flops: bool,
|
| 74 |
+
with_modules: bool,
|
| 75 |
+
experimental_config: _ExperimentalConfig,
|
| 76 |
+
trace_id: str | None = None,
|
| 77 |
+
) -> None: ...
|
| 78 |
+
|
| 79 |
+
class _ProfilerEvent:
|
| 80 |
+
start_tid: int
|
| 81 |
+
start_time_ns: int
|
| 82 |
+
children: list[_ProfilerEvent]
|
| 83 |
+
|
| 84 |
+
# TODO(robieta): remove in favor of `self.typed`
|
| 85 |
+
extra_fields: (
|
| 86 |
+
_ExtraFields_TorchOp
|
| 87 |
+
| _ExtraFields_Backend
|
| 88 |
+
| _ExtraFields_Allocation
|
| 89 |
+
| _ExtraFields_OutOfMemory
|
| 90 |
+
| _ExtraFields_PyCall
|
| 91 |
+
| _ExtraFields_PyCCall
|
| 92 |
+
| _ExtraFields_Kineto
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
@property
|
| 96 |
+
def typed(
|
| 97 |
+
self,
|
| 98 |
+
) -> (
|
| 99 |
+
tuple[Literal[_EventType.TorchOp], _ExtraFields_TorchOp]
|
| 100 |
+
| tuple[Literal[_EventType.Backend], _ExtraFields_Backend]
|
| 101 |
+
| tuple[Literal[_EventType.Allocation], _ExtraFields_Allocation]
|
| 102 |
+
| tuple[Literal[_EventType.OutOfMemory], _ExtraFields_OutOfMemory]
|
| 103 |
+
| tuple[Literal[_EventType.PyCall], _ExtraFields_PyCall]
|
| 104 |
+
| tuple[Literal[_EventType.PyCCall], _ExtraFields_PyCCall]
|
| 105 |
+
| tuple[Literal[_EventType.Kineto], _ExtraFields_Kineto]
|
| 106 |
+
): ...
|
| 107 |
+
@property
|
| 108 |
+
def name(self) -> str: ...
|
| 109 |
+
@property
|
| 110 |
+
def tag(self) -> _EventType: ...
|
| 111 |
+
@property
|
| 112 |
+
def id(self) -> int: ...
|
| 113 |
+
@property
|
| 114 |
+
def parent(self) -> _ProfilerEvent | None: ...
|
| 115 |
+
@property
|
| 116 |
+
def correlation_id(self) -> int: ...
|
| 117 |
+
@property
|
| 118 |
+
def end_time_ns(self) -> int: ...
|
| 119 |
+
@property
|
| 120 |
+
def duration_time_ns(self) -> int: ...
|
| 121 |
+
|
| 122 |
+
class _TensorMetadata:
|
| 123 |
+
impl_ptr: int | None
|
| 124 |
+
storage_data_ptr: int | None
|
| 125 |
+
id: int | None
|
| 126 |
+
|
| 127 |
+
@property
|
| 128 |
+
def allocation_id(self) -> int | None: ...
|
| 129 |
+
@property
|
| 130 |
+
def layout(self) -> layout: ...
|
| 131 |
+
@property
|
| 132 |
+
def device(self) -> device: ...
|
| 133 |
+
@property
|
| 134 |
+
def dtype(self) -> dtype: ...
|
| 135 |
+
@property
|
| 136 |
+
def sizes(self) -> list[int]: ...
|
| 137 |
+
@property
|
| 138 |
+
def strides(self) -> list[int]: ...
|
| 139 |
+
|
| 140 |
+
Scalar: TypeAlias = int | float | bool | complex
|
| 141 |
+
Input: TypeAlias = _TensorMetadata | list[_TensorMetadata] | Scalar | None
|
| 142 |
+
|
| 143 |
+
class _ExtraFields_TorchOp:
|
| 144 |
+
name: str
|
| 145 |
+
sequence_number: int
|
| 146 |
+
allow_tf32_cublas: bool
|
| 147 |
+
|
| 148 |
+
@property
|
| 149 |
+
def inputs(self) -> list[Input]: ...
|
| 150 |
+
@property
|
| 151 |
+
def scope(self) -> RecordScope: ...
|
| 152 |
+
|
| 153 |
+
class _ExtraFields_Backend: ...
|
| 154 |
+
|
| 155 |
+
class _ExtraFields_Allocation:
|
| 156 |
+
ptr: int
|
| 157 |
+
id: int | None
|
| 158 |
+
alloc_size: int
|
| 159 |
+
total_allocated: int
|
| 160 |
+
total_reserved: int
|
| 161 |
+
|
| 162 |
+
@property
|
| 163 |
+
def allocation_id(self) -> int | None: ...
|
| 164 |
+
@property
|
| 165 |
+
def device(self) -> device: ...
|
| 166 |
+
|
| 167 |
+
class _ExtraFields_OutOfMemory: ...
|
| 168 |
+
|
| 169 |
+
class _PyFrameState:
|
| 170 |
+
line_number: int
|
| 171 |
+
function_name: str
|
| 172 |
+
|
| 173 |
+
@property
|
| 174 |
+
def file_name(self) -> str: ...
|
| 175 |
+
|
| 176 |
+
class _NNModuleInfo:
|
| 177 |
+
@property
|
| 178 |
+
def self_ptr(self) -> int: ...
|
| 179 |
+
@property
|
| 180 |
+
def cls_ptr(self) -> int: ...
|
| 181 |
+
@property
|
| 182 |
+
def cls_name(self) -> str: ...
|
| 183 |
+
@property
|
| 184 |
+
def parameters(
|
| 185 |
+
self,
|
| 186 |
+
) -> list[tuple[str, _TensorMetadata, _TensorMetadata | None]]: ...
|
| 187 |
+
|
| 188 |
+
class _OptimizerInfo:
|
| 189 |
+
@property
|
| 190 |
+
def parameters(
|
| 191 |
+
self,
|
| 192 |
+
) -> list[
|
| 193 |
+
tuple[
|
| 194 |
+
# Parameter
|
| 195 |
+
_TensorMetadata,
|
| 196 |
+
#
|
| 197 |
+
# Gradient (if present during optimizer.step())
|
| 198 |
+
_TensorMetadata | None,
|
| 199 |
+
#
|
| 200 |
+
# Optimizer state for Parameter as (name, tensor) pairs
|
| 201 |
+
list[tuple[str, _TensorMetadata]],
|
| 202 |
+
]
|
| 203 |
+
]: ...
|
| 204 |
+
|
| 205 |
+
class _ExtraFields_PyCCall:
|
| 206 |
+
@property
|
| 207 |
+
def caller(self) -> _PyFrameState: ...
|
| 208 |
+
|
| 209 |
+
class _ExtraFields_PyCall:
|
| 210 |
+
@property
|
| 211 |
+
def callsite(self) -> _PyFrameState: ...
|
| 212 |
+
@property
|
| 213 |
+
def caller(self) -> _PyFrameState: ...
|
| 214 |
+
@property
|
| 215 |
+
def module(self) -> _NNModuleInfo | None: ...
|
| 216 |
+
@property
|
| 217 |
+
def optimizer(self) -> _OptimizerInfo | None: ...
|
| 218 |
+
|
| 219 |
+
class _ExtraFields_Kineto: ...
|
| 220 |
+
|
| 221 |
+
def _add_execution_trace_observer(output_file_path: str) -> bool: ...
|
| 222 |
+
def _remove_execution_trace_observer() -> None: ...
|
| 223 |
+
def _enable_execution_trace_observer() -> None: ...
|
| 224 |
+
def _disable_execution_trace_observer() -> None: ...
|
| 225 |
+
def _set_record_concrete_inputs_enabled_val(val: bool) -> None: ...
|
| 226 |
+
def _set_fwd_bwd_enabled_val(val: bool) -> None: ...
|
| 227 |
+
def _set_cuda_sync_enabled_val(val: bool) -> None: ...
|
| 228 |
+
|
| 229 |
+
class CapturedTraceback: ...
|
| 230 |
+
|
| 231 |
+
def gather_traceback(python: bool, script: bool, cpp: bool) -> CapturedTraceback: ...
|
| 232 |
+
|
| 233 |
+
# The Dict has name, filename, line
|
| 234 |
+
def symbolize_tracebacks(
|
| 235 |
+
to_symbolize: list[CapturedTraceback],
|
| 236 |
+
) -> list[list[dict[str, str]]]: ...
|
| 237 |
+
|
| 238 |
+
class _RecordFunctionFast:
|
| 239 |
+
def __init__(
|
| 240 |
+
self,
|
| 241 |
+
name: str,
|
| 242 |
+
input_values: list | tuple | None = None,
|
| 243 |
+
keyword_values: dict | None = None,
|
| 244 |
+
) -> None: ...
|
| 245 |
+
def __enter__(self) -> None: ...
|
| 246 |
+
def __exit__(self, *exc_info: object) -> None: ...
|
phivenv/Lib/site-packages/torch/_C/_verbose.pyi
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Defined in torch/csrc/utils/verbose.cpp
|
| 2 |
+
def mkl_set_verbose(enable: int) -> int: ...
|
| 3 |
+
def mkldnn_set_verbose(level: int) -> int: ...
|
phivenv/Lib/site-packages/torch/_C_flatbuffer/__init__.pyi
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from torch._C import LiteScriptModule, ScriptModule
|
| 3 |
+
|
| 4 |
+
def _load_mobile_module_from_file(filename: str): ...
|
| 5 |
+
def _load_mobile_module_from_bytes(bytes_: bytes): ...
|
| 6 |
+
def _load_jit_module_from_file(filename: str): ...
|
| 7 |
+
def _load_jit_module_from_bytes(bytes_: bytes): ...
|
| 8 |
+
def _save_mobile_module(m: LiteScriptModule, filename: str): ...
|
| 9 |
+
def _save_jit_module(m: ScriptModule, filename: str): ...
|
| 10 |
+
def _save_mobile_module_to_bytes(m: LiteScriptModule) -> bytes: ...
|
| 11 |
+
def _save_jit_module_to_bytes(m: ScriptModule) -> bytes: ...
|
phivenv/Lib/site-packages/torch/_awaits/__init__.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Generic, TypeVar
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
__all__ = ['Await']
|
| 8 |
+
|
| 9 |
+
W = TypeVar("W")
|
| 10 |
+
|
| 11 |
+
class _PyAwaitMeta(type(torch._C._Await), type(Generic)): # type: ignore[misc, no-redef]
|
| 12 |
+
pass
|
| 13 |
+
|
| 14 |
+
class _Await(torch._C._Await, Generic[W], metaclass=_PyAwaitMeta):
|
| 15 |
+
r"""
|
| 16 |
+
Wrapper around a ``torch._C.Await`` which encapsulates delayed execution
|
| 17 |
+
of a callable. All manipulations happen with functions ``torch.jit._awaitable``,
|
| 18 |
+
``torch.jit._awaitable_wait``, ``torch.jit._awaitable_nowait``.
|
| 19 |
+
|
| 20 |
+
Torch scriptable manipulations:
|
| 21 |
+
``torch.jit._awaitable(func, *args)``
|
| 22 |
+
Creates ``Await[W]`` object, where W is return type of func.
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
``torch.jit._awaitable_wait(Await[W])``
|
| 26 |
+
Returns the result of the function, specified at ``_awaitable``, with specified arguments.
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
The result of type ``W`` of the function call. The result is owned by ``Await[W]``
|
| 30 |
+
and returned on all following ``_awaitable_wait`` calls.
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
``torch.jit._awaitable_nowait(W)``
|
| 34 |
+
Returns:
|
| 35 |
+
Trivial ``Await[W]`` with specified result.
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
Only in eager mode:
|
| 39 |
+
``fn() -> Callable[Tuple[Any], W]``
|
| 40 |
+
Returns:
|
| 41 |
+
Specified at ``_awaitable`` python function ``func``.
|
| 42 |
+
|
| 43 |
+
``args() -> Tuple[Any]``
|
| 44 |
+
Returns:
|
| 45 |
+
Specified at ``_awaitable`` python args.
|
| 46 |
+
|
| 47 |
+
``is_nowait() -> _bool``
|
| 48 |
+
Returns:
|
| 49 |
+
``True`` if this object was created via ``_awaitable_nowait`` call (trivial `Await[W]`).
|
| 50 |
+
|
| 51 |
+
In eager mode ``Await[W]`` can be used as ``W`` i.e. attributes of W can be called on ``Await[W]``,
|
| 52 |
+
``_awaitable_wait()`` call will be transparently added.
|
| 53 |
+
"""
|
phivenv/Lib/site-packages/torch/_awaits/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (2.03 kB). View file
|
|
|
phivenv/Lib/site-packages/torch/_custom_op/__init__.py
ADDED
|
File without changes
|
phivenv/Lib/site-packages/torch/_custom_op/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (160 Bytes). View file
|
|
|
phivenv/Lib/site-packages/torch/_custom_op/__pycache__/autograd.cpython-39.pyc
ADDED
|
Binary file (8.9 kB). View file
|
|
|
phivenv/Lib/site-packages/torch/_custom_op/__pycache__/impl.cpython-39.pyc
ADDED
|
Binary file (21.2 kB). View file
|
|
|
phivenv/Lib/site-packages/torch/_custom_op/autograd.py
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import functools
|
| 3 |
+
from collections import namedtuple
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.utils._pytree as pytree
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# NOTE [CustomOp autograd kernel indirection]
|
| 10 |
+
# We register `inner` as the autograd kernel for this custom_op.
|
| 11 |
+
# `inner` either calls the autograd formula registered by the user,
|
| 12 |
+
# or goes into an `autograd_not_implemented` kernel.
|
| 13 |
+
#
|
| 14 |
+
# The reason why this indirection exists is
|
| 15 |
+
# so that we can swap out the autograd kernel (the PyTorch dispatcher
|
| 16 |
+
# doesn't actually allow us to do this). By default, we want
|
| 17 |
+
# the `autograd_not_implemented` behavior, but then the user may come
|
| 18 |
+
# and register something that is actually a backward formula
|
| 19 |
+
def autograd_kernel_indirection(custom_op):
|
| 20 |
+
autograd_fallback = autograd_not_implemented(custom_op)
|
| 21 |
+
|
| 22 |
+
def inner(*args, **kwargs):
|
| 23 |
+
if custom_op._has_impl("autograd"):
|
| 24 |
+
kernel = custom_op._get_impl("autograd").func
|
| 25 |
+
return kernel(*args, **kwargs)
|
| 26 |
+
# As explained in NOTE ["backward", "save_for_backward", and "autograd"],
|
| 27 |
+
# after the user gives us "backward" and "save_for_backward", we generate
|
| 28 |
+
# the "autograd" impl. If the user only provided one, then we tell
|
| 29 |
+
# the user they've done something wrong.
|
| 30 |
+
if custom_op._has_impl("save_for_backward") or custom_op._has_impl("backward"):
|
| 31 |
+
missing = (
|
| 32 |
+
"save_for_backward" if custom_op._has_impl("backward") else "backward"
|
| 33 |
+
)
|
| 34 |
+
found = "save_for_backward" if missing == "backward" else "backward"
|
| 35 |
+
loc = custom_op._get_impl(found).location
|
| 36 |
+
raise RuntimeError(
|
| 37 |
+
f"We found a '{found}' registration for {custom_op} at "
|
| 38 |
+
f"{loc} but were unable to find a '{missing}' registration. "
|
| 39 |
+
f"To use the CustomOp API to register a backward formula, "
|
| 40 |
+
f"please provide us both a backward function and a "
|
| 41 |
+
f"'save for backward' function via `impl_backward` and "
|
| 42 |
+
f"`impl_save_for_backward` respectively."
|
| 43 |
+
)
|
| 44 |
+
return autograd_fallback(*args, **kwargs)
|
| 45 |
+
|
| 46 |
+
return inner
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# TODO(#101191): Use the actual C++ autograd not implemented fallback,
|
| 50 |
+
# or change the default autograd fallback to the autograd not implemented fallback.
|
| 51 |
+
def autograd_not_implemented(custom_op):
|
| 52 |
+
def kernel(*args, **kwargs):
|
| 53 |
+
if torch.is_grad_enabled() and pytree.tree_any(
|
| 54 |
+
lambda x: isinstance(x, torch.Tensor) and x.requires_grad, (args, kwargs)
|
| 55 |
+
):
|
| 56 |
+
raise RuntimeError("Autograd has not been implemented for operator")
|
| 57 |
+
with torch._C._AutoDispatchBelowAutograd():
|
| 58 |
+
return custom_op(*args, **kwargs)
|
| 59 |
+
|
| 60 |
+
return kernel
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def mark_non_differentiable(ctx, output, output_differentiability):
|
| 64 |
+
# Output types are restricted to be:
|
| 65 |
+
# - Tensor
|
| 66 |
+
# - Tensor[]
|
| 67 |
+
# - int, bool, Scalar, float
|
| 68 |
+
# See _check_can_register_backward
|
| 69 |
+
if output_differentiability is not None:
|
| 70 |
+
if not isinstance(output, tuple):
|
| 71 |
+
tuple_output = (output,)
|
| 72 |
+
else:
|
| 73 |
+
tuple_output = output # type: ignore[assignment]
|
| 74 |
+
assert len(output_differentiability) == len(tuple_output)
|
| 75 |
+
non_differentiable_tensors = []
|
| 76 |
+
for idx, (differentiable, out) in enumerate(
|
| 77 |
+
zip(output_differentiability, tuple_output)
|
| 78 |
+
):
|
| 79 |
+
if isinstance(out, torch.Tensor):
|
| 80 |
+
if not differentiable:
|
| 81 |
+
non_differentiable_tensors.append(out)
|
| 82 |
+
continue
|
| 83 |
+
if isinstance(out, list):
|
| 84 |
+
if not differentiable:
|
| 85 |
+
non_differentiable_tensors.extend(out)
|
| 86 |
+
continue
|
| 87 |
+
if differentiable:
|
| 88 |
+
raise RuntimeError(
|
| 89 |
+
f"With output_differentiability={output_differentiability}. "
|
| 90 |
+
f"At idx {idx}, we received an object of type {type(out)} that "
|
| 91 |
+
f"is not a Tensor, so it cannot have be marked as differentiable in "
|
| 92 |
+
f"output_differentiability."
|
| 93 |
+
)
|
| 94 |
+
if non_differentiable_tensors:
|
| 95 |
+
ctx.mark_non_differentiable(*non_differentiable_tensors)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def construct_autograd_kernel(
|
| 99 |
+
schema,
|
| 100 |
+
output_differentiability,
|
| 101 |
+
custom_op,
|
| 102 |
+
op_overload,
|
| 103 |
+
save_for_backward_fn,
|
| 104 |
+
backward_fn,
|
| 105 |
+
):
|
| 106 |
+
def apply(*args):
|
| 107 |
+
flat_args, spec = pytree.tree_flatten(args)
|
| 108 |
+
out_spec = None
|
| 109 |
+
|
| 110 |
+
def forward(ctx, *flat_args):
|
| 111 |
+
ctx.set_materialize_grads(True)
|
| 112 |
+
args = pytree.tree_unflatten(list(flat_args), spec)
|
| 113 |
+
with torch._C._AutoDispatchBelowAutograd():
|
| 114 |
+
output = op_overload(*args)
|
| 115 |
+
|
| 116 |
+
# We use the info about args to give better error messages in backward
|
| 117 |
+
args_info = namedtuple_args(schema, pytree.tree_map(type, args))
|
| 118 |
+
|
| 119 |
+
save_for_backward_fn_inputs = namedtuple_args(schema, args)
|
| 120 |
+
to_save = save_for_backward_fn(save_for_backward_fn_inputs, output)
|
| 121 |
+
|
| 122 |
+
save_pytree_for_backward(ctx, (to_save, args_info))
|
| 123 |
+
mark_non_differentiable(ctx, output, output_differentiability)
|
| 124 |
+
|
| 125 |
+
nonlocal out_spec
|
| 126 |
+
flat_output, out_spec = pytree.tree_flatten(output)
|
| 127 |
+
return tuple(flat_output)
|
| 128 |
+
|
| 129 |
+
def backward(ctx, *flat_grad_output):
|
| 130 |
+
assert out_spec is not None
|
| 131 |
+
grads = pytree.tree_unflatten(list(flat_grad_output), out_spec)
|
| 132 |
+
saved, args_info = unpack_saved(ctx)
|
| 133 |
+
# There is nothing on the ctx object for now, it is just there so
|
| 134 |
+
# that we can add additional things in the future.
|
| 135 |
+
inner_ctx = object()
|
| 136 |
+
if not isinstance(grads, tuple):
|
| 137 |
+
grads = (grads,)
|
| 138 |
+
grad_inputs_dict = backward_fn(inner_ctx, saved, *grads)
|
| 139 |
+
|
| 140 |
+
# Massage the grad_inputs_dict to a form acceptable by
|
| 141 |
+
# autograd.Function.
|
| 142 |
+
validate_grad_inputs_dict(grad_inputs_dict, custom_op, args_info)
|
| 143 |
+
return grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info)
|
| 144 |
+
|
| 145 |
+
generated_cls = gen_autograd_function(
|
| 146 |
+
custom_op._opname + "_customop", forward, backward
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
flat_output = generated_cls.apply(*flat_args)
|
| 150 |
+
assert out_spec is not None
|
| 151 |
+
return pytree.tree_unflatten(list(flat_output), out_spec)
|
| 152 |
+
|
| 153 |
+
return apply
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def gen_autograd_function(name, forward, backward):
|
| 157 |
+
generated_cls = type(
|
| 158 |
+
name,
|
| 159 |
+
(torch.autograd.Function,),
|
| 160 |
+
{
|
| 161 |
+
"forward": staticmethod(forward),
|
| 162 |
+
"backward": staticmethod(backward),
|
| 163 |
+
},
|
| 164 |
+
)
|
| 165 |
+
return generated_cls
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
@functools.lru_cache
|
| 169 |
+
def namedtuple_args_cls(schema):
|
| 170 |
+
attribs = [arg.name for arg in schema.arguments.flat_all]
|
| 171 |
+
name = str(schema.name) + "_args"
|
| 172 |
+
# mypy doesn't support dynamic namedtuple name
|
| 173 |
+
tuple_cls = namedtuple(name, attribs) # type: ignore[misc]
|
| 174 |
+
return tuple_cls
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def namedtuple_args(schema, args):
|
| 178 |
+
assert isinstance(args, tuple)
|
| 179 |
+
tuple_cls = namedtuple_args_cls(schema)
|
| 180 |
+
return tuple_cls(*args)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def validate_grad_inputs_dict(grad_inputs_dict, forward_op, args_info):
|
| 184 |
+
def error(what):
|
| 185 |
+
backward = forward_op._get_impl("backward")
|
| 186 |
+
raise RuntimeError(
|
| 187 |
+
f"In the backward function defined for {forward_op} at "
|
| 188 |
+
f"{backward.location} using the CustomOp API, {what}"
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
if not isinstance(grad_inputs_dict, dict):
|
| 192 |
+
error(
|
| 193 |
+
f"expected the output of the backward function to be a dict but "
|
| 194 |
+
f"got {type(grad_inputs_dict)}"
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
expected_keys = {
|
| 198 |
+
arg.name
|
| 199 |
+
for arg in forward_op._schema.arguments.flat_all
|
| 200 |
+
if arg.type.is_tensor_like()
|
| 201 |
+
}
|
| 202 |
+
actual_keys = grad_inputs_dict.keys()
|
| 203 |
+
if expected_keys != actual_keys:
|
| 204 |
+
error(
|
| 205 |
+
f"expected the returned grad_input dict to have keys "
|
| 206 |
+
f"{expected_keys} but got {actual_keys}. The backward "
|
| 207 |
+
f"function must return a gradient (can be None) for each arg "
|
| 208 |
+
f"to the CustomOp that may be a Tensor or Sequence[Tensor]. "
|
| 209 |
+
f"Args declared to be non-Tensor-like types should not appear "
|
| 210 |
+
f"in the grad_input dict"
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
for name, grad in grad_inputs_dict.items():
|
| 214 |
+
arg_info = getattr(args_info, name)
|
| 215 |
+
|
| 216 |
+
if isinstance(arg_info, list):
|
| 217 |
+
if not isinstance(grad, (tuple, list)):
|
| 218 |
+
error(
|
| 219 |
+
f"for input '{name}' expected the grad_input dict to "
|
| 220 |
+
f"hold a list of gradients but got object of type "
|
| 221 |
+
f"{type(grad)}."
|
| 222 |
+
)
|
| 223 |
+
if not len(grad) == len(arg_info):
|
| 224 |
+
error(
|
| 225 |
+
f"for input '{name}' expected the grad_input dict to "
|
| 226 |
+
f"hold a list of {len(arg_info)} gradients but got "
|
| 227 |
+
f"{len(grad)}"
|
| 228 |
+
)
|
| 229 |
+
for idx, (g, info) in enumerate(zip(grad, arg_info)):
|
| 230 |
+
if g is None:
|
| 231 |
+
continue
|
| 232 |
+
if not isinstance(g, torch.Tensor):
|
| 233 |
+
error(
|
| 234 |
+
f"for input '{name}' expected the grad_input dict to "
|
| 235 |
+
f"hold a list of None or Tensor gradients but got "
|
| 236 |
+
f"object of {type(g)} at index {idx}"
|
| 237 |
+
)
|
| 238 |
+
if not issubclass(info, torch.Tensor):
|
| 239 |
+
error(
|
| 240 |
+
f"for input '{name}', got a Tensor as the gradient "
|
| 241 |
+
f"for the {idx}-th value but expected None because "
|
| 242 |
+
f"the {idx}-th value was not a Tensor (it was "
|
| 243 |
+
f"type {arg_info}"
|
| 244 |
+
)
|
| 245 |
+
continue
|
| 246 |
+
|
| 247 |
+
if grad is None:
|
| 248 |
+
continue
|
| 249 |
+
if not isinstance(grad, torch.Tensor):
|
| 250 |
+
error(
|
| 251 |
+
f"got object of type {type(grad)} as the gradient for input "
|
| 252 |
+
f"'{name}', "
|
| 253 |
+
f"but expected the gradient to be either None or a Tensor"
|
| 254 |
+
)
|
| 255 |
+
if not issubclass(arg_info, torch.Tensor):
|
| 256 |
+
error(
|
| 257 |
+
f"got a Tensor as the gradient for input '{name}' but "
|
| 258 |
+
f"expected None as the gradient because input '{name}' "
|
| 259 |
+
f"was not a Tensor (it was type {arg_info})."
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info):
|
| 264 |
+
result = []
|
| 265 |
+
for name, arg_info in args_info._asdict().items():
|
| 266 |
+
if name not in grad_inputs_dict:
|
| 267 |
+
result.append(pytree.tree_map(lambda x: None, arg_info))
|
| 268 |
+
continue
|
| 269 |
+
result.append(grad_inputs_dict[name])
|
| 270 |
+
return tuple(pytree.tree_leaves(result))
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
# Saves "stuff" (a pytree) onto the ctx object. Use unpack_saved to unpack it.
|
| 274 |
+
# autograd.Function prefers that users use ctx.save_for_backward to
|
| 275 |
+
# save Tensors (to avoid reference cycles) and for non-Tensors to go onto the
|
| 276 |
+
# ctx object.
|
| 277 |
+
def save_pytree_for_backward(ctx, stuff):
|
| 278 |
+
flat_stuff, spec = pytree.tree_flatten(stuff)
|
| 279 |
+
num_elts = len(flat_stuff)
|
| 280 |
+
tensor_idxs = [
|
| 281 |
+
idx for idx, thing in enumerate(flat_stuff) if isinstance(thing, torch.Tensor)
|
| 282 |
+
]
|
| 283 |
+
non_tensor_idxs = [
|
| 284 |
+
idx
|
| 285 |
+
for idx, thing in enumerate(flat_stuff)
|
| 286 |
+
if not isinstance(thing, torch.Tensor)
|
| 287 |
+
]
|
| 288 |
+
tensors = [thing for thing in flat_stuff if isinstance(thing, torch.Tensor)]
|
| 289 |
+
non_tensors = [thing for thing in flat_stuff if not isinstance(thing, torch.Tensor)]
|
| 290 |
+
|
| 291 |
+
ctx.spec = spec
|
| 292 |
+
ctx.num_elts = num_elts
|
| 293 |
+
ctx.save_for_backward(*tensors)
|
| 294 |
+
ctx.tensor_idxs = tensor_idxs
|
| 295 |
+
ctx.saved_non_tensors = non_tensors
|
| 296 |
+
ctx.non_tensor_idxs = non_tensor_idxs
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
# Inverse operation to save_pytree_for_backward
|
| 300 |
+
def unpack_saved(ctx):
|
| 301 |
+
flat_stuff = [None] * ctx.num_elts
|
| 302 |
+
for tensor, idx in zip(ctx.saved_tensors, ctx.tensor_idxs):
|
| 303 |
+
flat_stuff[idx] = tensor
|
| 304 |
+
for non_tensor, idx in zip(ctx.saved_non_tensors, ctx.non_tensor_idxs):
|
| 305 |
+
flat_stuff[idx] = non_tensor
|
| 306 |
+
stuff = pytree.tree_unflatten(flat_stuff, ctx.spec)
|
| 307 |
+
return stuff
|
phivenv/Lib/site-packages/torch/_custom_op/impl.py
ADDED
|
@@ -0,0 +1,715 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import dataclasses
|
| 3 |
+
import functools
|
| 4 |
+
import inspect
|
| 5 |
+
import sys
|
| 6 |
+
import typing
|
| 7 |
+
import warnings
|
| 8 |
+
import weakref
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch._C as _C
|
| 12 |
+
import torch._library.infer_schema
|
| 13 |
+
import torch.library as library
|
| 14 |
+
from torch._library.infer_schema import infer_schema
|
| 15 |
+
from torch.library import get_ctx
|
| 16 |
+
from torchgen.model import (
|
| 17 |
+
BaseTy,
|
| 18 |
+
BaseType,
|
| 19 |
+
FunctionSchema,
|
| 20 |
+
ListType,
|
| 21 |
+
OperatorName,
|
| 22 |
+
SchemaKind,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
from .autograd import autograd_kernel_indirection, construct_autograd_kernel
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
"""
|
| 29 |
+
torch._custom_op is deprecated. We shipped a production-ready version of it into torch.library.
|
| 30 |
+
Please use those APIs instead.
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
__all__ = ["custom_op", "CustomOp", "get_ctx"]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
SUPPORTED_DEVICE_TYPE_TO_KEY = {
|
| 37 |
+
"cpu": "CPU",
|
| 38 |
+
"cuda": "CUDA",
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
# We will not let users register CustomOps with anything that could look like
|
| 42 |
+
# PyTorch internals to avoid confusion.
|
| 43 |
+
RESERVED_NS = {
|
| 44 |
+
"prim",
|
| 45 |
+
"prims",
|
| 46 |
+
"aten",
|
| 47 |
+
"at",
|
| 48 |
+
"torch",
|
| 49 |
+
"pytorch",
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def warn_deprecated():
|
| 54 |
+
warnings.warn(
|
| 55 |
+
"torch._custom_op is deprecated and will be removed in PyTorch 2.6, please "
|
| 56 |
+
"use the equivalent torch.library API instead.",
|
| 57 |
+
DeprecationWarning,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def custom_op(
|
| 62 |
+
qualname: str, manual_schema: typing.Optional[str] = None
|
| 63 |
+
) -> typing.Callable:
|
| 64 |
+
r"""
|
| 65 |
+
This API is deprecated, please use torch.library.custom_op instead
|
| 66 |
+
"""
|
| 67 |
+
warn_deprecated()
|
| 68 |
+
|
| 69 |
+
def inner(func):
|
| 70 |
+
if not inspect.isfunction(func):
|
| 71 |
+
raise ValueError(
|
| 72 |
+
f"custom_op(...)(func): Expected `func` to be a Python "
|
| 73 |
+
f"function, got: {type(func)}"
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
ns, name = parse_qualname(qualname)
|
| 77 |
+
validate_namespace(ns)
|
| 78 |
+
if func.__name__ != name:
|
| 79 |
+
raise ValueError(
|
| 80 |
+
f"custom_op(qualname='{qualname}', ...)(func): expected `func` "
|
| 81 |
+
f"to have name '{name}' but got '{func.__name__}'. "
|
| 82 |
+
f"Please either change the name of `func` or the qualname that "
|
| 83 |
+
f"is passed to `custom_op`"
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
schema = (
|
| 87 |
+
infer_schema(func, mutates_args=())
|
| 88 |
+
if manual_schema is None
|
| 89 |
+
else manual_schema
|
| 90 |
+
)
|
| 91 |
+
schema_str = f"{name}{schema}"
|
| 92 |
+
function_schema = FunctionSchema.parse(schema_str)
|
| 93 |
+
validate_schema(function_schema)
|
| 94 |
+
if manual_schema is not None:
|
| 95 |
+
validate_function_matches_schema(function_schema, func)
|
| 96 |
+
|
| 97 |
+
lib = library.Library(ns, "FRAGMENT")
|
| 98 |
+
lib.define(schema_str)
|
| 99 |
+
ophandle = find_ophandle_or_throw(ns, function_schema.name)
|
| 100 |
+
result = CustomOp(
|
| 101 |
+
lib, ns, function_schema, name, ophandle, _private_access=True
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
result.__name__ = func.__name__
|
| 105 |
+
result.__module__ = func.__module__
|
| 106 |
+
result.__doc__ = func.__doc__
|
| 107 |
+
|
| 108 |
+
library.impl(lib, result._opname, "Autograd")(
|
| 109 |
+
autograd_kernel_indirection(weakref.proxy(result))
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
torch._C._dispatch_set_report_error_callback(
|
| 113 |
+
ophandle, functools.partial(report_error_callback, weakref.proxy(result))
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
return result
|
| 117 |
+
|
| 118 |
+
return inner
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
# Global dictionary holding references to all CustomOp objects
|
| 122 |
+
# Yes, it keeps all CustomOps alive (see NOTE [CustomOp lifetime])
|
| 123 |
+
# Used to query the CustomOp associated with a specific C++ dispatcher operator.
|
| 124 |
+
# An example usage is FakeTensor: FakeTensor checks if a specific operator
|
| 125 |
+
# has an implementation registered via the CustomOp API.
|
| 126 |
+
# Indexed by qualname (e.g. aten::foo)
|
| 127 |
+
global_registry: dict[str, "CustomOp"] = {}
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class CustomOp:
|
| 131 |
+
r"""
|
| 132 |
+
This API is deprecated, please use torch.library.custom_op instead
|
| 133 |
+
"""
|
| 134 |
+
|
| 135 |
+
def __init__(
|
| 136 |
+
self, lib, cpp_ns, schema, operator_name, ophandle, *, _private_access=False
|
| 137 |
+
):
|
| 138 |
+
super().__init__()
|
| 139 |
+
warn_deprecated()
|
| 140 |
+
if not _private_access:
|
| 141 |
+
raise RuntimeError(
|
| 142 |
+
"The CustomOp constructor is private and we do not guarantee "
|
| 143 |
+
"BC for it. Please use custom_op(...) to create a CustomOp object"
|
| 144 |
+
)
|
| 145 |
+
name = f"{cpp_ns}::{operator_name}"
|
| 146 |
+
self._schema = schema
|
| 147 |
+
self._cpp_ns = cpp_ns
|
| 148 |
+
self._lib: library.Library = lib
|
| 149 |
+
self._ophandle: _C._DispatchOperatorHandle = ophandle
|
| 150 |
+
# Has the name of the op, e.g. "foo". We cache here for convenience.
|
| 151 |
+
self._opname: str = operator_name
|
| 152 |
+
# this is _opname but with namespace. e.g. "custom::foo"
|
| 153 |
+
self._qualname: str = name
|
| 154 |
+
self.__name__ = None # mypy requires this
|
| 155 |
+
# NB: Some of these impls are registered as kernels to DispatchKeys.
|
| 156 |
+
# Modifying the _impls dict directly won't do anything in that case.
|
| 157 |
+
self._impls: dict[str, typing.Optional[FuncAndLocation]] = {}
|
| 158 |
+
# See NOTE [CustomOp autograd kernel indirection]
|
| 159 |
+
self._registered_autograd_kernel_indirection = False
|
| 160 |
+
|
| 161 |
+
global_registry[self._qualname] = self
|
| 162 |
+
|
| 163 |
+
def _register_autograd_kernel_indirection(self):
|
| 164 |
+
assert not self._registered_autograd_kernel_indirection
|
| 165 |
+
self._lib.impl(
|
| 166 |
+
self._opname, autograd_kernel_indirection(weakref.proxy(self)), "Autograd"
|
| 167 |
+
)
|
| 168 |
+
self._registered_autograd_kernel_indirection = True
|
| 169 |
+
|
| 170 |
+
# Records the impl and the source location in self._impls
|
| 171 |
+
# Note that this doesn't cause torch.library to use the impl, that
|
| 172 |
+
# needs to be done in a separate self._lib.impl call.
|
| 173 |
+
def _register_impl(self, kind, func, stacklevel=2):
|
| 174 |
+
if self._has_impl(kind):
|
| 175 |
+
func_and_location = self._impls[kind]
|
| 176 |
+
assert func_and_location is not None # Pacify mypy
|
| 177 |
+
location = func_and_location.location
|
| 178 |
+
raise RuntimeError(
|
| 179 |
+
f"Attempting to register a {kind} impl for operator {self._qualname} "
|
| 180 |
+
f"that already has a {kind} impl registered from Python at "
|
| 181 |
+
f"{location}. This is not supported."
|
| 182 |
+
)
|
| 183 |
+
frame = inspect.getframeinfo(sys._getframe(stacklevel))
|
| 184 |
+
location = f"{frame.filename}:{frame.lineno}"
|
| 185 |
+
self._impls[kind] = FuncAndLocation(func, location)
|
| 186 |
+
|
| 187 |
+
def _get_impl(self, kind):
|
| 188 |
+
return self._impls[kind]
|
| 189 |
+
|
| 190 |
+
def _has_impl(self, kind):
|
| 191 |
+
return kind in self._impls
|
| 192 |
+
|
| 193 |
+
def _destroy(self):
|
| 194 |
+
# NOTE: [CustomOp lifetime]
|
| 195 |
+
# A CustomOp, once created, lives forever. The mechanism is that the
|
| 196 |
+
# global registry holds a reference to it. However, to make testing
|
| 197 |
+
# easier, we want to be able to destroy CustomOp objects.
|
| 198 |
+
# CustomOp._destroy does the job, though it leaves the CustomOp
|
| 199 |
+
# in a garbage state.
|
| 200 |
+
del self._lib
|
| 201 |
+
|
| 202 |
+
opnamespace = getattr(torch.ops, self._cpp_ns)
|
| 203 |
+
if hasattr(opnamespace, self._opname):
|
| 204 |
+
delattr(opnamespace, self._opname)
|
| 205 |
+
|
| 206 |
+
del global_registry[self._qualname]
|
| 207 |
+
|
| 208 |
+
def __repr__(self):
|
| 209 |
+
return f'<CustomOp(op="{self._qualname}")>'
|
| 210 |
+
|
| 211 |
+
def __call__(self, *args, **kwargs):
|
| 212 |
+
# Bypass torch.ops.* and directly do OperatorHandle::callBoxed.
|
| 213 |
+
# Using torch.ops.* is a bit of a pain (it can be slow and it has lifetime
|
| 214 |
+
# issues from caching operators that make testing CustomOp difficult).
|
| 215 |
+
result = _C._dispatch_call_boxed(self._ophandle, *args, **kwargs)
|
| 216 |
+
return result
|
| 217 |
+
|
| 218 |
+
def impl(
|
| 219 |
+
self,
|
| 220 |
+
device_types: typing.Union[str, typing.Iterable[str]],
|
| 221 |
+
_stacklevel=2,
|
| 222 |
+
) -> typing.Callable:
|
| 223 |
+
r"""
|
| 224 |
+
This API is deprecated, please use torch.library.custom_op instead
|
| 225 |
+
"""
|
| 226 |
+
if isinstance(device_types, str):
|
| 227 |
+
device_types = [device_types]
|
| 228 |
+
for device_type in device_types:
|
| 229 |
+
validate_device_type(device_type)
|
| 230 |
+
|
| 231 |
+
def inner(f):
|
| 232 |
+
for device_type in set(device_types):
|
| 233 |
+
self._check_doesnt_have_library_impl(device_type)
|
| 234 |
+
self._register_impl(device_type, f, stacklevel=_stacklevel)
|
| 235 |
+
dispatch_key = SUPPORTED_DEVICE_TYPE_TO_KEY[device_type]
|
| 236 |
+
library.impl(self._lib, self._opname, dispatch_key)(f)
|
| 237 |
+
return f
|
| 238 |
+
|
| 239 |
+
return inner
|
| 240 |
+
|
| 241 |
+
def _check_doesnt_have_library_impl(self, device_type):
|
| 242 |
+
if self._has_impl(device_type):
|
| 243 |
+
return
|
| 244 |
+
key = SUPPORTED_DEVICE_TYPE_TO_KEY[device_type]
|
| 245 |
+
if _C._dispatch_has_computed_kernel_for_dispatch_key(self._qualname, key):
|
| 246 |
+
raise RuntimeError(
|
| 247 |
+
f"impl(..., device_types={device_type}): the operator {self._qualname} "
|
| 248 |
+
f"already has an implementation for this device type via a "
|
| 249 |
+
f"pre-existing torch.library or TORCH_LIBRARY registration."
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
def impl_factory(self) -> typing.Callable:
|
| 253 |
+
r"""Register an implementation for a factory function."""
|
| 254 |
+
|
| 255 |
+
def inner(f):
|
| 256 |
+
self._register_impl("factory", f)
|
| 257 |
+
library.impl(self._lib, self._opname, "BackendSelect")(f)
|
| 258 |
+
return f
|
| 259 |
+
|
| 260 |
+
return inner
|
| 261 |
+
|
| 262 |
+
def impl_abstract(self, _stacklevel=2) -> typing.Callable:
|
| 263 |
+
r"""
|
| 264 |
+
This API is deprecated, please use torch.library.custom_op instead
|
| 265 |
+
"""
|
| 266 |
+
|
| 267 |
+
def inner(f):
|
| 268 |
+
self._check_doesnt_have_library_meta_impl()
|
| 269 |
+
self._register_impl("abstract", f, stacklevel=_stacklevel)
|
| 270 |
+
location = self._get_impl("abstract").location
|
| 271 |
+
|
| 272 |
+
qualname = self._qualname
|
| 273 |
+
|
| 274 |
+
# Handle DispatchKey.Meta registration
|
| 275 |
+
@functools.wraps(f)
|
| 276 |
+
def f_with_ctx(*args, **kwargs):
|
| 277 |
+
def error_on_ctx():
|
| 278 |
+
raise RuntimeError(
|
| 279 |
+
f"Attempted to call get_ctx() for the meta implementation "
|
| 280 |
+
f"for {qualname}."
|
| 281 |
+
f"You have presumably called get_ctx() because the operator "
|
| 282 |
+
f"has a data-dependent output shape; if so, there is no "
|
| 283 |
+
f"such meta implementation and this error is the correct "
|
| 284 |
+
f"behavior. Otherwise, please remove the call to get_ctx() "
|
| 285 |
+
f"in the implementation registered with impl_abstract "
|
| 286 |
+
f"at {location}"
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
with torch._library.fake_impl.set_ctx_getter(error_on_ctx):
|
| 290 |
+
return f(*args, **kwargs)
|
| 291 |
+
|
| 292 |
+
self._lib.impl(self._opname, f_with_ctx, "Meta")
|
| 293 |
+
return f
|
| 294 |
+
|
| 295 |
+
return inner
|
| 296 |
+
|
| 297 |
+
def _check_can_register_backward(self):
|
| 298 |
+
def error(detail):
|
| 299 |
+
raise RuntimeError(
|
| 300 |
+
f"Cannot use torch._custom_ops APIs to register backward "
|
| 301 |
+
f"formula for {detail}. Got operator "
|
| 302 |
+
f"{self._qualname} with schema: {schema}"
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
schema = self._schema
|
| 306 |
+
if schema.kind() != SchemaKind.functional:
|
| 307 |
+
error("non-functional operator")
|
| 308 |
+
|
| 309 |
+
rets = schema.returns
|
| 310 |
+
if not schema.returns:
|
| 311 |
+
error("operator with no returns")
|
| 312 |
+
|
| 313 |
+
assert len(rets) > 0
|
| 314 |
+
is_non_mutating_view = any(
|
| 315 |
+
r.annotation is not None and not r.annotation.is_write for r in rets
|
| 316 |
+
)
|
| 317 |
+
if is_non_mutating_view:
|
| 318 |
+
error("operator that returns views")
|
| 319 |
+
|
| 320 |
+
# We make assumptions about the schema's return types.
|
| 321 |
+
allowed_return_types = {
|
| 322 |
+
BaseType(BaseTy.int): "int",
|
| 323 |
+
BaseType(BaseTy.SymInt): "SymInt",
|
| 324 |
+
BaseType(BaseTy.bool): "bool",
|
| 325 |
+
BaseType(BaseTy.float): "float",
|
| 326 |
+
BaseType(BaseTy.Tensor): "Tensor",
|
| 327 |
+
ListType(BaseType(BaseTy.Tensor), None): "List[Tensor]",
|
| 328 |
+
}
|
| 329 |
+
for ret in schema.returns:
|
| 330 |
+
if ret.type in allowed_return_types:
|
| 331 |
+
continue
|
| 332 |
+
error(
|
| 333 |
+
f"operator with return not in {list(allowed_return_types.values())} (got {ret.type})"
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
def _check_doesnt_have_library_autograd_impl(self):
|
| 337 |
+
if self._registered_autograd_kernel_indirection:
|
| 338 |
+
return
|
| 339 |
+
|
| 340 |
+
if _C._dispatch_has_kernel_for_dispatch_key(
|
| 341 |
+
self._qualname, "CompositeImplicitAutograd"
|
| 342 |
+
):
|
| 343 |
+
raise RuntimeError(
|
| 344 |
+
f"impl_backward/impl_save_for_backward: the operator {self._qualname} "
|
| 345 |
+
f"already has an implementation for this device type via a "
|
| 346 |
+
f"pre-existing registration to DispatchKey::CompositeImplicitAutograd."
|
| 347 |
+
f"CompositeImplicitAutograd operators do not need an autograd formula; "
|
| 348 |
+
f"instead, the operator will decompose into its constituents and those "
|
| 349 |
+
f"can have autograd formulas defined on them."
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
# We can improve this by adding "all Autograd<BACKEND> keys", but
|
| 353 |
+
# realistically people will just be using this API for CPU/CUDA for now.
|
| 354 |
+
for key in ["Autograd", "AutogradCPU", "AutogradCUDA"]:
|
| 355 |
+
if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, key):
|
| 356 |
+
raise RuntimeError(
|
| 357 |
+
f"impl_backward/impl_save_for_backward: "
|
| 358 |
+
f"the operator {self._qualname} already has an Autograd kernel "
|
| 359 |
+
f"registered to DispatchKey::{key} vi a pre-existing "
|
| 360 |
+
f"torch.library or TORCH_LIBRARY registration. Please either "
|
| 361 |
+
f"remove those registrations or don't use the torch._custom_ops APIs"
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
def _check_doesnt_have_library_meta_impl(self):
|
| 365 |
+
if self._has_impl("abstract"):
|
| 366 |
+
return
|
| 367 |
+
|
| 368 |
+
# If the user's operator is CompositeExplicitAutograd,
|
| 369 |
+
# allow them to impl_abstract. This is being pragmatic
|
| 370 |
+
# (existing custom ops may have CompositeExplicitAutograd
|
| 371 |
+
# registration that don't work with Meta kernels, so this
|
| 372 |
+
# gives them an escape hatch).
|
| 373 |
+
if _C._dispatch_has_kernel_for_dispatch_key(
|
| 374 |
+
self._qualname, "CompositeExplicitAutograd"
|
| 375 |
+
) and not _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "Meta"):
|
| 376 |
+
return
|
| 377 |
+
|
| 378 |
+
# Otherwise, if the user's already has a Meta kernel or their
|
| 379 |
+
# op is CompositeImplicitAutograd or some other alias dispatch key,
|
| 380 |
+
# raise.
|
| 381 |
+
|
| 382 |
+
# Special case for CompositeImplicitAutograd
|
| 383 |
+
if _C._dispatch_has_kernel_for_dispatch_key(
|
| 384 |
+
self._qualname, "CompositeImplicitAutograd"
|
| 385 |
+
):
|
| 386 |
+
raise RuntimeError(
|
| 387 |
+
f"impl_abstract(...): the operator {self._qualname} "
|
| 388 |
+
f"already has an implementation for this device type via a "
|
| 389 |
+
f"pre-existing registration to DispatchKey::CompositeImplicitAutograd."
|
| 390 |
+
f"CompositeImplicitAutograd operators do not need an abstract impl; "
|
| 391 |
+
f"instead, the operator will decompose into its constituents and those "
|
| 392 |
+
f"can have abstract impls defined on them."
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "Meta"):
|
| 396 |
+
raise RuntimeError(
|
| 397 |
+
f"impl_abstract(...): the operator {self._qualname} "
|
| 398 |
+
f"already has an DispatchKey::Meta implementation via a "
|
| 399 |
+
f"pre-existing torch.library or TORCH_LIBRARY registration. "
|
| 400 |
+
f"Please either remove that registration or don't call impl_abstract."
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
# NOTE ["backward", "save_for_backward", and "autograd"]
|
| 404 |
+
# As a part of the explicit autograd API, a user must provide us
|
| 405 |
+
# a "save_for_backward" function and a "backward" function.
|
| 406 |
+
# When both of these have been provided, then we automatically
|
| 407 |
+
# construct the "autograd" kernel.
|
| 408 |
+
def _register_autograd_kernel(self):
|
| 409 |
+
assert self._has_impl("backward")
|
| 410 |
+
assert self._has_impl("save_for_backward")
|
| 411 |
+
kernel = construct_autograd_kernel(
|
| 412 |
+
self._schema,
|
| 413 |
+
self._output_differentiability,
|
| 414 |
+
self,
|
| 415 |
+
get_op(self._qualname),
|
| 416 |
+
self._get_impl("save_for_backward").func,
|
| 417 |
+
self._get_impl("backward").func,
|
| 418 |
+
)
|
| 419 |
+
self._register_impl("autograd", kernel)
|
| 420 |
+
|
| 421 |
+
def impl_save_for_backward(self, _stacklevel=2):
|
| 422 |
+
r"""Register a function that tells us what to save for backward.
|
| 423 |
+
|
| 424 |
+
Please see impl_backward for more details.
|
| 425 |
+
"""
|
| 426 |
+
|
| 427 |
+
def inner(f):
|
| 428 |
+
self._check_can_register_backward()
|
| 429 |
+
self._check_doesnt_have_library_autograd_impl()
|
| 430 |
+
if not self._registered_autograd_kernel_indirection:
|
| 431 |
+
self._register_autograd_kernel_indirection()
|
| 432 |
+
self._register_impl("save_for_backward", f, stacklevel=_stacklevel)
|
| 433 |
+
if self._has_impl("backward"):
|
| 434 |
+
self._register_autograd_kernel()
|
| 435 |
+
|
| 436 |
+
return inner
|
| 437 |
+
|
| 438 |
+
def impl_backward(self, output_differentiability=None, _stacklevel=2):
|
| 439 |
+
r"""
|
| 440 |
+
This API is deprecated, please use torch.library.custom_op instead
|
| 441 |
+
"""
|
| 442 |
+
if output_differentiability is not None:
|
| 443 |
+
|
| 444 |
+
def yell():
|
| 445 |
+
raise RuntimeError(
|
| 446 |
+
f"impl_backward(output_differentiability): expected "
|
| 447 |
+
f"output_differentiability to be a list of bools with "
|
| 448 |
+
f"length equal to the number of outputs of this CustomOp "
|
| 449 |
+
f"got: {output_differentiability}"
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
if not isinstance(output_differentiability, list):
|
| 453 |
+
yell()
|
| 454 |
+
for diff in output_differentiability:
|
| 455 |
+
if not isinstance(diff, bool):
|
| 456 |
+
yell()
|
| 457 |
+
if len(self._schema.returns) != len(output_differentiability):
|
| 458 |
+
yell()
|
| 459 |
+
|
| 460 |
+
def inner(f):
|
| 461 |
+
self._check_can_register_backward()
|
| 462 |
+
self._check_doesnt_have_library_autograd_impl()
|
| 463 |
+
if not self._registered_autograd_kernel_indirection:
|
| 464 |
+
self._register_autograd_kernel_indirection()
|
| 465 |
+
self._register_impl("backward", f, stacklevel=_stacklevel)
|
| 466 |
+
self._output_differentiability = output_differentiability
|
| 467 |
+
if self._has_impl("save_for_backward"):
|
| 468 |
+
self._register_autograd_kernel()
|
| 469 |
+
|
| 470 |
+
return inner
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
@dataclasses.dataclass
|
| 474 |
+
class FuncAndLocation:
|
| 475 |
+
func: typing.Callable
|
| 476 |
+
location: str
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
def find_ophandle_or_throw(cpp_ns: str, operator_name: OperatorName):
|
| 480 |
+
overload_name = (
|
| 481 |
+
"" if operator_name.overload_name is None else operator_name.overload_name
|
| 482 |
+
)
|
| 483 |
+
return _C._dispatch_find_schema_or_throw(
|
| 484 |
+
f"{cpp_ns}::{str(operator_name.name)}", overload_name
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
def validate_namespace(ns: str) -> None:
|
| 489 |
+
if "." in ns:
|
| 490 |
+
raise ValueError(
|
| 491 |
+
f'custom_op(..., ns="{ns}"): expected ns to not contain any . (and be a '
|
| 492 |
+
f"valid variable name)"
|
| 493 |
+
)
|
| 494 |
+
if ns in RESERVED_NS:
|
| 495 |
+
raise ValueError(
|
| 496 |
+
f"custom_op(..., ns='{ns}'): '{ns}' is a reserved namespace, "
|
| 497 |
+
f"please choose something else. "
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
def validate_schema(schema: FunctionSchema) -> None:
|
| 502 |
+
if not torch._library.utils.is_functional_schema(schema):
|
| 503 |
+
raise ValueError(
|
| 504 |
+
f"custom_op only supports functional operators "
|
| 505 |
+
f"(ops that do not mutate any inputs, do not return "
|
| 506 |
+
f"views of the inputs, and has at least one return). "
|
| 507 |
+
f"Got the following non-functional schema: {schema}"
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
# For simplicity: don't allow self arguments
|
| 511 |
+
if schema.arguments.self_arg is not None:
|
| 512 |
+
raise ValueError(
|
| 513 |
+
f"custom_op does not support arguments named 'self'. Please "
|
| 514 |
+
f"rename your argument. Got: {schema}"
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
def parse_qualname(qualname: str) -> tuple[str, str]:
|
| 519 |
+
names = qualname.split("::", 1)
|
| 520 |
+
if len(names) != 2:
|
| 521 |
+
raise ValueError(
|
| 522 |
+
f"Expected there to be a namespace in {qualname}, i.e. The "
|
| 523 |
+
f"operator name should look something like ns::foo"
|
| 524 |
+
)
|
| 525 |
+
if "." in names[1]:
|
| 526 |
+
raise ValueError(
|
| 527 |
+
f"The torch.custom_ops APIs do not handle overloads, "
|
| 528 |
+
f"i.e. operator names with '.' in them. "
|
| 529 |
+
f"Please name your operator something like ns::foo. "
|
| 530 |
+
f"Got: {qualname}"
|
| 531 |
+
)
|
| 532 |
+
return names[0], names[1]
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
def validate_device_type(device_type: str) -> None:
|
| 536 |
+
if device_type not in SUPPORTED_DEVICE_TYPE_TO_KEY:
|
| 537 |
+
raise ValueError(
|
| 538 |
+
f"CustomOp.impl(device_types=[{device_type}, ...]): we only support device_type "
|
| 539 |
+
f"in {SUPPORTED_DEVICE_TYPE_TO_KEY.keys()}."
|
| 540 |
+
)
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
def supported_param(param: inspect.Parameter) -> bool:
|
| 544 |
+
return param.kind in (
|
| 545 |
+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
| 546 |
+
inspect.Parameter.KEYWORD_ONLY,
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
|
| 550 |
+
def validate_function_matches_schema(
|
| 551 |
+
schema: FunctionSchema, func: typing.Callable
|
| 552 |
+
) -> None:
|
| 553 |
+
sig = inspect.signature(func)
|
| 554 |
+
|
| 555 |
+
if not all(supported_param(p) for _, p in sig.parameters.items()):
|
| 556 |
+
raise ValueError(
|
| 557 |
+
f"custom_op(..., manual_schema)(func): positional-only args, "
|
| 558 |
+
f"varargs, and kwargs are not supported. Please rewrite `func` "
|
| 559 |
+
f"to not have them. Got `func` with signature: {sig}"
|
| 560 |
+
)
|
| 561 |
+
|
| 562 |
+
if (
|
| 563 |
+
any(
|
| 564 |
+
p.annotation is not inspect.Parameter.empty
|
| 565 |
+
for _, p in sig.parameters.items()
|
| 566 |
+
)
|
| 567 |
+
or sig.return_annotation is not inspect.Signature.empty
|
| 568 |
+
):
|
| 569 |
+
raise ValueError(
|
| 570 |
+
f"custom_op(..., manual_schema)(func): When passing in a manual "
|
| 571 |
+
f"schema, we expect `func` to have no type annotations to avoid "
|
| 572 |
+
f"ambiguity. Got `func` with signature: {sig}"
|
| 573 |
+
)
|
| 574 |
+
|
| 575 |
+
positional = [
|
| 576 |
+
(name, param)
|
| 577 |
+
for name, param in sig.parameters.items()
|
| 578 |
+
if param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
| 579 |
+
]
|
| 580 |
+
kwargonly = [
|
| 581 |
+
(name, param)
|
| 582 |
+
for name, param in sig.parameters.items()
|
| 583 |
+
if param.kind == inspect.Parameter.KEYWORD_ONLY
|
| 584 |
+
]
|
| 585 |
+
|
| 586 |
+
def error():
|
| 587 |
+
raise ValueError(
|
| 588 |
+
f"custom_op(..., manual_schema)(func): When passing in a manual "
|
| 589 |
+
f"schema, we expect `func`'s signature to match `manual_schema` "
|
| 590 |
+
f"(aside from type annotations). "
|
| 591 |
+
f"func's signature: {sig}, manual_schema: {schema}"
|
| 592 |
+
)
|
| 593 |
+
|
| 594 |
+
def error_default_args():
|
| 595 |
+
raise ValueError(
|
| 596 |
+
f"custom_op(..., manual_schema)(func): "
|
| 597 |
+
f"neither func nor manual_schema should have default "
|
| 598 |
+
f"arguments. Got "
|
| 599 |
+
f"func's signature: {sig}, manual_schema: {schema}"
|
| 600 |
+
)
|
| 601 |
+
|
| 602 |
+
def compare(sig_args, schema_args):
|
| 603 |
+
if len(sig_args) != len(schema_args):
|
| 604 |
+
error()
|
| 605 |
+
for (name, param), arg in zip(sig_args, schema_args):
|
| 606 |
+
if name != arg.name:
|
| 607 |
+
error()
|
| 608 |
+
if param.default is not inspect.Parameter.empty or arg.default is not None:
|
| 609 |
+
error_default_args()
|
| 610 |
+
|
| 611 |
+
compare(positional, schema.arguments.flat_positional)
|
| 612 |
+
compare(kwargonly, schema.arguments.flat_kwarg_only)
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
def report_error_callback(custom_op: typing.Any, key: str) -> None:
|
| 616 |
+
if key == "Undefined":
|
| 617 |
+
raise NotImplementedError(
|
| 618 |
+
f"{custom_op}: There were no Tensor inputs to this operator "
|
| 619 |
+
f"(e.g. you passed an empty list of Tensors). If your operator is a "
|
| 620 |
+
f"factory function (that is, it takes no Tensors and constructs "
|
| 621 |
+
f"a new one), then please use CustomOp.impl_factory to register "
|
| 622 |
+
f"an implementation for it"
|
| 623 |
+
)
|
| 624 |
+
if key == "Meta":
|
| 625 |
+
raise NotImplementedError(
|
| 626 |
+
f"{custom_op}: when running with device='Meta' tensors: there is no "
|
| 627 |
+
f"abstract impl registered for this CustomOp. Please register one via "
|
| 628 |
+
f"CustomOp.impl_abstract to get this CustomOp to work with Meta tensors"
|
| 629 |
+
)
|
| 630 |
+
if key in ("CPU", "CUDA"):
|
| 631 |
+
device = key.lower()
|
| 632 |
+
raise NotImplementedError(
|
| 633 |
+
f"{custom_op}: when running with device='{device}' tensors: there is no "
|
| 634 |
+
f"{device} impl registered for this CustomOp. Please register one via "
|
| 635 |
+
f"CustomOp.impl(device_type='{device}')"
|
| 636 |
+
)
|
| 637 |
+
raise NotImplementedError(
|
| 638 |
+
f"{custom_op}: No implementation for dispatch key {key}. It is likely "
|
| 639 |
+
f"that we have not added this functionality yet, please either open an "
|
| 640 |
+
f"issue or if you're feeling adventurous, use the low-level "
|
| 641 |
+
f"torch.library API"
|
| 642 |
+
)
|
| 643 |
+
|
| 644 |
+
|
| 645 |
+
def custom_op_from_existing(op):
|
| 646 |
+
ns = op.namespace
|
| 647 |
+
lib = torch.library.Library(ns, "FRAGMENT")
|
| 648 |
+
name = op.name().split("::")[-1]
|
| 649 |
+
schema_str = str(op._schema)
|
| 650 |
+
# CustomOp expects the schema string without the namespace
|
| 651 |
+
schema_str = schema_str.split("::")[-1]
|
| 652 |
+
schema = FunctionSchema.parse(schema_str)
|
| 653 |
+
return CustomOp(lib, ns, schema, name, op, _private_access=True)
|
| 654 |
+
|
| 655 |
+
|
| 656 |
+
def get_op(qualname):
|
| 657 |
+
def error_not_found():
|
| 658 |
+
raise ValueError(
|
| 659 |
+
f"Could not find the operator {qualname}. Please make sure you have "
|
| 660 |
+
f"already registered the operator and (if registered from C++) "
|
| 661 |
+
f"loaded it via torch.ops.load_library."
|
| 662 |
+
)
|
| 663 |
+
|
| 664 |
+
ns, name = parse_qualname(qualname)
|
| 665 |
+
if not hasattr(torch.ops, ns):
|
| 666 |
+
error_not_found()
|
| 667 |
+
opnamespace = getattr(torch.ops, ns)
|
| 668 |
+
if not hasattr(opnamespace, name):
|
| 669 |
+
error_not_found()
|
| 670 |
+
packet = getattr(opnamespace, name)
|
| 671 |
+
if not hasattr(packet, "default"):
|
| 672 |
+
error_not_found()
|
| 673 |
+
return packet.default
|
| 674 |
+
|
| 675 |
+
|
| 676 |
+
def _find_custom_op(qualname, also_check_torch_library=False):
|
| 677 |
+
if qualname in global_registry:
|
| 678 |
+
return global_registry[qualname]
|
| 679 |
+
if not also_check_torch_library:
|
| 680 |
+
raise RuntimeError(
|
| 681 |
+
f'Could not find custom op "{qualname}". Did you register it via '
|
| 682 |
+
f"the torch._custom_ops API?"
|
| 683 |
+
)
|
| 684 |
+
overload = get_op(qualname)
|
| 685 |
+
result = custom_op_from_existing(overload)
|
| 686 |
+
return result
|
| 687 |
+
|
| 688 |
+
|
| 689 |
+
def get_abstract_impl(qualname):
|
| 690 |
+
if qualname not in torch._custom_op.impl.global_registry:
|
| 691 |
+
return None
|
| 692 |
+
custom_op = torch._custom_op.impl.global_registry[qualname]
|
| 693 |
+
if custom_op is None:
|
| 694 |
+
return None
|
| 695 |
+
if not custom_op._has_impl("abstract"):
|
| 696 |
+
return None
|
| 697 |
+
return custom_op._get_impl("abstract").func
|
| 698 |
+
|
| 699 |
+
|
| 700 |
+
def _custom_op_with_schema(qualname, schema, needs_fixed_stride_order=True):
|
| 701 |
+
ns, name = qualname.split("::")
|
| 702 |
+
schema_str = f"{name}{schema}"
|
| 703 |
+
function_schema = FunctionSchema.parse(schema_str)
|
| 704 |
+
validate_schema(function_schema)
|
| 705 |
+
tags = [torch._C.Tag.needs_fixed_stride_order] if needs_fixed_stride_order else []
|
| 706 |
+
lib = library.Library(ns, "FRAGMENT")
|
| 707 |
+
lib.define(schema_str, tags=tags)
|
| 708 |
+
ophandle = find_ophandle_or_throw(ns, function_schema.name)
|
| 709 |
+
result = CustomOp(lib, ns, function_schema, name, ophandle, _private_access=True)
|
| 710 |
+
result._register_autograd_kernel_indirection()
|
| 711 |
+
|
| 712 |
+
torch._C._dispatch_set_report_error_callback(
|
| 713 |
+
ophandle, functools.partial(report_error_callback, weakref.proxy(result))
|
| 714 |
+
)
|
| 715 |
+
return get_op(qualname)
|
phivenv/Lib/site-packages/torch/_decomp/__init__.py
ADDED
|
@@ -0,0 +1,544 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import inspect
|
| 3 |
+
from collections import defaultdict
|
| 4 |
+
from collections.abc import Sequence
|
| 5 |
+
from functools import lru_cache, partial, wraps
|
| 6 |
+
from itertools import chain
|
| 7 |
+
from typing import Callable, Optional, TYPE_CHECKING, TypeVar, Union
|
| 8 |
+
from typing_extensions import ParamSpec
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
if TYPE_CHECKING:
|
| 12 |
+
from torch.export.decomp_utils import CustomDecompTable
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.library
|
| 16 |
+
from torch._ops import HigherOrderOperator, OperatorBase, OpOverload, OpOverloadPacket
|
| 17 |
+
from torch._prims_common import CustomOutParamAnnotation
|
| 18 |
+
from torch._subclasses.functional_tensor import FunctionalTensor
|
| 19 |
+
from torch.utils import _pytree as pytree
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
__all__ = [
|
| 23 |
+
"decomposition_table",
|
| 24 |
+
"pre_autograd_decomposition_table",
|
| 25 |
+
"meta_table",
|
| 26 |
+
"register_decomposition",
|
| 27 |
+
"get_decompositions",
|
| 28 |
+
"core_aten_decompositions",
|
| 29 |
+
"_should_decompose_because_unsafe_op",
|
| 30 |
+
]
|
| 31 |
+
|
| 32 |
+
_T = TypeVar("_T")
|
| 33 |
+
_P = ParamSpec("_P")
|
| 34 |
+
|
| 35 |
+
# TODO: relax key type here; torch registrations should be possible to; but
|
| 36 |
+
# right now this type is accurate
|
| 37 |
+
global_decomposition_table: dict[str, dict[torch._ops.OperatorBase, Callable]] = (
|
| 38 |
+
defaultdict(dict)
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
decomposition_table = global_decomposition_table["post_autograd"]
|
| 42 |
+
pre_autograd_decomposition_table = global_decomposition_table["pre_autograd"]
|
| 43 |
+
meta_table = global_decomposition_table["meta"]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _should_decompose_because_unsafe_op(op: torch._ops.OperatorBase) -> bool:
|
| 47 |
+
"""
|
| 48 |
+
Returns True if the op must always decompose in export/compile tracing system
|
| 49 |
+
|
| 50 |
+
In export, we always decompose certain CIA ops that are tagged with
|
| 51 |
+
maybe_aliasing_or_mutating because we statically need to know if the op is
|
| 52 |
+
mutating or not. But these CIA ops could have different behaviour in runtime.
|
| 53 |
+
|
| 54 |
+
native_batch_norm is a prim op which has a wrong schema and it needs to be replaced
|
| 55 |
+
with correct schema. But until then, we will force decompose it via this tag.
|
| 56 |
+
"""
|
| 57 |
+
if not isinstance(op, torch._ops.OpOverload):
|
| 58 |
+
return False
|
| 59 |
+
if torch.Tag.maybe_aliasing_or_mutating in op.tags:
|
| 60 |
+
return True
|
| 61 |
+
return op == torch.ops.aten.native_batch_norm.default
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def _add_op_to_registry(registry, op, fn):
|
| 65 |
+
"""
|
| 66 |
+
This is an internal API for adding an op to the decomposition table.
|
| 67 |
+
|
| 68 |
+
If op is OpOverload, it will be added to the registry directly.
|
| 69 |
+
If op is OpOverloadPacket, all the valid op_overloads in the packet will be added to the registry.
|
| 70 |
+
"""
|
| 71 |
+
overloads: list[Union[torch._ops.OperatorBase]] = []
|
| 72 |
+
if isinstance(op, HigherOrderOperator):
|
| 73 |
+
# There's no concept of overloads for HigherOrderOperator
|
| 74 |
+
registry[op] = fn
|
| 75 |
+
return
|
| 76 |
+
elif isinstance(op, OpOverload):
|
| 77 |
+
overloads.append(op)
|
| 78 |
+
else:
|
| 79 |
+
assert isinstance(op, OpOverloadPacket)
|
| 80 |
+
for ol in op.overloads():
|
| 81 |
+
overloads.append(getattr(op, ol))
|
| 82 |
+
|
| 83 |
+
for op_overload in overloads:
|
| 84 |
+
if op_overload in registry:
|
| 85 |
+
raise RuntimeError(f"duplicate registrations for {op_overload}")
|
| 86 |
+
# TorchScript dumps a bunch of extra nonsense overloads
|
| 87 |
+
# which don't have corresponding dispatcher entries, we need
|
| 88 |
+
# to filter those out, e.g aten.add.float_int
|
| 89 |
+
if torch._C._dispatch_has_kernel(op_overload.name()):
|
| 90 |
+
registry[op_overload] = fn
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def _convert_out_params(f):
|
| 94 |
+
out_annotation = f.__annotations__.get("out")
|
| 95 |
+
|
| 96 |
+
# If there are no out params, do not wrap the function.
|
| 97 |
+
if not out_annotation:
|
| 98 |
+
return f
|
| 99 |
+
|
| 100 |
+
# Hack to detect when out is a Tuple. There seems to be no pretty way of doing this
|
| 101 |
+
if getattr(out_annotation, "__origin__", None) is tuple:
|
| 102 |
+
sig = inspect.signature(f)
|
| 103 |
+
out_names = sig.return_annotation._fields
|
| 104 |
+
# If out is a tuple, we need to register a function that unpacks all the out
|
| 105 |
+
# elements as this is what native_functions.yaml expects
|
| 106 |
+
|
| 107 |
+
@wraps(f)
|
| 108 |
+
def _fn(*args, **kwargs):
|
| 109 |
+
out_kwargs = tuple(kwargs.pop(o, None) for o in out_names)
|
| 110 |
+
# Either all of the out kwargs are set or none of them
|
| 111 |
+
is_none = out_kwargs[0] is None
|
| 112 |
+
assert all((o is None) == is_none for o in out_kwargs)
|
| 113 |
+
return f(*args, **kwargs, out=None if is_none else out_kwargs)
|
| 114 |
+
|
| 115 |
+
out_params = [
|
| 116 |
+
inspect.Parameter(
|
| 117 |
+
o,
|
| 118 |
+
kind=inspect.Parameter.KEYWORD_ONLY,
|
| 119 |
+
default=None,
|
| 120 |
+
annotation=t,
|
| 121 |
+
)
|
| 122 |
+
for o, t in zip(out_names, out_annotation.__args__)
|
| 123 |
+
]
|
| 124 |
+
# Drop the out parameter and concatenate the new kwargs in the signature
|
| 125 |
+
params = chain((v for k, v in sig.parameters.items() if k != "out"), out_params)
|
| 126 |
+
_fn.__signature__ = inspect.Signature( # type: ignore[attr-defined]
|
| 127 |
+
parameters=params, # type: ignore[arg-type]
|
| 128 |
+
return_annotation=sig.return_annotation,
|
| 129 |
+
)
|
| 130 |
+
# Drop the out parameter and concatenate the new kwargs in the annotations
|
| 131 |
+
_fn.__annotations__ = {k: v for k, v in f.__annotations__.items() if k != "out"}
|
| 132 |
+
for o in out_params:
|
| 133 |
+
_fn.__annotations__[o.name] = o.annotation
|
| 134 |
+
|
| 135 |
+
# Propagate that this function is wrapped by `out_wrapper`
|
| 136 |
+
_fn._torch_decompositions_out_wrapper = f._torch_decompositions_out_wrapper # type: ignore[attr-defined]
|
| 137 |
+
|
| 138 |
+
return _fn
|
| 139 |
+
|
| 140 |
+
# Alternatively, there may be a single tensor out parameter with a name
|
| 141 |
+
# other than "out". This will need special treatment and is indicated by an
|
| 142 |
+
# annotation, which we will remove here so it is not exposed after wrapping.
|
| 143 |
+
custom_out_param_name = f.__annotations__.pop(CustomOutParamAnnotation, None)
|
| 144 |
+
if custom_out_param_name:
|
| 145 |
+
|
| 146 |
+
@wraps(f)
|
| 147 |
+
def _fn(*args, **kwargs):
|
| 148 |
+
out_kwarg = kwargs.pop(custom_out_param_name, None)
|
| 149 |
+
return f(*args, **kwargs, out=out_kwarg)
|
| 150 |
+
|
| 151 |
+
out_param = inspect.Parameter(
|
| 152 |
+
custom_out_param_name,
|
| 153 |
+
kind=inspect.Parameter.KEYWORD_ONLY,
|
| 154 |
+
default=None,
|
| 155 |
+
annotation=out_annotation,
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
# Drop the out parameter and concatenate the new kwarg in the signature
|
| 159 |
+
sig = inspect.signature(f)
|
| 160 |
+
params = chain(
|
| 161 |
+
(v for k, v in sig.parameters.items() if k != "out"), (out_param,)
|
| 162 |
+
)
|
| 163 |
+
_fn.__signature__ = inspect.Signature( # type: ignore[attr-defined]
|
| 164 |
+
parameters=params, # type: ignore[arg-type]
|
| 165 |
+
return_annotation=sig.return_annotation,
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
# Drop the out parameter and concatenate the new kwargs in the annotations
|
| 169 |
+
_fn.__annotations__ = {k: v for k, v in f.__annotations__.items() if k != "out"}
|
| 170 |
+
_fn.__annotations__[out_param.name] = out_param.annotation
|
| 171 |
+
|
| 172 |
+
return _fn
|
| 173 |
+
|
| 174 |
+
return f
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def register_decomposition(
|
| 178 |
+
aten_op, registry=None, *, type="post_autograd", unsafe=False
|
| 179 |
+
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
|
| 180 |
+
"""
|
| 181 |
+
A decorator to register a function as a decomposition to the Python
|
| 182 |
+
decomposition table. Use it like this::
|
| 183 |
+
|
| 184 |
+
@register_decomposition(torch.ops.aten.clamp_min)
|
| 185 |
+
def clamp_min(x):
|
| 186 |
+
return torch.clamp(self, min=min)
|
| 187 |
+
|
| 188 |
+
If you are writing a new decomposition, consider contributing it
|
| 189 |
+
directly to PyTorch in torch._decomp.decompositions.
|
| 190 |
+
|
| 191 |
+
This API is experimental; we are almost certainly going to extend
|
| 192 |
+
the API when we make decompositions eligible for use in transforms (e.g.,
|
| 193 |
+
autograd) and not just backend tracing, where we then need to know if a
|
| 194 |
+
decomposition can be used to simulate a transform.
|
| 195 |
+
|
| 196 |
+
By default, we also will register it to the Meta key of dispatcher,
|
| 197 |
+
and replace the c++ Meta implementation if there is already one.
|
| 198 |
+
|
| 199 |
+
unsafe kwarg is for reuse of this function for registering non-function
|
| 200 |
+
things
|
| 201 |
+
"""
|
| 202 |
+
|
| 203 |
+
assert type in {"post_autograd", "pre_autograd", "meta"}
|
| 204 |
+
|
| 205 |
+
def decomposition_decorator(fn: Callable[_P, _T]) -> Callable[_P, _T]:
|
| 206 |
+
orig_fn = fn
|
| 207 |
+
if not unsafe:
|
| 208 |
+
fn = _convert_out_params(fn)
|
| 209 |
+
|
| 210 |
+
nonlocal registry
|
| 211 |
+
if registry is None:
|
| 212 |
+
registry = global_decomposition_table[type]
|
| 213 |
+
|
| 214 |
+
def register(op):
|
| 215 |
+
_add_op_to_registry(registry, op, fn)
|
| 216 |
+
|
| 217 |
+
# To handle allowing multiple aten_ops at once
|
| 218 |
+
pytree.tree_map_(register, aten_op)
|
| 219 |
+
return orig_fn
|
| 220 |
+
|
| 221 |
+
return decomposition_decorator
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def get_decompositions(
|
| 225 |
+
aten_ops: Sequence[Union[torch._ops.OperatorBase, OpOverloadPacket]],
|
| 226 |
+
type: str = "post_autograd",
|
| 227 |
+
) -> dict[torch._ops.OperatorBase, Callable]:
|
| 228 |
+
"""
|
| 229 |
+
Retrieve a dictionary of decompositions corresponding to the list of
|
| 230 |
+
operator overloads and overload packets passed as input. Overload
|
| 231 |
+
packets will include all decomposed overloads in the packet. If there is
|
| 232 |
+
no decomposition for a requested operator, it is silently ignored.
|
| 233 |
+
|
| 234 |
+
This API is experimental; we are almost certainly going to give an alternate,
|
| 235 |
+
more recommended formulation, where a user provides the set of operators
|
| 236 |
+
they know how to implement, and we provide decompositions for everything
|
| 237 |
+
not in this set.
|
| 238 |
+
"""
|
| 239 |
+
assert type in {"post_autograd", "pre_autograd", "meta"}
|
| 240 |
+
|
| 241 |
+
registry = global_decomposition_table[type]
|
| 242 |
+
packets_to_overloads = defaultdict(list)
|
| 243 |
+
for opo in registry:
|
| 244 |
+
if isinstance(opo, (OpOverload, OpOverloadPacket)):
|
| 245 |
+
packets_to_overloads[opo.overloadpacket].append(opo)
|
| 246 |
+
decompositions: dict[torch._ops.OperatorBase, Callable] = {}
|
| 247 |
+
for op in aten_ops:
|
| 248 |
+
if isinstance(op, OpOverloadPacket) and op in packets_to_overloads:
|
| 249 |
+
for op_overload in packets_to_overloads[op]:
|
| 250 |
+
decompositions[op_overload] = registry[op_overload]
|
| 251 |
+
elif isinstance(op, (torch._ops.OperatorBase)) and op in registry:
|
| 252 |
+
decompositions[op] = registry[op]
|
| 253 |
+
return decompositions
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def remove_decompositions(
|
| 257 |
+
decompositions: dict[torch._ops.OperatorBase, Callable],
|
| 258 |
+
aten_ops: Sequence[Union[OpOverload, OpOverloadPacket]],
|
| 259 |
+
) -> None:
|
| 260 |
+
"""
|
| 261 |
+
Given a dictionary of decompositions obtained from get_decompositions(), removes
|
| 262 |
+
operators associated with a list of operator overloads and overload packets passed
|
| 263 |
+
as input. If the decomposition dictionary does not contain a decomposition that is
|
| 264 |
+
specified to be removed, it is silently ignored.
|
| 265 |
+
"""
|
| 266 |
+
for op in aten_ops:
|
| 267 |
+
if isinstance(op, OpOverloadPacket):
|
| 268 |
+
for overload_name in op.overloads():
|
| 269 |
+
opo = getattr(op, overload_name)
|
| 270 |
+
decompositions.pop(opo, None)
|
| 271 |
+
elif isinstance(op, OpOverload):
|
| 272 |
+
decompositions.pop(op, None)
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
# populate the table
|
| 276 |
+
import torch._decomp.decompositions
|
| 277 |
+
import torch._refs
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def core_aten_decompositions() -> "CustomDecompTable":
|
| 281 |
+
from torch.export.exported_program import default_decompositions
|
| 282 |
+
|
| 283 |
+
return default_decompositions()
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
# See NOTE [Core ATen Ops]
|
| 287 |
+
#
|
| 288 |
+
# list was copied from torch/_inductor/decomposition.py
|
| 289 |
+
# excluding decompositions that results in prim ops
|
| 290 |
+
# Resulting opset of decomposition is core aten ops
|
| 291 |
+
def _core_aten_decompositions_post_autograd() -> dict[
|
| 292 |
+
torch._ops.OperatorBase, Callable
|
| 293 |
+
]:
|
| 294 |
+
aten = torch.ops.aten
|
| 295 |
+
return get_decompositions(
|
| 296 |
+
[
|
| 297 |
+
aten.addcdiv,
|
| 298 |
+
aten.addcdiv_,
|
| 299 |
+
aten.addcmul,
|
| 300 |
+
aten.addcmul_,
|
| 301 |
+
aten.addr,
|
| 302 |
+
aten.affine_grid_generator,
|
| 303 |
+
aten.alias_copy,
|
| 304 |
+
aten.all,
|
| 305 |
+
aten.aminmax,
|
| 306 |
+
aten.arange.default,
|
| 307 |
+
aten.arange.start,
|
| 308 |
+
aten.avg_pool2d_backward,
|
| 309 |
+
aten.baddbmm,
|
| 310 |
+
aten.binary_cross_entropy,
|
| 311 |
+
aten.binary_cross_entropy_backward,
|
| 312 |
+
aten.binary_cross_entropy_with_logits,
|
| 313 |
+
aten.block_diag,
|
| 314 |
+
aten.bernoulli.p,
|
| 315 |
+
aten.bernoulli.default,
|
| 316 |
+
aten.celu,
|
| 317 |
+
aten.celu_,
|
| 318 |
+
aten.channel_shuffle,
|
| 319 |
+
aten.clamp_max,
|
| 320 |
+
aten.clamp_min,
|
| 321 |
+
aten.col2im,
|
| 322 |
+
aten.count_nonzero,
|
| 323 |
+
aten.linalg_cross,
|
| 324 |
+
aten.cudnn_batch_norm,
|
| 325 |
+
aten.cudnn_batch_norm_backward,
|
| 326 |
+
aten.miopen_batch_norm_backward,
|
| 327 |
+
aten.deg2rad,
|
| 328 |
+
aten.deg2rad_,
|
| 329 |
+
aten.detach,
|
| 330 |
+
aten.diag_embed,
|
| 331 |
+
aten.diagonal_backward,
|
| 332 |
+
aten.diagonal_copy,
|
| 333 |
+
aten.dot,
|
| 334 |
+
aten.vdot,
|
| 335 |
+
aten.elu_,
|
| 336 |
+
aten.elu_backward,
|
| 337 |
+
aten._embedding_bag,
|
| 338 |
+
aten.embedding_dense_backward,
|
| 339 |
+
aten.empty_like,
|
| 340 |
+
aten._euclidean_dist.default,
|
| 341 |
+
aten.expand_as,
|
| 342 |
+
aten.expand_copy,
|
| 343 |
+
aten.eye,
|
| 344 |
+
aten.fill,
|
| 345 |
+
aten.fill_,
|
| 346 |
+
aten.floor_divide,
|
| 347 |
+
aten.frac,
|
| 348 |
+
aten.frac_,
|
| 349 |
+
aten._fused_moving_avg_obs_fq_helper,
|
| 350 |
+
aten.gelu_,
|
| 351 |
+
aten.gelu_backward,
|
| 352 |
+
aten.glu,
|
| 353 |
+
aten.glu_backward,
|
| 354 |
+
aten.hardshrink,
|
| 355 |
+
aten.hardsigmoid,
|
| 356 |
+
aten.hardsigmoid_,
|
| 357 |
+
aten.hardsigmoid_backward,
|
| 358 |
+
aten.hardswish,
|
| 359 |
+
aten.hardswish_,
|
| 360 |
+
aten.hardswish_backward,
|
| 361 |
+
aten.hardtanh_,
|
| 362 |
+
aten.hardtanh_backward,
|
| 363 |
+
aten.heaviside,
|
| 364 |
+
aten.heaviside_,
|
| 365 |
+
aten.huber_loss,
|
| 366 |
+
aten.huber_loss_backward,
|
| 367 |
+
aten.im2col,
|
| 368 |
+
aten.index_add.out,
|
| 369 |
+
aten.index_add.default,
|
| 370 |
+
aten.index_add_,
|
| 371 |
+
aten.index_copy.out,
|
| 372 |
+
aten.index_copy.default,
|
| 373 |
+
aten.index_copy_,
|
| 374 |
+
aten.index_fill.int_Scalar,
|
| 375 |
+
aten.index_fill.int_Tensor,
|
| 376 |
+
aten.index_fill.int_Scalar_out,
|
| 377 |
+
aten.index_fill.int_Tensor_out,
|
| 378 |
+
aten.index_fill_,
|
| 379 |
+
aten.isin,
|
| 380 |
+
aten.isneginf,
|
| 381 |
+
aten.isposinf,
|
| 382 |
+
aten.l1_loss,
|
| 383 |
+
aten._lazy_clone,
|
| 384 |
+
aten._test_parallel_materialize,
|
| 385 |
+
aten.leaky_relu_,
|
| 386 |
+
aten.leaky_relu_backward,
|
| 387 |
+
aten.lerp,
|
| 388 |
+
aten.lerp_,
|
| 389 |
+
aten.linspace,
|
| 390 |
+
aten.logaddexp,
|
| 391 |
+
aten.logaddexp2,
|
| 392 |
+
aten.logit,
|
| 393 |
+
aten.logit_,
|
| 394 |
+
aten.logit_backward,
|
| 395 |
+
aten.log_sigmoid_backward,
|
| 396 |
+
aten.log_sigmoid_forward,
|
| 397 |
+
aten._log_softmax_backward_data,
|
| 398 |
+
aten.logspace,
|
| 399 |
+
aten.logsumexp.default,
|
| 400 |
+
aten.masked_fill,
|
| 401 |
+
aten.masked_fill_,
|
| 402 |
+
aten.max_unpool2d,
|
| 403 |
+
aten.max_unpool3d,
|
| 404 |
+
aten.mish,
|
| 405 |
+
aten.mish_,
|
| 406 |
+
aten.mse_loss,
|
| 407 |
+
aten.mse_loss_backward,
|
| 408 |
+
aten.multi_margin_loss,
|
| 409 |
+
aten.multilabel_margin_loss_forward,
|
| 410 |
+
aten.mv,
|
| 411 |
+
aten.mvlgamma,
|
| 412 |
+
aten.mvlgamma_,
|
| 413 |
+
aten.nansum,
|
| 414 |
+
aten.nan_to_num,
|
| 415 |
+
aten.nan_to_num_,
|
| 416 |
+
aten.narrow,
|
| 417 |
+
aten.native_batch_norm_backward,
|
| 418 |
+
aten.native_dropout_backward,
|
| 419 |
+
aten.native_group_norm_backward,
|
| 420 |
+
aten.native_layer_norm_backward,
|
| 421 |
+
aten.new_empty,
|
| 422 |
+
aten.new_full,
|
| 423 |
+
aten.new_ones,
|
| 424 |
+
aten.new_zeros,
|
| 425 |
+
aten.nll_loss2d_forward,
|
| 426 |
+
aten.nll_loss2d_backward,
|
| 427 |
+
aten.nll_loss_backward,
|
| 428 |
+
aten.nll_loss_forward,
|
| 429 |
+
aten.norm.ScalarOpt_dtype,
|
| 430 |
+
aten.norm.Scalar,
|
| 431 |
+
aten.norm.ScalarOpt_dim_dtype,
|
| 432 |
+
aten.norm.ScalarOpt_dim,
|
| 433 |
+
aten.norm.dtype_out,
|
| 434 |
+
aten.norm.out,
|
| 435 |
+
aten.norm.names_dtype_out,
|
| 436 |
+
aten.norm.names_out,
|
| 437 |
+
aten.norm.ScalarOpt_dtype_out,
|
| 438 |
+
aten.norm.Scalar_out,
|
| 439 |
+
aten.ones,
|
| 440 |
+
aten.ones_like,
|
| 441 |
+
aten.pixel_shuffle,
|
| 442 |
+
aten.pixel_unshuffle,
|
| 443 |
+
aten._prelu_kernel,
|
| 444 |
+
aten._prelu_kernel_backward,
|
| 445 |
+
aten._reshape_alias,
|
| 446 |
+
aten.rad2deg,
|
| 447 |
+
aten.rad2deg_,
|
| 448 |
+
aten.reflection_pad1d,
|
| 449 |
+
aten.reflection_pad1d_backward,
|
| 450 |
+
aten.reflection_pad2d,
|
| 451 |
+
aten.reflection_pad2d_backward,
|
| 452 |
+
aten.reflection_pad3d,
|
| 453 |
+
aten.reflection_pad3d_backward,
|
| 454 |
+
aten.replication_pad1d,
|
| 455 |
+
aten.replication_pad2d,
|
| 456 |
+
aten.replication_pad3d,
|
| 457 |
+
aten.renorm,
|
| 458 |
+
aten.renorm_,
|
| 459 |
+
aten.replication_pad2d,
|
| 460 |
+
aten.resize_as,
|
| 461 |
+
aten.roll,
|
| 462 |
+
aten.rot90,
|
| 463 |
+
aten.rrelu_with_noise,
|
| 464 |
+
aten.rrelu_with_noise_,
|
| 465 |
+
aten.rsub,
|
| 466 |
+
aten._safe_softmax,
|
| 467 |
+
aten._scaled_dot_product_flash_attention_for_cpu.default,
|
| 468 |
+
aten.select_backward,
|
| 469 |
+
aten.select_scatter,
|
| 470 |
+
aten.sgn,
|
| 471 |
+
aten.sgn_,
|
| 472 |
+
aten.sigmoid_backward,
|
| 473 |
+
aten.silu,
|
| 474 |
+
aten.silu_,
|
| 475 |
+
aten.silu_backward.grad_input,
|
| 476 |
+
aten.sinc,
|
| 477 |
+
aten.sinc_,
|
| 478 |
+
aten.slice_backward,
|
| 479 |
+
aten.smooth_l1_loss,
|
| 480 |
+
aten.smooth_l1_loss_backward,
|
| 481 |
+
aten.soft_margin_loss,
|
| 482 |
+
aten.soft_margin_loss_backward,
|
| 483 |
+
aten._softmax_backward_data,
|
| 484 |
+
aten.softplus,
|
| 485 |
+
aten.softplus_backward,
|
| 486 |
+
aten.softshrink,
|
| 487 |
+
aten.special_entr,
|
| 488 |
+
aten.special_log_ndtr,
|
| 489 |
+
aten.special_xlog1py,
|
| 490 |
+
aten.split.Tensor,
|
| 491 |
+
aten.split_with_sizes_copy,
|
| 492 |
+
aten.squeeze_copy,
|
| 493 |
+
aten.squeeze.default,
|
| 494 |
+
aten.squeeze.dim,
|
| 495 |
+
aten.std.correction,
|
| 496 |
+
aten.std.out,
|
| 497 |
+
aten.std.correction_out,
|
| 498 |
+
aten.std.names_out,
|
| 499 |
+
aten.std.correction_names_out,
|
| 500 |
+
aten.std_mean.correction,
|
| 501 |
+
aten.std_mean.correction_out,
|
| 502 |
+
aten.stack,
|
| 503 |
+
aten.sum.default,
|
| 504 |
+
aten.sum.out,
|
| 505 |
+
aten.t,
|
| 506 |
+
aten.t_copy,
|
| 507 |
+
aten.take,
|
| 508 |
+
aten.tanh_backward,
|
| 509 |
+
aten.threshold,
|
| 510 |
+
aten.threshold_,
|
| 511 |
+
aten.threshold_backward,
|
| 512 |
+
aten.trace,
|
| 513 |
+
aten.transpose.int,
|
| 514 |
+
aten.transpose_copy,
|
| 515 |
+
aten.tril,
|
| 516 |
+
aten.tril_,
|
| 517 |
+
aten.triu,
|
| 518 |
+
aten.triu_,
|
| 519 |
+
aten.unbind,
|
| 520 |
+
aten.unfold_backward,
|
| 521 |
+
aten.unfold_copy,
|
| 522 |
+
aten._unsafe_index,
|
| 523 |
+
aten._unsafe_index_put,
|
| 524 |
+
aten._unsafe_masked_index,
|
| 525 |
+
aten._unsafe_masked_index_put_accumulate,
|
| 526 |
+
aten.unsafe_split.Tensor,
|
| 527 |
+
aten.unsafe_split_with_sizes,
|
| 528 |
+
aten.unsqueeze_copy,
|
| 529 |
+
aten._unsafe_view,
|
| 530 |
+
aten.upsample_linear1d,
|
| 531 |
+
aten.upsample_bilinear2d.out,
|
| 532 |
+
aten.upsample_trilinear3d.out,
|
| 533 |
+
aten.upsample_nearest2d_backward,
|
| 534 |
+
aten.view_as_complex,
|
| 535 |
+
aten.xlogy,
|
| 536 |
+
aten.xlogy_,
|
| 537 |
+
aten.zero,
|
| 538 |
+
aten.zero_,
|
| 539 |
+
aten.zeros,
|
| 540 |
+
aten.zeros_like,
|
| 541 |
+
aten._chunk_cat,
|
| 542 |
+
aten._weight_norm_interface,
|
| 543 |
+
]
|
| 544 |
+
)
|
phivenv/Lib/site-packages/torch/_decomp/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (14 kB). View file
|
|
|
phivenv/Lib/site-packages/torch/_decomp/__pycache__/decompositions_for_jvp.cpython-39.pyc
ADDED
|
Binary file (6.7 kB). View file
|
|
|
phivenv/Lib/site-packages/torch/_decomp/__pycache__/decompositions_for_rng.cpython-39.pyc
ADDED
|
Binary file (8.1 kB). View file
|
|
|
phivenv/Lib/site-packages/torch/_decomp/decompositions.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
phivenv/Lib/site-packages/torch/_decomp/decompositions_for_jvp.py
ADDED
|
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-decorators
|
| 2 |
+
# mypy: allow-untyped-defs
|
| 3 |
+
import inspect
|
| 4 |
+
from typing import Callable, Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch._decomp
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
from torch._prims_common.wrappers import _maybe_remove_out_wrapper
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
decomposition_table = torch._decomp.decomposition_table
|
| 13 |
+
decomposition_table_for_jvp: dict[torch._ops.OperatorBase, Callable] = {}
|
| 14 |
+
register_decomposition = torch._decomp.register_decomposition
|
| 15 |
+
aten = torch.ops.aten
|
| 16 |
+
|
| 17 |
+
# NOTE: [forward-mode AD decompositions mechanism]
|
| 18 |
+
#
|
| 19 |
+
# The mechanism is in VariableType,
|
| 20 |
+
# IF any inputs have forward grad
|
| 21 |
+
# AND there is no forward AD formula implemented
|
| 22 |
+
# AND the functions are actually differentiable
|
| 23 |
+
# run the decomposition
|
| 24 |
+
# See run_jit_decomposition_with_args_for_jvp
|
| 25 |
+
# We currently use python decompositions that we torchscript.
|
| 26 |
+
#
|
| 27 |
+
# Note that we would be building the backward graph at the decomposed level
|
| 28 |
+
# too, but that is OK, because we would've errored out otherwise anyway.
|
| 29 |
+
#
|
| 30 |
+
# TODO: The mechanism we are using to register decompositions doesn't
|
| 31 |
+
# seem to be exclusively used for jvp. So open question here is whether
|
| 32 |
+
# torch/csrc/jit/runtime/decomposition_registry.cpp is being used for other things.
|
| 33 |
+
# If that is the case, we may go down the decomposition path unexpectedly
|
| 34 |
+
# (and possibly produce an unintelligible error) vs erroring out earlier and
|
| 35 |
+
# printing that the forward AD formula is not implemented.
|
| 36 |
+
#
|
| 37 |
+
# The solution to this may be to have an explicitly white list control when
|
| 38 |
+
# to enable the decomposition.
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def maybe_register_decomposition(op):
|
| 42 |
+
def decorator(f):
|
| 43 |
+
try:
|
| 44 |
+
return register_decomposition(op)(f)
|
| 45 |
+
except Exception:
|
| 46 |
+
return f
|
| 47 |
+
|
| 48 |
+
return decorator
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# Functions where we need a special decomposition for jvp but there's another version that
|
| 52 |
+
# should be used more generally (ex. for jvp we need to recompute the mean and variance for
|
| 53 |
+
# the backwards of a normalization function. Without jvp, it should use the saved value)
|
| 54 |
+
decomposition_table_for_jvp = {}
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def register_decomposition_for_jvp(fn):
|
| 58 |
+
return register_decomposition(fn, registry=decomposition_table_for_jvp)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _register_jit_decomposition_for_jvp(decomp, use_python=False):
|
| 62 |
+
if decomp in decomposition_table_for_jvp:
|
| 63 |
+
decomposition_table_used = decomposition_table_for_jvp
|
| 64 |
+
elif decomp in decomposition_table:
|
| 65 |
+
decomposition_table_used = decomposition_table
|
| 66 |
+
else:
|
| 67 |
+
raise RuntimeError(f"could not find decomposition for {decomp}")
|
| 68 |
+
decomp_fn = decomposition_table_used[decomp]
|
| 69 |
+
|
| 70 |
+
# `out_wrapper` extends a decompositions signature with
|
| 71 |
+
# an `out` parameter. However jit will use the unwrapped function's
|
| 72 |
+
# signature instead so we need to unwrap here to prevent an error
|
| 73 |
+
decomp_fn = _maybe_remove_out_wrapper(decomp_fn)
|
| 74 |
+
|
| 75 |
+
if use_python:
|
| 76 |
+
decomp_fn = torch.jit.ignore(decomp_fn)
|
| 77 |
+
sig = inspect.signature(decomp_fn)
|
| 78 |
+
|
| 79 |
+
# Create a string wrapping the function from the signature
|
| 80 |
+
# example output:
|
| 81 |
+
# def wrapped_decomp(x: torch.Tensor, y: int, z: int):
|
| 82 |
+
# return decomp_fn(x, y, z)
|
| 83 |
+
# Thanks copilot!
|
| 84 |
+
def get_function_def(sig):
|
| 85 |
+
param_def = [f"{param_str}" for param_str in sig.parameters.values()]
|
| 86 |
+
param_use = [f"{param_str}" for param_str in sig.parameters.keys()]
|
| 87 |
+
|
| 88 |
+
return f"def wrapped_decomp({', '.join(param_def)}):\n return decomp_fn({', '.join(param_use)})\n"
|
| 89 |
+
|
| 90 |
+
f_str = get_function_def(sig)
|
| 91 |
+
graph = torch.jit.CompilationUnit(f_str).wrapped_decomp.graph
|
| 92 |
+
else:
|
| 93 |
+
graph = torch.jit.script(decomp_fn).graph
|
| 94 |
+
torch.jit._register_decomposition(decomp, graph)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
# The only decompositions here are temporary or hacks for the purposes of jvp
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# TODO: do these also belong here?
|
| 101 |
+
@maybe_register_decomposition(aten.trace.default)
|
| 102 |
+
def trace(self: Tensor) -> Tensor:
|
| 103 |
+
return torch.sum(torch.diag(self))
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
@maybe_register_decomposition(aten.log_sigmoid_forward.default)
|
| 107 |
+
def log_sigmoid_forward(self: Tensor) -> tuple[Tensor, Tensor]:
|
| 108 |
+
min = torch.minimum(self.new_zeros(()), self)
|
| 109 |
+
z = torch.exp(-torch.abs(self))
|
| 110 |
+
if self.is_cuda or self.is_xpu:
|
| 111 |
+
buffer = self.new_zeros((0,))
|
| 112 |
+
else:
|
| 113 |
+
buffer = z
|
| 114 |
+
return min - torch.log1p(z), buffer
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def recompute_mean_var(
|
| 118 |
+
input: Tensor, rstd: Tensor, inner_dim_indices: list[int], keepdim: bool
|
| 119 |
+
):
|
| 120 |
+
# for most norm decompositions, it will be the same as the core version except for here.
|
| 121 |
+
# We recompute the mean and variance so that they track gradients through input
|
| 122 |
+
|
| 123 |
+
mean = torch.mean(input, dim=inner_dim_indices, keepdim=keepdim)
|
| 124 |
+
var = torch.var(input, dim=inner_dim_indices, unbiased=False, keepdim=keepdim)
|
| 125 |
+
eps = torch.pow(1 / rstd, 2) - var # this makes me so sad inside
|
| 126 |
+
eps = eps.detach()
|
| 127 |
+
rstd = 1 / torch.sqrt(var + eps)
|
| 128 |
+
return mean, rstd
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
@register_decomposition_for_jvp(aten.native_layer_norm_backward)
|
| 132 |
+
def native_layer_norm_backward(
|
| 133 |
+
grad_out: Tensor,
|
| 134 |
+
input: Tensor,
|
| 135 |
+
normalized_shape: list[int],
|
| 136 |
+
mean: Tensor,
|
| 137 |
+
rstd: Tensor,
|
| 138 |
+
weight: Optional[Tensor],
|
| 139 |
+
bias: Optional[Tensor],
|
| 140 |
+
output_mask: list[bool],
|
| 141 |
+
) -> tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
|
| 142 |
+
input_shape = input.shape
|
| 143 |
+
input_ndim = input.dim()
|
| 144 |
+
|
| 145 |
+
axis = input_ndim - len(normalized_shape)
|
| 146 |
+
inner_dims = input_shape[axis:]
|
| 147 |
+
outer_dims = input_shape[:axis]
|
| 148 |
+
inner_dim_indices = list(range(axis, input_ndim))
|
| 149 |
+
outer_dim_indices = list(range(0, axis))
|
| 150 |
+
|
| 151 |
+
N = 1
|
| 152 |
+
for i in inner_dims:
|
| 153 |
+
N *= i
|
| 154 |
+
M = 1
|
| 155 |
+
for i in outer_dims:
|
| 156 |
+
M *= i
|
| 157 |
+
if M <= 0 or N <= 0:
|
| 158 |
+
return (
|
| 159 |
+
input.new_zeros(input_shape),
|
| 160 |
+
input.new_zeros(input_shape[axis:]),
|
| 161 |
+
input.new_zeros(input_shape[axis:]),
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
mean_, rstd_ = recompute_mean_var(input, rstd, inner_dim_indices, keepdim=True)
|
| 165 |
+
|
| 166 |
+
x_hat = (input - mean_) * rstd_
|
| 167 |
+
if weight is not None:
|
| 168 |
+
grad_x_hat = grad_out * weight
|
| 169 |
+
else:
|
| 170 |
+
grad_x_hat = grad_out
|
| 171 |
+
a = grad_x_hat * N
|
| 172 |
+
b = torch.sum(grad_x_hat, inner_dim_indices, True)
|
| 173 |
+
c1 = torch.mul(grad_x_hat, x_hat)
|
| 174 |
+
c2 = torch.sum(c1, inner_dim_indices, True)
|
| 175 |
+
c3 = torch.mul(x_hat, c2)
|
| 176 |
+
inner = a - b - c3
|
| 177 |
+
|
| 178 |
+
if output_mask[0]:
|
| 179 |
+
d_input: Optional[Tensor] = (rstd_ / N) * inner
|
| 180 |
+
else:
|
| 181 |
+
d_input = torch.zeros_like(input) # should be None but doesn't work with vjp
|
| 182 |
+
|
| 183 |
+
if output_mask[1] and weight is not None:
|
| 184 |
+
if len(outer_dim_indices) > 0:
|
| 185 |
+
d_weight: Optional[Tensor] = torch.sum(
|
| 186 |
+
grad_out * x_hat, outer_dim_indices, False
|
| 187 |
+
)
|
| 188 |
+
else:
|
| 189 |
+
d_weight = grad_out * x_hat
|
| 190 |
+
elif weight is not None:
|
| 191 |
+
d_weight = torch.zeros_like(weight) # should be None but doesn't work with vjp
|
| 192 |
+
else:
|
| 193 |
+
d_weight = torch.zeros(()) # should be None but doesn't work with vjp
|
| 194 |
+
|
| 195 |
+
if output_mask[2] and bias is not None:
|
| 196 |
+
if len(outer_dim_indices) > 0:
|
| 197 |
+
d_bias: Optional[Tensor] = torch.sum(grad_out, outer_dim_indices, False)
|
| 198 |
+
else:
|
| 199 |
+
d_bias = grad_out.clone()
|
| 200 |
+
elif bias is not None:
|
| 201 |
+
d_bias = torch.zeros_like(bias) # should be None but doesn't work with vjp
|
| 202 |
+
else:
|
| 203 |
+
d_bias = torch.zeros(()) # should be None but doesn't work with vjp
|
| 204 |
+
|
| 205 |
+
return (d_input, d_weight, d_bias)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def prod(x: list[int]):
|
| 209 |
+
r = 1
|
| 210 |
+
for i in x:
|
| 211 |
+
r *= i
|
| 212 |
+
return r
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
@register_decomposition_for_jvp(aten.native_batch_norm_backward)
|
| 216 |
+
def native_batch_norm_backward(
|
| 217 |
+
grad_out: Tensor,
|
| 218 |
+
input: Tensor,
|
| 219 |
+
weight: Optional[Tensor],
|
| 220 |
+
running_mean: Optional[Tensor],
|
| 221 |
+
running_var: Optional[Tensor],
|
| 222 |
+
save_mean: Optional[Tensor],
|
| 223 |
+
save_invstd: Optional[Tensor],
|
| 224 |
+
train: bool,
|
| 225 |
+
eps: float,
|
| 226 |
+
output_mask: list[bool],
|
| 227 |
+
) -> tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
|
| 228 |
+
input_shape = input.shape
|
| 229 |
+
input_rank = input.dim()
|
| 230 |
+
assert input_rank >= 2, "rank of the input must be at least 2"
|
| 231 |
+
|
| 232 |
+
axis = 1
|
| 233 |
+
num_features = prod(input_shape) / input_shape[axis] # type: ignore[arg-type]
|
| 234 |
+
mean = save_mean
|
| 235 |
+
invstd = save_invstd
|
| 236 |
+
if train:
|
| 237 |
+
assert save_mean is not None and save_invstd is not None, (
|
| 238 |
+
"when train=True, save_mean and save_invstd are required"
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
reduciton_dims = [0] + list(range(2, input.dim()))
|
| 242 |
+
assert invstd is not None # for typing
|
| 243 |
+
mean, invstd = recompute_mean_var(input, invstd, reduciton_dims, keepdim=False)
|
| 244 |
+
else:
|
| 245 |
+
assert running_mean is not None and running_var is not None
|
| 246 |
+
mean = running_mean
|
| 247 |
+
invstd = torch.rsqrt(running_var + eps)
|
| 248 |
+
|
| 249 |
+
assert invstd is not None and mean is not None
|
| 250 |
+
|
| 251 |
+
broadcast_mask = [1] * input_rank
|
| 252 |
+
broadcast_mask[axis] = input_shape[axis]
|
| 253 |
+
|
| 254 |
+
reduction_axes: list[int] = []
|
| 255 |
+
for i in range(input_rank):
|
| 256 |
+
if i != axis:
|
| 257 |
+
reduction_axes.append(i)
|
| 258 |
+
|
| 259 |
+
mean = torch.reshape(mean, broadcast_mask)
|
| 260 |
+
norm = 1.0 / num_features
|
| 261 |
+
grad_output_sum = torch.sum(grad_out, reduction_axes)
|
| 262 |
+
dot_p = torch.sum(grad_out * (input - mean), reduction_axes)
|
| 263 |
+
|
| 264 |
+
grad_mean = torch.reshape(grad_output_sum * norm, broadcast_mask)
|
| 265 |
+
proj_scale = torch.reshape(torch.mul(dot_p * norm, invstd * invstd), broadcast_mask)
|
| 266 |
+
|
| 267 |
+
if weight is None:
|
| 268 |
+
grad_scale = torch.reshape(invstd, broadcast_mask) * 1.0
|
| 269 |
+
else:
|
| 270 |
+
grad_scale = torch.reshape(invstd * weight, broadcast_mask)
|
| 271 |
+
|
| 272 |
+
if train:
|
| 273 |
+
proj = (input - mean) * proj_scale
|
| 274 |
+
grad_input = ((grad_out - proj) - grad_mean) * grad_scale
|
| 275 |
+
else:
|
| 276 |
+
grad_input = grad_out * grad_scale
|
| 277 |
+
|
| 278 |
+
if output_mask[1]:
|
| 279 |
+
grad_weight = dot_p * invstd
|
| 280 |
+
elif weight is not None:
|
| 281 |
+
grad_weight = torch.zeros_like(
|
| 282 |
+
weight
|
| 283 |
+
) # should be None but doesn't work with vjp
|
| 284 |
+
else:
|
| 285 |
+
grad_weight = torch.zeros(()) # should be None but doesn't work with vjp
|
| 286 |
+
|
| 287 |
+
if output_mask[2]:
|
| 288 |
+
grad_bias = grad_output_sum
|
| 289 |
+
else:
|
| 290 |
+
grad_bias = torch.zeros_like(
|
| 291 |
+
grad_output_sum
|
| 292 |
+
) # should be None but doesn't work with vjp
|
| 293 |
+
|
| 294 |
+
return (grad_input, grad_weight, grad_bias)
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
@register_decomposition_for_jvp(aten.batch_norm_backward)
|
| 298 |
+
def batch_norm_backward(
|
| 299 |
+
grad_out: Tensor,
|
| 300 |
+
input: Tensor,
|
| 301 |
+
weight: Tensor,
|
| 302 |
+
running_mean: Optional[Tensor],
|
| 303 |
+
running_var: Optional[Tensor],
|
| 304 |
+
save_mean: Optional[Tensor],
|
| 305 |
+
save_var: Optional[Tensor],
|
| 306 |
+
update: bool,
|
| 307 |
+
eps: float,
|
| 308 |
+
output_mask: list[bool],
|
| 309 |
+
reserve: Tensor,
|
| 310 |
+
) -> tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
|
| 311 |
+
return native_batch_norm_backward(
|
| 312 |
+
grad_out,
|
| 313 |
+
input,
|
| 314 |
+
weight,
|
| 315 |
+
running_mean,
|
| 316 |
+
running_var,
|
| 317 |
+
save_mean,
|
| 318 |
+
save_var,
|
| 319 |
+
update,
|
| 320 |
+
eps,
|
| 321 |
+
output_mask,
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
_register_jit_decomposition_for_jvp(torch.ops.aten.trace.default, use_python=True)
|
| 326 |
+
_register_jit_decomposition_for_jvp(torch.ops.aten.nll_loss_backward.default)
|
| 327 |
+
_register_jit_decomposition_for_jvp(torch.ops.aten.nll_loss2d_backward.default)
|
| 328 |
+
_register_jit_decomposition_for_jvp(torch.ops.aten._log_softmax_backward_data.default)
|
| 329 |
+
_register_jit_decomposition_for_jvp(torch.ops.aten._softmax_backward_data.default)
|
| 330 |
+
_register_jit_decomposition_for_jvp(torch.ops.aten.log_sigmoid_forward.default)
|
| 331 |
+
_register_jit_decomposition_for_jvp(torch.ops.aten.native_layer_norm_backward.default)
|
| 332 |
+
_register_jit_decomposition_for_jvp(torch.ops.aten.native_batch_norm_backward.default)
|
| 333 |
+
_register_jit_decomposition_for_jvp(torch.ops.aten.cudnn_batch_norm_backward.default)
|
| 334 |
+
_register_jit_decomposition_for_jvp(torch.ops.aten.batch_norm_backward.default)
|
| 335 |
+
_register_jit_decomposition_for_jvp(torch.ops.aten.miopen_batch_norm_backward.default)
|
phivenv/Lib/site-packages/torch/_decomp/decompositions_for_rng.py
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-decorators
|
| 2 |
+
# mypy: allow-untyped-defs
|
| 3 |
+
import functools
|
| 4 |
+
from collections import defaultdict
|
| 5 |
+
from typing import Callable
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch._decomp as decomp
|
| 9 |
+
from torch._decomp import get_decompositions
|
| 10 |
+
from torch._ops import OpOverload
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
aten = torch.ops.aten
|
| 14 |
+
|
| 15 |
+
rng_decompositions: dict[str, dict[OpOverload, Callable]] = defaultdict(dict)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def register_rng_decomposition(aten_op):
|
| 19 |
+
return decomp.register_decomposition(aten_op, rng_decompositions)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def throw_on_non_cuda(device):
|
| 23 |
+
raise RuntimeError(
|
| 24 |
+
f"You are trying to functionalize a {device.type} RNG operator but {device.type} does not "
|
| 25 |
+
f"use Philox/counter-based RNG. Therefore, functionalizing a {device.type} RNG operator is "
|
| 26 |
+
"not supported. We are discussing the possibility of a Philox-based RNG implementation for CPU."
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# TODO - We have to register many more distributions here, and also higher level
|
| 31 |
+
# ops like dropout which have fused implementation and can hide the rand inside.
|
| 32 |
+
@register_rng_decomposition(aten.rand)
|
| 33 |
+
def rand(shape, dtype=None, layout=torch.strided, device=None, pin_memory=False):
|
| 34 |
+
if device and device.type != "cuda":
|
| 35 |
+
throw_on_non_cuda(device)
|
| 36 |
+
seed, offset = PhiloxStateTracker.get_state_as_tuple()
|
| 37 |
+
dtype = dtype or torch.float32
|
| 38 |
+
out, offset_jump = torch.ops.rngprims.philox_rand(
|
| 39 |
+
shape, seed, offset, None, device, dtype
|
| 40 |
+
)
|
| 41 |
+
PhiloxStateTracker.advance_offset(offset_jump)
|
| 42 |
+
return out
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@register_rng_decomposition(aten.rand_like)
|
| 46 |
+
def rand_like(
|
| 47 |
+
x: torch.Tensor,
|
| 48 |
+
dtype=None,
|
| 49 |
+
layout=None,
|
| 50 |
+
device=None,
|
| 51 |
+
pin_memory=False,
|
| 52 |
+
memory_format=torch.preserve_format,
|
| 53 |
+
):
|
| 54 |
+
device = device or x.device
|
| 55 |
+
if device.type != "cuda":
|
| 56 |
+
throw_on_non_cuda(device)
|
| 57 |
+
dtype = dtype or x.dtype
|
| 58 |
+
seed, offset = PhiloxStateTracker.get_state_as_tuple()
|
| 59 |
+
out, offset_jump = torch.ops.rngprims.philox_rand(
|
| 60 |
+
x.shape, seed, offset, None, device, dtype
|
| 61 |
+
)
|
| 62 |
+
PhiloxStateTracker.advance_offset(offset_jump)
|
| 63 |
+
return out
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class PhiloxState:
|
| 67 |
+
"""
|
| 68 |
+
Represents a PhiloxRngState - (seed, offset) where offset = base_offset +
|
| 69 |
+
relative_offset. seed and base_offset basically point to the rng state just
|
| 70 |
+
before tracing starts. relative offset tracks the totally consumed offset at
|
| 71 |
+
trace time.
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
def __init__(self) -> None:
|
| 75 |
+
self.reset()
|
| 76 |
+
|
| 77 |
+
def reset(self):
|
| 78 |
+
self.seed = torch.tensor(())
|
| 79 |
+
self.base_offset = torch.tensor(())
|
| 80 |
+
self.relative_offset = 0
|
| 81 |
+
self.offset_advanced_alteast_once = False
|
| 82 |
+
|
| 83 |
+
def validate_state(self):
|
| 84 |
+
assert self.seed.numel() != 0 and self.base_offset.numel() != 0
|
| 85 |
+
|
| 86 |
+
def advance_offset(self, consumed_offset):
|
| 87 |
+
self.offset_advanced_alteast_once = True
|
| 88 |
+
self.relative_offset = self.relative_offset + consumed_offset
|
| 89 |
+
|
| 90 |
+
def set_state(self, seed, base_offset, relative_offset=0):
|
| 91 |
+
self.seed = seed
|
| 92 |
+
self.base_offset = base_offset
|
| 93 |
+
self.relative_offset = relative_offset
|
| 94 |
+
|
| 95 |
+
def get_state_as_tuple(self):
|
| 96 |
+
self.validate_state()
|
| 97 |
+
return (self.seed, self.base_offset + self.relative_offset)
|
| 98 |
+
|
| 99 |
+
def get_state_as_tensor(self):
|
| 100 |
+
# Only needed because we override get_rng_state.
|
| 101 |
+
self.validate_state()
|
| 102 |
+
return torch.stack([self.seed, self.base_offset + self.relative_offset])
|
| 103 |
+
|
| 104 |
+
def set_state_from_tensor(self, state):
|
| 105 |
+
# Only needed because we override set_rng_state.
|
| 106 |
+
self.seed, self.base_offset = torch.unbind(state)
|
| 107 |
+
self.relative_offset = 0
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class PhiloxStateTracker:
|
| 111 |
+
"""
|
| 112 |
+
Singleton class to track the philox rng state during AOT Autograd tracing.
|
| 113 |
+
For each aot tracing instance, AOT Autograd resets this tracker and keeps
|
| 114 |
+
track of both forward and backward offsets. At runtime, we only care about
|
| 115 |
+
the total consumed forward and backward offsets. For dynamic shapes, these
|
| 116 |
+
offsets are a function of input shapes. Therefore, the AOT generated graphs
|
| 117 |
+
have additional outputs that compute total consumed forward and backward
|
| 118 |
+
offsets.
|
| 119 |
+
"""
|
| 120 |
+
|
| 121 |
+
running_state: PhiloxState
|
| 122 |
+
fwd_state: PhiloxState
|
| 123 |
+
bwd_state: PhiloxState
|
| 124 |
+
|
| 125 |
+
def __enter__(self):
|
| 126 |
+
PhiloxStateTracker.reset()
|
| 127 |
+
return self
|
| 128 |
+
|
| 129 |
+
def __exit__(self, exc_type, exc_cal, exc_tb):
|
| 130 |
+
PhiloxStateTracker.reset()
|
| 131 |
+
|
| 132 |
+
@classmethod
|
| 133 |
+
def reset(cls):
|
| 134 |
+
cls.running_state = PhiloxState()
|
| 135 |
+
cls.fwd_state = PhiloxState()
|
| 136 |
+
cls.bwd_state = PhiloxState()
|
| 137 |
+
|
| 138 |
+
@classmethod
|
| 139 |
+
def mark_beginning_of_forward(cls):
|
| 140 |
+
# Tells the tracker to use fwd_state as the running state
|
| 141 |
+
cls.running_state = cls.fwd_state
|
| 142 |
+
|
| 143 |
+
@classmethod
|
| 144 |
+
def mark_beginning_of_backward(cls):
|
| 145 |
+
# Tells the tracker to use bwd_state as the running state
|
| 146 |
+
cls.running_state = cls.bwd_state
|
| 147 |
+
|
| 148 |
+
@classmethod
|
| 149 |
+
def record_state(cls, seed, offset, mode):
|
| 150 |
+
# Records the seed and offset tensors. These tensors are used to invoke
|
| 151 |
+
# the philox_rand functional primitives.
|
| 152 |
+
if mode == "forward":
|
| 153 |
+
cls.fwd_state.set_state(seed, offset)
|
| 154 |
+
cls.mark_beginning_of_forward()
|
| 155 |
+
else:
|
| 156 |
+
assert mode == "backward"
|
| 157 |
+
cls.bwd_state.set_state(seed, offset)
|
| 158 |
+
|
| 159 |
+
@classmethod
|
| 160 |
+
def get_state_as_tensor(cls):
|
| 161 |
+
# The only reason this exists is because we override get_rng_state and
|
| 162 |
+
# set_rng_state during tracing. get_rng_state expects a tensor output,
|
| 163 |
+
# so return (seed, offset) tuple upset other parts of the program like
|
| 164 |
+
# ctx.saved_tensors.
|
| 165 |
+
|
| 166 |
+
# A bad consequence is that if user saves and restores rng state, we
|
| 167 |
+
# have little bit of ugliness in the generated code, where we first
|
| 168 |
+
# concat the (seed, offset) to create a tensor for get_rng_state, and
|
| 169 |
+
# then split it back to get (seed, offset) tuple in set_rng_state.
|
| 170 |
+
|
| 171 |
+
# TODO: Investigate if there is be a better way to wrap the tuple in a
|
| 172 |
+
# false Tensor object, and then desugar it later on.
|
| 173 |
+
return cls.running_state.get_state_as_tensor()
|
| 174 |
+
|
| 175 |
+
@classmethod
|
| 176 |
+
def get_state_as_tuple(cls):
|
| 177 |
+
return cls.running_state.get_state_as_tuple()
|
| 178 |
+
|
| 179 |
+
@classmethod
|
| 180 |
+
def set_state_from_tensor(cls, x):
|
| 181 |
+
# This is only needed because we override set_rng_state. Look at the
|
| 182 |
+
# comment in get_state_from_tensor method.
|
| 183 |
+
cls.running_state.set_state_from_tensor(x)
|
| 184 |
+
|
| 185 |
+
@classmethod
|
| 186 |
+
def advance_offset(cls, consumed_offset):
|
| 187 |
+
cls.running_state.advance_offset(consumed_offset)
|
| 188 |
+
|
| 189 |
+
@classmethod
|
| 190 |
+
def get_current_relative_offset(cls):
|
| 191 |
+
return cls.running_state.relative_offset
|
| 192 |
+
|
| 193 |
+
@staticmethod
|
| 194 |
+
def multiple_of_4(offset):
|
| 195 |
+
# torch cuda rng state offset must be a multiple of 4. For inductor, as
|
| 196 |
+
# we sum up all the numel, the result might not be a multiple of 4. This
|
| 197 |
+
# method achieves that.
|
| 198 |
+
return (offset + 3) // 4 * 4
|
| 199 |
+
|
| 200 |
+
@classmethod
|
| 201 |
+
def get_updated_fwd_offset(cls):
|
| 202 |
+
# Short circuit if no rand ops were observed
|
| 203 |
+
if not cls.fwd_state.offset_advanced_alteast_once:
|
| 204 |
+
return cls.fwd_state.base_offset
|
| 205 |
+
return cls.multiple_of_4(
|
| 206 |
+
cls.fwd_state.base_offset + cls.fwd_state.relative_offset
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
@classmethod
|
| 210 |
+
def get_updated_bwd_offset(cls):
|
| 211 |
+
# Short circuit if no rand ops were observed
|
| 212 |
+
if not cls.bwd_state.offset_advanced_alteast_once:
|
| 213 |
+
return cls.bwd_state.base_offset
|
| 214 |
+
return cls.multiple_of_4(
|
| 215 |
+
cls.bwd_state.base_offset + cls.bwd_state.relative_offset
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
# Adding more decompositions which eventually use rand_like inside decomps.
|
| 220 |
+
# Adding these in rng_decompositions ensures the functionalization of rand_like
|
| 221 |
+
# ops used in these decomps. The list is copied from inductor codebase, which
|
| 222 |
+
# uses it for similar purpose.
|
| 223 |
+
#
|
| 224 |
+
# Caution - These decomps do not have same accuracy as that of eager. However,
|
| 225 |
+
# we can't just disable them with a config flag like fallback_random, because
|
| 226 |
+
# for functionalization of rng ops, we have to decompose these ops.
|
| 227 |
+
extra_random_decomps = get_decompositions(
|
| 228 |
+
[
|
| 229 |
+
aten.cauchy,
|
| 230 |
+
aten.cauchy_,
|
| 231 |
+
aten.exponential,
|
| 232 |
+
aten.exponential_,
|
| 233 |
+
aten.geometric,
|
| 234 |
+
aten.geometric_,
|
| 235 |
+
aten.native_dropout,
|
| 236 |
+
aten.normal,
|
| 237 |
+
aten.normal_,
|
| 238 |
+
aten.normal_functional,
|
| 239 |
+
aten.log_normal,
|
| 240 |
+
aten.log_normal_,
|
| 241 |
+
aten.rrelu_with_noise,
|
| 242 |
+
aten.rrelu_with_noise_,
|
| 243 |
+
aten.uniform_,
|
| 244 |
+
]
|
| 245 |
+
)
|
| 246 |
+
register_extra_random_decomp = functools.partial(
|
| 247 |
+
decomp.register_decomposition, registry=extra_random_decomps
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
@register_extra_random_decomp([aten.bernoulli_])
|
| 252 |
+
def bernoulli_(self, p=0.5):
|
| 253 |
+
if self.device == torch.device("cpu"):
|
| 254 |
+
return NotImplemented
|
| 255 |
+
return self.copy_(torch.rand_like(self, dtype=torch.float32) < p)
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
@register_extra_random_decomp([aten.bernoulli.p])
|
| 259 |
+
def bernoulli_p(self, p=0.5, *, generator=None):
|
| 260 |
+
if self.device == torch.device("cpu"):
|
| 261 |
+
return NotImplemented
|
| 262 |
+
assert generator is None
|
| 263 |
+
return torch.rand_like(self, dtype=torch.float32) < p
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
rng_decompositions.update(extra_random_decomps) # type: ignore[arg-type]
|
phivenv/Lib/site-packages/torch/_dispatch/__init__.py
ADDED
|
File without changes
|
phivenv/Lib/site-packages/torch/_dispatch/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (159 Bytes). View file
|
|
|
phivenv/Lib/site-packages/torch/_dispatch/__pycache__/python.cpython-39.pyc
ADDED
|
Binary file (6.92 kB). View file
|
|
|
phivenv/Lib/site-packages/torch/_dispatch/python.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import itertools
|
| 3 |
+
import unittest.mock
|
| 4 |
+
from collections.abc import Iterator
|
| 5 |
+
from contextlib import contextmanager
|
| 6 |
+
from typing import Callable, TypeVar, Union
|
| 7 |
+
from typing_extensions import ParamSpec
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch._C
|
| 11 |
+
import torch._ops
|
| 12 |
+
import torch.utils._python_dispatch
|
| 13 |
+
import torch.utils._pytree as pytree
|
| 14 |
+
from torch._C import DispatchKey
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
__all__ = ["enable_python_dispatcher", "no_python_dispatcher", "enable_pre_dispatch"]
|
| 18 |
+
|
| 19 |
+
no_python_dispatcher = torch._C._DisablePythonDispatcher
|
| 20 |
+
enable_python_dispatcher = torch._C._EnablePythonDispatcher
|
| 21 |
+
enable_pre_dispatch = torch._C._EnablePreDispatch
|
| 22 |
+
|
| 23 |
+
CROSSREF_FUNCTIONALIZE = False
|
| 24 |
+
|
| 25 |
+
_P = ParamSpec("_P")
|
| 26 |
+
_T = TypeVar("_T")
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def all_py_loaded_overloads() -> Iterator[torch._ops.OpOverload]:
|
| 30 |
+
"""
|
| 31 |
+
Warning: the set of overloads this will report is very subtle. It is precisely
|
| 32 |
+
the set of torch.ops functions that have actually been accessed from Python
|
| 33 |
+
(e.g., we actually called torch.ops.aten.blah at some point. This is DIFFERENT
|
| 34 |
+
from the set of registered operators, which will in general be a larger set,
|
| 35 |
+
as this would include all operators which we ran C++ static initializers or
|
| 36 |
+
Python operator registration on. This does not eagerly populate the list on
|
| 37 |
+
torch.ops.aten; this list is lazy!
|
| 38 |
+
|
| 39 |
+
In other words, this is good for traversing over everything that has an
|
| 40 |
+
OpOverload object allocated in Python. We use it for cache invalidation, but
|
| 41 |
+
don't rely on this list being complete.
|
| 42 |
+
|
| 43 |
+
Note that even if we did report all C++ registered overloads, this isn't guaranteed
|
| 44 |
+
to be complete either, as a subsequent lazy load of a library which triggers more
|
| 45 |
+
registrations could add more things to the set.
|
| 46 |
+
"""
|
| 47 |
+
for ns in torch.ops:
|
| 48 |
+
packets = getattr(torch.ops, ns)
|
| 49 |
+
for op_name in packets:
|
| 50 |
+
packet = getattr(packets, op_name)
|
| 51 |
+
for overload in packet:
|
| 52 |
+
yield getattr(packet, overload)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@contextmanager
|
| 56 |
+
def suspend_functionalization():
|
| 57 |
+
f_tls = torch._C._dispatch_tls_is_dispatch_key_included(
|
| 58 |
+
torch._C.DispatchKey.Functionalize
|
| 59 |
+
)
|
| 60 |
+
f_rv = torch._C._functionalization_reapply_views_tls()
|
| 61 |
+
if f_tls:
|
| 62 |
+
torch._disable_functionalization()
|
| 63 |
+
try:
|
| 64 |
+
yield
|
| 65 |
+
finally:
|
| 66 |
+
if f_tls:
|
| 67 |
+
torch._enable_functionalization(reapply_views=f_rv)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def check_tensor_metadata_matches(nv, rv, desc):
|
| 71 |
+
assert callable(desc)
|
| 72 |
+
assert nv.size() == rv.size(), f"{desc()}: sizes {nv.size()} != {rv.size()}"
|
| 73 |
+
assert nv.dtype == rv.dtype, f"{desc()}: dtype {nv.dtype} != {rv.dtype}"
|
| 74 |
+
same_strides, idx = torch._prims_common.check_significant_strides(
|
| 75 |
+
nv, rv, only_cuda=False
|
| 76 |
+
)
|
| 77 |
+
assert same_strides, (
|
| 78 |
+
f"{desc()}: strides {nv.stride()} != {rv.stride()} (mismatch at index {idx})"
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def check_metadata_matches(n, r, desc):
|
| 83 |
+
assert callable(desc)
|
| 84 |
+
n_vals, _n_spec = pytree.tree_flatten(n)
|
| 85 |
+
r_vals, _r_spec = pytree.tree_flatten(r)
|
| 86 |
+
# TODO: test the specs match; empirically sometimes we have a tuple
|
| 87 |
+
# on one side and a list on the other
|
| 88 |
+
assert len(n_vals) == len(r_vals), f"{len(n_vals)} != {len(r_vals)}"
|
| 89 |
+
for i, nv, rv in zip(range(len(n_vals)), n_vals, r_vals):
|
| 90 |
+
if not isinstance(rv, torch.Tensor):
|
| 91 |
+
continue
|
| 92 |
+
check_tensor_metadata_matches(nv, rv, lambda: f"{desc()} output {i}")
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class Lit:
|
| 96 |
+
def __init__(self, s):
|
| 97 |
+
self.s = s
|
| 98 |
+
|
| 99 |
+
def __repr__(self):
|
| 100 |
+
return self.s
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def _fmt(a: object) -> object:
|
| 104 |
+
if isinstance(a, torch.Tensor):
|
| 105 |
+
return Lit(
|
| 106 |
+
f"torch.empty_strided({tuple(a.size())}, {a.stride()}, dtype={a.dtype})"
|
| 107 |
+
)
|
| 108 |
+
else:
|
| 109 |
+
return a
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def make_crossref_functionalize(
|
| 113 |
+
op: torch._ops.OpOverload[_P, _T], final_key: DispatchKey
|
| 114 |
+
) -> Union[Callable[_P, _T], DispatchKey]:
|
| 115 |
+
from torch._subclasses.fake_tensor import FakeTensorMode
|
| 116 |
+
|
| 117 |
+
# This case is pretty weird, suppress it for now
|
| 118 |
+
if op == torch.ops.aten.lift_fresh.default:
|
| 119 |
+
return final_key
|
| 120 |
+
|
| 121 |
+
def handler(*args: _P.args, **kwargs: _P.kwargs) -> _T:
|
| 122 |
+
fake_mode = FakeTensorMode()
|
| 123 |
+
|
| 124 |
+
def fakeify_defun(t):
|
| 125 |
+
if isinstance(t, torch.Tensor):
|
| 126 |
+
if torch._is_functional_tensor(t):
|
| 127 |
+
r = torch._from_functional_tensor(t)
|
| 128 |
+
# NB: This assumes that the inner tensor sizes/strides match
|
| 129 |
+
# the outer tensor sizes/strides. This doesn't necessarily have to
|
| 130 |
+
# be the case, see discussion at
|
| 131 |
+
# https://github.com/pytorch/pytorch/pull/87610/files/401ddeda1d769bedc88a12de332c7357b60e51a4#r1007264456
|
| 132 |
+
assert t.size() == r.size()
|
| 133 |
+
assert t.stride() == r.stride()
|
| 134 |
+
else:
|
| 135 |
+
r = t
|
| 136 |
+
# TODO: suppress guards
|
| 137 |
+
return fake_mode.from_tensor(r)
|
| 138 |
+
return t
|
| 139 |
+
|
| 140 |
+
def maybe_detach(t):
|
| 141 |
+
if isinstance(t, torch.Tensor):
|
| 142 |
+
return t.detach()
|
| 143 |
+
else:
|
| 144 |
+
return t
|
| 145 |
+
|
| 146 |
+
# TODO: This probably does the wrong thing if you're running other
|
| 147 |
+
# substantive modes with the normal op outside here
|
| 148 |
+
with (
|
| 149 |
+
torch.utils._python_dispatch._disable_current_modes(),
|
| 150 |
+
suspend_functionalization(),
|
| 151 |
+
):
|
| 152 |
+
f_args, f_kwargs = pytree.tree_map(fakeify_defun, (args, kwargs))
|
| 153 |
+
orig_f_args, orig_f_kwargs = pytree.tree_map(
|
| 154 |
+
maybe_detach, (f_args, f_kwargs)
|
| 155 |
+
)
|
| 156 |
+
with fake_mode:
|
| 157 |
+
f_r = op(*f_args, **f_kwargs)
|
| 158 |
+
r = op._op_dk(final_key, *args, **kwargs)
|
| 159 |
+
|
| 160 |
+
def desc():
|
| 161 |
+
fmt_args = ", ".join(
|
| 162 |
+
itertools.chain(
|
| 163 |
+
(repr(pytree.tree_map(_fmt, a)) for a in orig_f_args),
|
| 164 |
+
(
|
| 165 |
+
f"{k}={pytree.tree_map(_fmt, v)}"
|
| 166 |
+
for k, v in orig_f_kwargs.items()
|
| 167 |
+
),
|
| 168 |
+
)
|
| 169 |
+
)
|
| 170 |
+
return f"{op}({fmt_args})"
|
| 171 |
+
|
| 172 |
+
check_metadata_matches(f_r, r, desc)
|
| 173 |
+
return r
|
| 174 |
+
|
| 175 |
+
return handler
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
# NB: enabling this is slow, don't do it in a hot loop. This is purely
|
| 179 |
+
# for debugging purposes.
|
| 180 |
+
@contextmanager
|
| 181 |
+
def enable_crossref_functionalize():
|
| 182 |
+
for op in all_py_loaded_overloads():
|
| 183 |
+
op._uncache_dispatch(torch._C.DispatchKey.Functionalize)
|
| 184 |
+
try:
|
| 185 |
+
with (
|
| 186 |
+
enable_python_dispatcher(),
|
| 187 |
+
unittest.mock.patch("torch._dispatch.python.CROSSREF_FUNCTIONALIZE", True),
|
| 188 |
+
):
|
| 189 |
+
yield
|
| 190 |
+
finally:
|
| 191 |
+
for op in all_py_loaded_overloads():
|
| 192 |
+
op._uncache_dispatch(torch._C.DispatchKey.Functionalize)
|