cranky-coder08 commited on
Commit
59f1501
·
verified ·
1 Parent(s): 44823a3

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. phivenv/Lib/site-packages/torch/_C/_VariableFunctions.pyi +0 -0
  3. phivenv/Lib/site-packages/torch/_C/__init__.pyi +0 -0
  4. phivenv/Lib/site-packages/torch/_C/_aoti.pyi +164 -0
  5. phivenv/Lib/site-packages/torch/_C/_autograd.pyi +141 -0
  6. phivenv/Lib/site-packages/torch/_C/_cpu.pyi +13 -0
  7. phivenv/Lib/site-packages/torch/_C/_cudnn.pyi +14 -0
  8. phivenv/Lib/site-packages/torch/_C/_cusparselt.pyi +1 -0
  9. phivenv/Lib/site-packages/torch/_C/_distributed_autograd.pyi +26 -0
  10. phivenv/Lib/site-packages/torch/_C/_distributed_c10d.pyi +797 -0
  11. phivenv/Lib/site-packages/torch/_C/_distributed_rpc.pyi +188 -0
  12. phivenv/Lib/site-packages/torch/_C/_distributed_rpc_testing.pyi +32 -0
  13. phivenv/Lib/site-packages/torch/_C/_dynamo/__init__.pyi +4 -0
  14. phivenv/Lib/site-packages/torch/_C/_dynamo/compiled_autograd.pyi +13 -0
  15. phivenv/Lib/site-packages/torch/_C/_dynamo/eval_frame.pyi +71 -0
  16. phivenv/Lib/site-packages/torch/_C/_dynamo/guards.pyi +191 -0
  17. phivenv/Lib/site-packages/torch/_C/_export/__init__.pyi +9 -0
  18. phivenv/Lib/site-packages/torch/_C/_export/pt2_archive_constants.pyi +22 -0
  19. phivenv/Lib/site-packages/torch/_C/_functions.pyi +19 -0
  20. phivenv/Lib/site-packages/torch/_C/_functorch.pyi +86 -0
  21. phivenv/Lib/site-packages/torch/_C/_instruction_counter.pyi +4 -0
  22. phivenv/Lib/site-packages/torch/_C/_itt.pyi +5 -0
  23. phivenv/Lib/site-packages/torch/_C/_lazy.pyi +26 -0
  24. phivenv/Lib/site-packages/torch/_C/_lazy_ts_backend.pyi +12 -0
  25. phivenv/Lib/site-packages/torch/_C/_monitor.pyi +58 -0
  26. phivenv/Lib/site-packages/torch/_C/_nn.pyi +175 -0
  27. phivenv/Lib/site-packages/torch/_C/_nvtx.pyi +9 -0
  28. phivenv/Lib/site-packages/torch/_C/_onnx.pyi +39 -0
  29. phivenv/Lib/site-packages/torch/_C/_profiler.pyi +246 -0
  30. phivenv/Lib/site-packages/torch/_C/_verbose.pyi +3 -0
  31. phivenv/Lib/site-packages/torch/_C_flatbuffer/__init__.pyi +11 -0
  32. phivenv/Lib/site-packages/torch/_awaits/__init__.py +53 -0
  33. phivenv/Lib/site-packages/torch/_awaits/__pycache__/__init__.cpython-39.pyc +0 -0
  34. phivenv/Lib/site-packages/torch/_custom_op/__init__.py +0 -0
  35. phivenv/Lib/site-packages/torch/_custom_op/__pycache__/__init__.cpython-39.pyc +0 -0
  36. phivenv/Lib/site-packages/torch/_custom_op/__pycache__/autograd.cpython-39.pyc +0 -0
  37. phivenv/Lib/site-packages/torch/_custom_op/__pycache__/impl.cpython-39.pyc +0 -0
  38. phivenv/Lib/site-packages/torch/_custom_op/autograd.py +307 -0
  39. phivenv/Lib/site-packages/torch/_custom_op/impl.py +715 -0
  40. phivenv/Lib/site-packages/torch/_decomp/__init__.py +544 -0
  41. phivenv/Lib/site-packages/torch/_decomp/__pycache__/__init__.cpython-39.pyc +0 -0
  42. phivenv/Lib/site-packages/torch/_decomp/__pycache__/decompositions_for_jvp.cpython-39.pyc +0 -0
  43. phivenv/Lib/site-packages/torch/_decomp/__pycache__/decompositions_for_rng.cpython-39.pyc +0 -0
  44. phivenv/Lib/site-packages/torch/_decomp/decompositions.py +0 -0
  45. phivenv/Lib/site-packages/torch/_decomp/decompositions_for_jvp.py +335 -0
  46. phivenv/Lib/site-packages/torch/_decomp/decompositions_for_rng.py +266 -0
  47. phivenv/Lib/site-packages/torch/_dispatch/__init__.py +0 -0
  48. phivenv/Lib/site-packages/torch/_dispatch/__pycache__/__init__.cpython-39.pyc +0 -0
  49. phivenv/Lib/site-packages/torch/_dispatch/__pycache__/python.cpython-39.pyc +0 -0
  50. phivenv/Lib/site-packages/torch/_dispatch/python.py +192 -0
.gitattributes CHANGED
@@ -112,3 +112,6 @@ phivenv/Lib/site-packages/torch/lib/cpuinfo.lib filter=lfs diff=lfs merge=lfs -t
112
  phivenv/Lib/site-packages/torch/lib/fbgemm.dll filter=lfs diff=lfs merge=lfs -text
113
  phivenv/Lib/site-packages/torch/lib/c10.dll filter=lfs diff=lfs merge=lfs -text
114
  phivenv/Lib/site-packages/torch/lib/fbgemm.lib filter=lfs diff=lfs merge=lfs -text
 
 
 
 
112
  phivenv/Lib/site-packages/torch/lib/fbgemm.dll filter=lfs diff=lfs merge=lfs -text
113
  phivenv/Lib/site-packages/torch/lib/c10.dll filter=lfs diff=lfs merge=lfs -text
114
  phivenv/Lib/site-packages/torch/lib/fbgemm.lib filter=lfs diff=lfs merge=lfs -text
115
+ phivenv/Lib/site-packages/torch/lib/fmt.lib filter=lfs diff=lfs merge=lfs -text
116
+ phivenv/Lib/site-packages/torch/lib/libiomp5md.dll filter=lfs diff=lfs merge=lfs -text
117
+ phivenv/Lib/site-packages/torch/lib/libittnotify.lib filter=lfs diff=lfs merge=lfs -text
phivenv/Lib/site-packages/torch/_C/_VariableFunctions.pyi ADDED
The diff for this file is too large to render. See raw diff
 
phivenv/Lib/site-packages/torch/_C/__init__.pyi ADDED
The diff for this file is too large to render. See raw diff
 
phivenv/Lib/site-packages/torch/_C/_aoti.pyi ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ctypes import c_void_p
2
+ from typing import overload, Protocol
3
+
4
+ from torch import Tensor
5
+
6
+ # Defined in torch/csrc/inductor/aoti_runner/pybind.cpp
7
+
8
+ # Tensor to AtenTensorHandle
9
+ def unsafe_alloc_void_ptrs_from_tensors(tensors: list[Tensor]) -> list[c_void_p]: ...
10
+ def unsafe_alloc_void_ptr_from_tensor(tensor: Tensor) -> c_void_p: ...
11
+
12
+ # AtenTensorHandle to Tensor
13
+ def alloc_tensors_by_stealing_from_void_ptrs(
14
+ handles: list[c_void_p],
15
+ ) -> list[Tensor]: ...
16
+ def alloc_tensor_by_stealing_from_void_ptr(
17
+ handle: c_void_p,
18
+ ) -> Tensor: ...
19
+
20
+ class AOTIModelContainerRunner(Protocol):
21
+ def run(
22
+ self, inputs: list[Tensor], stream_handle: c_void_p = ...
23
+ ) -> list[Tensor]: ...
24
+ def get_call_spec(self) -> list[str]: ...
25
+ def get_constant_names_to_original_fqns(self) -> dict[str, str]: ...
26
+ def get_constant_names_to_dtypes(self) -> dict[str, int]: ...
27
+ def extract_constants_map(self, use_inactive: bool) -> dict[str, Tensor]: ...
28
+ def update_constant_buffer(
29
+ self,
30
+ tensor_map: dict[str, Tensor],
31
+ use_inactive: bool,
32
+ validate_full_updates: bool,
33
+ user_managed: bool = ...,
34
+ ) -> None: ...
35
+ def swap_constant_buffer(self) -> None: ...
36
+ def free_inactive_constant_buffer(self) -> None: ...
37
+
38
+ class AOTIModelContainerRunnerCpu:
39
+ def __init__(self, model_so_path: str, num_models: int) -> None: ...
40
+ def run(
41
+ self, inputs: list[Tensor], stream_handle: c_void_p = ...
42
+ ) -> list[Tensor]: ...
43
+ def get_call_spec(self) -> list[str]: ...
44
+ def get_constant_names_to_original_fqns(self) -> dict[str, str]: ...
45
+ def get_constant_names_to_dtypes(self) -> dict[str, int]: ...
46
+ def extract_constants_map(self, use_inactive: bool) -> dict[str, Tensor]: ...
47
+ def update_constant_buffer(
48
+ self,
49
+ tensor_map: dict[str, Tensor],
50
+ use_inactive: bool,
51
+ validate_full_updates: bool,
52
+ user_managed: bool = ...,
53
+ ) -> None: ...
54
+ def swap_constant_buffer(self) -> None: ...
55
+ def free_inactive_constant_buffer(self) -> None: ...
56
+
57
+ class AOTIModelContainerRunnerCuda:
58
+ @overload
59
+ def __init__(self, model_so_path: str, num_models: int) -> None: ...
60
+ @overload
61
+ def __init__(
62
+ self, model_so_path: str, num_models: int, device_str: str
63
+ ) -> None: ...
64
+ @overload
65
+ def __init__(
66
+ self, model_so_path: str, num_models: int, device_str: str, cubin_dir: str
67
+ ) -> None: ...
68
+ def run(
69
+ self, inputs: list[Tensor], stream_handle: c_void_p = ...
70
+ ) -> list[Tensor]: ...
71
+ def get_call_spec(self) -> list[str]: ...
72
+ def get_constant_names_to_original_fqns(self) -> dict[str, str]: ...
73
+ def get_constant_names_to_dtypes(self) -> dict[str, int]: ...
74
+ def extract_constants_map(self, use_inactive: bool) -> dict[str, Tensor]: ...
75
+ def update_constant_buffer(
76
+ self,
77
+ tensor_map: dict[str, Tensor],
78
+ use_inactive: bool,
79
+ validate_full_updates: bool,
80
+ user_managed: bool = ...,
81
+ ) -> None: ...
82
+ def swap_constant_buffer(self) -> None: ...
83
+ def free_inactive_constant_buffer(self) -> None: ...
84
+
85
+ class AOTIModelContainerRunnerXpu:
86
+ @overload
87
+ def __init__(self, model_so_path: str, num_models: int) -> None: ...
88
+ @overload
89
+ def __init__(
90
+ self, model_so_path: str, num_models: int, device_str: str
91
+ ) -> None: ...
92
+ @overload
93
+ def __init__(
94
+ self, model_so_path: str, num_models: int, device_str: str, kernel_bin_dir: str
95
+ ) -> None: ...
96
+ def run(
97
+ self, inputs: list[Tensor], stream_handle: c_void_p = ...
98
+ ) -> list[Tensor]: ...
99
+ def get_call_spec(self) -> list[str]: ...
100
+ def get_constant_names_to_original_fqns(self) -> dict[str, str]: ...
101
+ def get_constant_names_to_dtypes(self) -> dict[str, int]: ...
102
+ def extract_constants_map(self, use_inactive: bool) -> dict[str, Tensor]: ...
103
+ def update_constant_buffer(
104
+ self,
105
+ tensor_map: dict[str, Tensor],
106
+ use_inactive: bool,
107
+ validate_full_updates: bool,
108
+ user_managed: bool = ...,
109
+ ) -> None: ...
110
+ def swap_constant_buffer(self) -> None: ...
111
+ def free_inactive_constant_buffer(self) -> None: ...
112
+
113
+ class AOTIModelContainerRunnerMps:
114
+ def __init__(self, model_so_path: str, num_models: int) -> None: ...
115
+ def run(
116
+ self, inputs: list[Tensor], stream_handle: c_void_p = ...
117
+ ) -> list[Tensor]: ...
118
+ def get_call_spec(self) -> list[str]: ...
119
+ def get_constant_names_to_original_fqns(self) -> dict[str, str]: ...
120
+ def get_constant_names_to_dtypes(self) -> dict[str, int]: ...
121
+ def extract_constants_map(self, use_inactive: bool) -> dict[str, Tensor]: ...
122
+ def update_constant_buffer(
123
+ self,
124
+ tensor_map: dict[str, Tensor],
125
+ use_inactive: bool,
126
+ validate_full_updates: bool,
127
+ user_managed: bool = ...,
128
+ ) -> None: ...
129
+ def swap_constant_buffer(self) -> None: ...
130
+ def free_inactive_constant_buffer(self) -> None: ...
131
+
132
+ # Defined in torch/csrc/inductor/aoti_package/pybind.cpp
133
+ class AOTIModelPackageLoader:
134
+ def __init__(
135
+ self,
136
+ model_package_path: str,
137
+ model_name: str,
138
+ run_single_threaded: bool,
139
+ num_runners: int,
140
+ device_index: int,
141
+ ) -> None: ...
142
+ def get_metadata(self) -> dict[str, str]: ...
143
+ def run(
144
+ self, inputs: list[Tensor], stream_handle: c_void_p = ...
145
+ ) -> list[Tensor]: ...
146
+ def boxed_run(
147
+ self, inputs: list[Tensor], stream_handle: c_void_p = ...
148
+ ) -> list[Tensor]: ...
149
+ def get_call_spec(self) -> list[str]: ...
150
+ def get_constant_fqns(self) -> list[str]: ...
151
+ def load_constants(
152
+ self,
153
+ constants_map: dict[str, Tensor],
154
+ use_inactive: bool,
155
+ check_full_update: bool,
156
+ user_managed: bool = ...,
157
+ ) -> None: ...
158
+ def update_constant_buffer(
159
+ self,
160
+ tensor_map: dict[str, Tensor],
161
+ use_inactive: bool,
162
+ validate_full_updates: bool,
163
+ user_managed: bool = ...,
164
+ ) -> None: ...
phivenv/Lib/site-packages/torch/_C/_autograd.pyi ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from enum import Enum
3
+ from typing import Any, Callable
4
+
5
+ import torch
6
+ from torch._C._profiler import (
7
+ _ProfilerEvent,
8
+ ActiveProfilerType,
9
+ ProfilerActivity,
10
+ ProfilerConfig,
11
+ )
12
+
13
+ # Defined in torch/csrc/autograd/init.cpp
14
+
15
+ class DeviceType(Enum):
16
+ CPU = ...
17
+ CUDA = ...
18
+ XPU = ...
19
+ MKLDNN = ...
20
+ OPENGL = ...
21
+ OPENCL = ...
22
+ IDEEP = ...
23
+ HIP = ...
24
+ FPGA = ...
25
+ MAIA = ...
26
+ XLA = ...
27
+ MTIA = ...
28
+ MPS = ...
29
+ HPU = ...
30
+ Meta = ...
31
+ Vulkan = ...
32
+ Metal = ...
33
+ PrivateUse1 = ...
34
+
35
+ class ProfilerEvent:
36
+ def cpu_elapsed_us(self, other: ProfilerEvent) -> float: ...
37
+ def cpu_memory_usage(self) -> int: ...
38
+ def cuda_elapsed_us(self, other: ProfilerEvent) -> float: ...
39
+ def privateuse1_elapsed_us(self, other: ProfilerEvent) -> float: ...
40
+ def cuda_memory_usage(self) -> int: ...
41
+ def device(self) -> int: ...
42
+ def handle(self) -> int: ...
43
+ def has_cuda(self) -> bool: ...
44
+ def is_remote(self) -> bool: ...
45
+ def kind(self) -> int: ...
46
+ def name(self) -> str: ...
47
+ def node_id(self) -> int: ...
48
+ def sequence_nr(self) -> int: ...
49
+ def shapes(self) -> list[list[int]]: ...
50
+ def thread_id(self) -> int: ...
51
+ def flops(self) -> float: ...
52
+ def is_async(self) -> bool: ...
53
+
54
+ class _KinetoEvent:
55
+ def name(self) -> str: ...
56
+ def overload_name(self) -> str: ...
57
+ def device_index(self) -> int: ...
58
+ def device_resource_id(self) -> int: ...
59
+ def start_ns(self) -> int: ...
60
+ def end_ns(self) -> int: ...
61
+ def duration_ns(self) -> int: ...
62
+ def is_async(self) -> bool: ...
63
+ def linked_correlation_id(self) -> int: ...
64
+ def shapes(self) -> list[list[int]]: ...
65
+ def dtypes(self) -> list[str]: ...
66
+ def concrete_inputs(self) -> list[Any]: ...
67
+ def kwinputs(self) -> dict[str, Any]: ...
68
+ def device_type(self) -> DeviceType: ...
69
+ def start_thread_id(self) -> int: ...
70
+ def end_thread_id(self) -> int: ...
71
+ def correlation_id(self) -> int: ...
72
+ def fwd_thread_id(self) -> int: ...
73
+ def stack(self) -> list[str]: ...
74
+ def scope(self) -> int: ...
75
+ def sequence_nr(self) -> int: ...
76
+ def flops(self) -> int: ...
77
+ def cuda_elapsed_us(self) -> int: ...
78
+ def privateuse1_elapsed_us(self) -> int: ...
79
+ def is_user_annotation(self) -> bool: ...
80
+
81
+ class _ProfilerResult:
82
+ def events(self) -> list[_KinetoEvent]: ...
83
+ def legacy_events(self) -> list[list[ProfilerEvent]]: ...
84
+ def save(self, path: str) -> None: ...
85
+ def experimental_event_tree(self) -> list[_ProfilerEvent]: ...
86
+ def trace_start_ns(self) -> int: ...
87
+
88
+ class SavedTensor: ...
89
+
90
+ def _enable_profiler(
91
+ config: ProfilerConfig,
92
+ activities: set[ProfilerActivity],
93
+ ) -> None: ...
94
+ def _prepare_profiler(
95
+ config: ProfilerConfig,
96
+ activities: set[ProfilerActivity],
97
+ ) -> None: ...
98
+ def _toggle_collection_dynamic(
99
+ enable: bool,
100
+ activities: set[ProfilerActivity],
101
+ ) -> None: ...
102
+ def _disable_profiler() -> _ProfilerResult: ...
103
+ def _profiler_enabled() -> bool: ...
104
+ def _add_metadata_json(key: str, value: str) -> None: ...
105
+ def _kineto_step() -> None: ...
106
+ def _get_current_graph_task_keep_graph() -> bool: ...
107
+ def _get_sequence_nr() -> int: ...
108
+ def kineto_available() -> bool: ...
109
+ def _record_function_with_args_enter(name: str, *args) -> torch.Tensor: ...
110
+ def _record_function_with_args_exit(handle: torch.Tensor) -> None: ...
111
+ def _supported_activities() -> set[ProfilerActivity]: ...
112
+ def _enable_record_function(enable: bool) -> None: ...
113
+ def _set_empty_test_observer(is_global: bool, sampling_prob: float) -> None: ...
114
+ def _push_saved_tensors_default_hooks(
115
+ pack_hook: Callable[[torch.Tensor], Any],
116
+ unpack_hook: Callable[[Any], torch.Tensor],
117
+ ) -> None: ...
118
+ def _pop_saved_tensors_default_hooks() -> None: ...
119
+ def _top_saved_tensors_default_hooks(
120
+ ignore_is_tracing: bool,
121
+ ) -> tuple[Callable[[torch.Tensor], Any], Callable[[Any], torch.Tensor]]: ...
122
+ def _unsafe_set_version_counter(
123
+ t: tuple[torch.Tensor, ...], prev_version: tuple[int, ...]
124
+ ) -> None: ...
125
+ def _enable_profiler_legacy(config: ProfilerConfig) -> None: ...
126
+ def _disable_profiler_legacy() -> list[list[ProfilerEvent]]: ...
127
+ def _profiler_type() -> ActiveProfilerType: ...
128
+ def _saved_tensors_hooks_enable() -> None: ...
129
+ def _saved_tensors_hooks_disable(message: str, fail_if_non_empty=True) -> None: ...
130
+ def _saved_tensors_hooks_get_disabled_error_message() -> str | None: ...
131
+ def _saved_tensors_hooks_set_tracing(is_tracing: bool) -> bool: ...
132
+
133
+ class CreationMeta(Enum):
134
+ DEFAULT = ...
135
+ IN_CUSTOM_FUNCTION = ...
136
+ MULTI_OUTPUT_NODE = ...
137
+ NO_GRAD_MODE = ...
138
+ INFERENCE_MODE = ...
139
+
140
+ def _set_creation_meta(t: torch.Tensor, creation_meta: CreationMeta) -> None: ...
141
+ def _get_creation_meta(t: torch.Tensor) -> CreationMeta: ...
phivenv/Lib/site-packages/torch/_C/_cpu.pyi ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.types import _bool, _int
2
+
3
+ # Defined in torch/csrc/cpu/Module.cpp
4
+
5
+ def _is_avx2_supported() -> _bool: ...
6
+ def _is_avx512_supported() -> _bool: ...
7
+ def _is_avx512_vnni_supported() -> _bool: ...
8
+ def _is_avx512_bf16_supported() -> _bool: ...
9
+ def _is_amx_tile_supported() -> _bool: ...
10
+ def _is_amx_fp16_supported() -> _bool: ...
11
+ def _init_amx() -> _bool: ...
12
+ def _L1d_cache_size() -> _int: ...
13
+ def _L2_cache_size() -> _int: ...
phivenv/Lib/site-packages/torch/_C/_cudnn.pyi ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import IntEnum
2
+
3
+ # Defined in torch/csrc/cuda/shared/cudnn.cpp
4
+ is_cuda: bool
5
+
6
+ def getRuntimeVersion() -> tuple[int, int, int]: ...
7
+ def getCompileVersion() -> tuple[int, int, int]: ...
8
+ def getVersionInt() -> int: ...
9
+
10
+ class RNNMode(IntEnum):
11
+ rnn_relu = ...
12
+ rnn_tanh = ...
13
+ lstm = ...
14
+ gru = ...
phivenv/Lib/site-packages/torch/_C/_cusparselt.pyi ADDED
@@ -0,0 +1 @@
 
 
1
+ def getVersionInt() -> int: ...
phivenv/Lib/site-packages/torch/_C/_distributed_autograd.pyi ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ import torch
4
+
5
+ # This module is defined in torch/csrc/distributed/autograd/init.cpp
6
+
7
+ class DistAutogradContext:
8
+ def _context_id(self) -> int: ...
9
+ def _recv_functions(self) -> dict[int, Any]: ...
10
+ def _send_functions(self) -> dict[int, Any]: ...
11
+ def _known_worker_ids(self) -> set[int]: ...
12
+
13
+ def _new_context() -> DistAutogradContext: ...
14
+ def _release_context(context_id: int) -> None: ...
15
+ def _get_max_id() -> int: ...
16
+ def _is_valid_context(worker_id: int) -> bool: ...
17
+ def _retrieve_context(context_id: int) -> DistAutogradContext: ...
18
+ def _current_context() -> DistAutogradContext: ...
19
+ def _init(worker_id: int) -> None: ...
20
+ def _get_debug_info() -> dict[str, str]: ...
21
+ def backward(
22
+ context_id: int,
23
+ roots: list[torch.Tensor],
24
+ retain_graph: bool = False,
25
+ ) -> None: ...
26
+ def get_gradients(context_id: int) -> dict[torch.Tensor, torch.Tensor]: ...
phivenv/Lib/site-packages/torch/_C/_distributed_c10d.pyi ADDED
@@ -0,0 +1,797 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ # mypy: disable-error-code="type-arg"
3
+ from datetime import timedelta
4
+ from enum import Enum
5
+ from typing import Any, Optional, overload, Union
6
+
7
+ import torch
8
+ from torch import Tensor
9
+ from torch._C import ScriptObject
10
+ from torch._C._autograd import DeviceType
11
+ from torch.futures import Future
12
+
13
+ # This module is defined in torch/csrc/distributed/c10d/init.cpp
14
+
15
+ _DEFAULT_FIRST_BUCKET_BYTES: int
16
+ _DEFAULT_NO_TIMEOUT: timedelta
17
+ _DEFAULT_PG_TIMEOUT: timedelta
18
+ _DEFAULT_PG_NCCL_TIMEOUT: timedelta
19
+
20
+ class BuiltinCommHookType(Enum):
21
+ ALLREDUCE = ...
22
+ FP16_COMPRESS = ...
23
+
24
+ def _register_comm_hook(reducer: Reducer, state: Any, comm_hook: Any): ...
25
+ def _register_builtin_comm_hook(
26
+ reducer: Reducer,
27
+ comm_hook_type: BuiltinCommHookType,
28
+ ): ...
29
+ def _set_global_rank(rank: int) -> None: ...
30
+ def _hash_tensors(tensors: list[Tensor]) -> int: ...
31
+
32
+ class GradBucket:
33
+ def index(self) -> int: ...
34
+ def buffer(self) -> Tensor: ...
35
+ def gradients(self) -> list[Tensor]: ...
36
+ def is_last(self) -> bool: ...
37
+ def set_buffer(self, tensor: Tensor) -> None: ...
38
+ def parameters(self) -> list[Tensor]: ...
39
+
40
+ class Reducer:
41
+ def __init__(
42
+ self,
43
+ params: list[Tensor],
44
+ bucket_indices: list[list[int]],
45
+ per_bucket_size_limits: list[int],
46
+ process_group: ProcessGroup,
47
+ expect_sparse_gradients: list[bool] = ...,
48
+ bucket_bytes_cap: int = ..., # kDefaultBucketBytesCap in reducer.hpp
49
+ find_unused_parameters: bool = ...,
50
+ gradient_as_bucket_view: bool = ...,
51
+ param_to_name_mapping: dict[int, str] = ...,
52
+ first_bucket_types_cap: int = ..., # kDefaultFirstBucketBytes in reducer.hpp
53
+ skip_all_reduce_unused_params: bool = ...,
54
+ use_python_reducer: bool = ...,
55
+ ) -> None: ...
56
+ def prepare_for_forward(self) -> None: ...
57
+ def prepare_for_backward(self, output: list[Tensor]) -> None: ...
58
+ def get_backward_stats(self) -> list[int]: ...
59
+ def _install_post_backward_futures(self, futures: list[Future]) -> None: ...
60
+ def _rebuild_buckets(self) -> bool: ...
61
+ def _get_zeros_like_grad_buckets(self) -> list[GradBucket]: ...
62
+ def _push_all_rebuilt_params(self) -> None: ...
63
+ def _set_forward_pass_work_handle(
64
+ self,
65
+ work: Work,
66
+ use_static_world_size: bool,
67
+ ): ...
68
+ def _get_local_used_map(self) -> Tensor: ...
69
+ def _set_ddp_runtime_logging_sample_rate(self, sample_rate: int) -> None: ...
70
+ def _set_static_graph(self) -> None: ...
71
+ def _run_comm_hook(self, bucket: GradBucket) -> Future: ...
72
+ def set_logger(self, logger: Logger) -> None: ...
73
+ def _remove_autograd_hooks(self) -> None: ...
74
+ def _check_reducer_finalized(self) -> None: ...
75
+ def _set_sparse_metadata(self, global_unique_ids: dict[str, Tensor]) -> None: ...
76
+ def _reset_state(self) -> None: ...
77
+ def _update_process_group(self, new_process_group: ProcessGroup) -> None: ...
78
+
79
+ class DDPLoggingData:
80
+ strs_map: dict[str, str]
81
+ ints_map: dict[str, int]
82
+
83
+ class Logger:
84
+ def __init__(self, reducer: Reducer) -> None: ...
85
+ def set_construction_data_and_log(
86
+ self,
87
+ module_name: str,
88
+ device_ids: list[int],
89
+ output_device: int,
90
+ broadcast_buffers: bool,
91
+ has_sync_bn: bool,
92
+ static_graph: bool,
93
+ ): ...
94
+ def set_runtime_stats_and_log(self) -> None: ...
95
+ def set_error_and_log(self, error: str) -> None: ...
96
+ def _get_ddp_logging_data(self) -> DDPLoggingData: ...
97
+ def _set_comm_hook_name(self, comm_hook: str) -> None: ...
98
+ def _set_uneven_input_join(self) -> None: ...
99
+ def _set_static_graph(self) -> None: ...
100
+
101
+ class _WorkerServer:
102
+ def __init__(self, socket_path: str) -> None: ...
103
+ def shutdown(self) -> None: ...
104
+
105
+ def get_debug_level(): ...
106
+ def set_debug_level(): ...
107
+ def set_debug_level_from_env(): ...
108
+
109
+ class DebugLevel(Enum):
110
+ OFF = ...
111
+ INFO = ...
112
+ DETAIL = ...
113
+
114
+ class ReduceOp:
115
+ def __init__(self, op: RedOpType) -> None: ...
116
+
117
+ SUM: RedOpType = ...
118
+ AVG: RedOpType = ...
119
+ PRODUCT: RedOpType = ...
120
+ MIN: RedOpType = ...
121
+ MAX: RedOpType = ...
122
+ BAND: RedOpType = ...
123
+ BOR: RedOpType = ...
124
+ BXOR: RedOpType = ...
125
+ PREMUL_SUM: RedOpType = ...
126
+ UNUSED: RedOpType = ...
127
+
128
+ # mypy error being ignored:
129
+ # Detected enum "torch._C._distributed_c10d.ReduceOp.RedOpType" in a type
130
+ # stub with zero members. There is a chance this is due to a recent change
131
+ # in the semantics of enum membership. If so, use `member = value` to mark
132
+ # an enum member, instead of `member: type`
133
+ class RedOpType(Enum): ... # type: ignore[misc]
134
+
135
+ class BroadcastOptions:
136
+ rootRank: int
137
+ rootTensor: int
138
+ timeout: timedelta
139
+ asyncOp: bool
140
+
141
+ class AllreduceOptions:
142
+ reduceOp: ReduceOp
143
+ timeout: timedelta
144
+ asyncOp: bool
145
+ sparseIndices: Optional[Tensor]
146
+
147
+ class AllreduceCoalescedOptions(AllreduceOptions): ...
148
+
149
+ class ReduceOptions:
150
+ reduceOp: ReduceOp
151
+ rootRank: int
152
+ rootTensor: int
153
+ timeout: timedelta
154
+ asyncOp: bool
155
+
156
+ class AllgatherOptions:
157
+ timeout: timedelta
158
+ asyncOp: bool
159
+
160
+ class GatherOptions:
161
+ rootRank: int
162
+ timeout: timedelta
163
+ asyncOp: bool
164
+
165
+ class ScatterOptions:
166
+ rootRank: int
167
+ timeout: timedelta
168
+ asyncOp: bool
169
+
170
+ class ReduceScatterOptions:
171
+ reduceOp: ReduceOp
172
+ timeout: timedelta
173
+ asyncOp: bool
174
+
175
+ class BarrierOptions:
176
+ device_ids: list[int]
177
+ device: torch.device
178
+ timeout: timedelta
179
+ asyncOp: bool
180
+
181
+ class AllToAllOptions:
182
+ timeout: timedelta
183
+ asyncOp: bool
184
+
185
+ class Store:
186
+ def set(self, key: str, value: str): ...
187
+ def get(self, key: str) -> bytes: ...
188
+ def add(self, key: str, value: int) -> int: ...
189
+ def check(self, keys: list[str]) -> bool: ...
190
+ def compare_set(
191
+ self,
192
+ key: str,
193
+ expected_value: str,
194
+ desired_value: str,
195
+ ) -> bytes: ...
196
+ def delete_key(self, key: str) -> bool: ...
197
+ def num_keys(self) -> int: ...
198
+ def set_timeout(self, timeout: timedelta): ...
199
+ @overload
200
+ def wait(self, keys: list[str]): ...
201
+ @overload
202
+ def wait(self, keys: list[str], timeout: timedelta): ...
203
+ def queue_pop(self, key: str, block: bool = True) -> bytes: ...
204
+ def queue_push(self, key: str, value: Union[bytes, str]) -> None: ...
205
+ def queue_len(self, key: str) -> int: ...
206
+
207
+ class FileStore(Store):
208
+ def __init__(self, path: str, numWorkers: int = ...) -> None: ...
209
+
210
+ class HashStore(Store):
211
+ def __init__(self) -> None: ...
212
+
213
+ class TCPStore(Store):
214
+ def __init__(
215
+ self,
216
+ host_name: str,
217
+ port: int,
218
+ world_size: int | None = ...,
219
+ is_master: bool = ...,
220
+ timeout: timedelta = ...,
221
+ wait_for_workers: bool = ...,
222
+ multi_tenant: bool = ...,
223
+ master_listen_fd: int | None = ...,
224
+ use_libuv: bool | None = ...,
225
+ ) -> None: ...
226
+ @property
227
+ def host(self) -> str: ...
228
+ @property
229
+ def port(self) -> int: ...
230
+
231
+ class PrefixStore(Store):
232
+ def __init__(self, prefix: str, store: Store) -> None: ...
233
+ @property
234
+ def underlying_store(self) -> Store: ...
235
+
236
+ class _ControlCollectives:
237
+ def barrier(self, key: str, timeout: timedelta, blocking: bool) -> None: ...
238
+ def broadcast_send(self, key: str, data: str, timeout: timedelta) -> None: ...
239
+ def broadcast_recv(self, key: str, timeout: timedelta) -> str: ...
240
+ def gather_send(self, key: str, data: str, timeout: timedelta) -> None: ...
241
+ def gather_recv(self, key: str, timeout: timedelta) -> str: ...
242
+ def scatter_send(self, key: str, data: str, timeout: timedelta) -> None: ...
243
+ def scatter_recv(self, key: str, timeout: timedelta) -> str: ...
244
+ def all_gather(self, key: str, data: str, timeout: timedelta) -> str: ...
245
+ def all_sum(self, key: str, data: int, timeout: timedelta) -> int: ...
246
+
247
+ class _StoreCollectives(_ControlCollectives):
248
+ def __init__(self, store: Store, rank: int, world_size: int) -> None: ...
249
+
250
+ class _DistributedBackendOptions:
251
+ def __init__(self) -> None: ...
252
+ @property
253
+ def store(self) -> Store: ...
254
+ @store.setter
255
+ def store(self, store: Store) -> None: ...
256
+ @property
257
+ def group_rank(self) -> int: ...
258
+ @group_rank.setter
259
+ def group_rank(self, rank: int) -> None: ...
260
+ @property
261
+ def group_size(self) -> int: ...
262
+ @group_size.setter
263
+ def group_size(self, size: int) -> None: ...
264
+ @property
265
+ def timeout(self) -> timedelta: ...
266
+ @timeout.setter
267
+ def timeout(self, timeout: timedelta) -> None: ...
268
+ @property
269
+ def group_id(self) -> str: ...
270
+ @group_id.setter
271
+ def group_id(self, group_id: str) -> None: ...
272
+ @property
273
+ def global_ranks_in_group(self) -> list[int]: ...
274
+ @global_ranks_in_group.setter
275
+ def global_ranks_in_group(self, ranks: list[int]) -> None: ...
276
+
277
+ class Work:
278
+ def is_completed(self) -> bool: ...
279
+ def is_success(self) -> bool: ...
280
+ def exception(self) -> Any: ...
281
+ def wait(self, timeout: timedelta = ...) -> bool: ...
282
+ def get_future(self) -> Future: ...
283
+ def source_rank(self) -> int: ...
284
+ def _source_rank(self) -> int: ...
285
+ def result(self) -> list[Tensor]: ...
286
+ def synchronize(self): ...
287
+ def boxed(self) -> ScriptObject: ...
288
+ @staticmethod
289
+ def unbox(obj: ScriptObject) -> Work: ...
290
+
291
+ class Backend:
292
+ class Options:
293
+ def __init__(self, backend: str, timeout: timedelta = ...) -> None: ...
294
+ @property
295
+ def backend(self) -> str: ...
296
+ @property
297
+ def _timeout(self) -> timedelta: ...
298
+ @_timeout.setter
299
+ def _timeout(self, val: timedelta) -> None: ...
300
+
301
+ def __init__(
302
+ self,
303
+ rank: int,
304
+ size: int,
305
+ ) -> None: ...
306
+ @property
307
+ def supports_splitting(self) -> bool: ...
308
+ @property
309
+ def supports_coalescing(self) -> bool: ...
310
+ @property
311
+ def supports_time_estimate(self) -> bool: ...
312
+ @property
313
+ def options(self) -> Options: ...
314
+ def rank(self) -> int: ...
315
+ def size(self) -> int: ...
316
+ def abort(self) -> None: ...
317
+ def shutdown(self) -> None: ...
318
+ def eager_connect_single_device(self, device: torch.device | None) -> None: ...
319
+ def _set_sequence_number_for_group(self) -> None: ...
320
+ def _set_default_timeout(self, timeout: timedelta) -> None: ...
321
+ def get_error(self) -> ErrorType: ...
322
+ def supports_tensor_alloc(self, device: torch.device) -> bool: ...
323
+ def allocate_tensor(
324
+ self,
325
+ size: int,
326
+ *,
327
+ dtype: torch.dtype,
328
+ device: torch.device,
329
+ ) -> Tensor: ...
330
+ @property
331
+ def mem_allocator(self) -> Any: ...
332
+
333
+ class ProcessGroup:
334
+ class BackendType(Enum):
335
+ UNDEFINED = ...
336
+ GLOO = ...
337
+ NCCL = ...
338
+ UCC = ...
339
+ MPI = ...
340
+ XCCL = ...
341
+ CUSTOM = ...
342
+
343
+ def __init__(
344
+ self,
345
+ store: Store,
346
+ rank: int,
347
+ size: int,
348
+ ) -> None: ...
349
+ def rank(self) -> int: ...
350
+ def size(self) -> int: ...
351
+ def abort(self) -> None: ...
352
+ def shutdown(self) -> None: ...
353
+ @overload
354
+ def broadcast(
355
+ self,
356
+ tensors: list[Tensor],
357
+ opts=...,
358
+ ) -> Work: ...
359
+ @overload
360
+ def broadcast(
361
+ self,
362
+ tensor: Tensor,
363
+ root: int,
364
+ ) -> Work: ...
365
+ @overload
366
+ def allreduce(
367
+ self,
368
+ tensors: list[Tensor],
369
+ opts: AllreduceOptions = ...,
370
+ ) -> Work: ...
371
+ @overload
372
+ def allreduce(
373
+ self,
374
+ tensors: list[Tensor],
375
+ op=...,
376
+ ) -> Work: ...
377
+ @overload
378
+ def allreduce(
379
+ self,
380
+ tensor: Tensor,
381
+ op=...,
382
+ ) -> Work: ...
383
+ def allreduce_coalesced(
384
+ self,
385
+ tensors: list[Tensor],
386
+ opts=...,
387
+ ) -> Work: ...
388
+ def reduce_scatter_tensor_coalesced(
389
+ self,
390
+ outputTensors: list[Tensor],
391
+ inputTensors: list[Tensor],
392
+ opts: ReduceScatterOptions | None = None,
393
+ ) -> Work: ...
394
+ @overload
395
+ def reduce(
396
+ self,
397
+ tensors: list[Tensor],
398
+ opts=...,
399
+ ) -> Work: ...
400
+ @overload
401
+ def reduce(
402
+ self,
403
+ tensor: Tensor,
404
+ root: int,
405
+ op=...,
406
+ ) -> Work: ...
407
+ @overload
408
+ def allgather(
409
+ self,
410
+ output_tensors: list[list[Tensor]],
411
+ input_tensors: list[Tensor],
412
+ opts=...,
413
+ ) -> Work: ...
414
+ @overload
415
+ def allgather(
416
+ self,
417
+ output_tensors: list[Tensor],
418
+ input_tensor: Tensor,
419
+ ) -> Work: ...
420
+ def _allgather_base(
421
+ self,
422
+ output: Tensor,
423
+ input: Tensor,
424
+ opts=...,
425
+ ) -> Work: ...
426
+ def allgather_coalesced(
427
+ self,
428
+ output_lists: list[list[Tensor]],
429
+ input_list: list[Tensor],
430
+ opts=...,
431
+ ) -> Work: ...
432
+ def allgather_into_tensor_coalesced(
433
+ self,
434
+ output_lists: list[Tensor],
435
+ input_list: list[Tensor],
436
+ opts=...,
437
+ ) -> Work: ...
438
+ @overload
439
+ def gather(
440
+ self,
441
+ output_tensors: list[list[Tensor]],
442
+ input_tensors: list[Tensor],
443
+ opts=...,
444
+ ) -> Work: ...
445
+ @overload
446
+ def gather(
447
+ self,
448
+ output_tensors: list[Tensor],
449
+ input_tensor: Tensor,
450
+ root: int,
451
+ ) -> Work: ...
452
+ @overload
453
+ def scatter(
454
+ self,
455
+ output_tensors: list[Tensor],
456
+ input_tensors: list[list[Tensor]],
457
+ opts=...,
458
+ ) -> Work: ...
459
+ @overload
460
+ def scatter(
461
+ self,
462
+ output_tensor: Tensor,
463
+ input_tensors: list[Tensor],
464
+ root: int,
465
+ ) -> Work: ...
466
+ @overload
467
+ def reduce_scatter(
468
+ self,
469
+ output_tensors: list[Tensor],
470
+ input_tensors: list[list[Tensor]],
471
+ opts=...,
472
+ ) -> Work: ...
473
+ @overload
474
+ def reduce_scatter(
475
+ self,
476
+ output_tensors: Tensor,
477
+ input_tensor: list[Tensor],
478
+ ) -> Work: ...
479
+ def _reduce_scatter_base(
480
+ self,
481
+ outputTensor: Tensor,
482
+ inputTensor: Tensor,
483
+ opts: ReduceScatterOptions | None,
484
+ ) -> Work: ...
485
+ @overload
486
+ def alltoall_base(
487
+ self,
488
+ output_tensor: Tensor,
489
+ input_tensor: Tensor,
490
+ output_split_sizes: list[int],
491
+ input_split_sizes: list[int],
492
+ opts=...,
493
+ ) -> Work: ...
494
+ @overload
495
+ def alltoall_base(
496
+ self,
497
+ output: Tensor,
498
+ input: Tensor,
499
+ output_split_sizes: list[int],
500
+ input_split_sizes: list[int],
501
+ ) -> Work: ...
502
+ @overload
503
+ def alltoall(
504
+ self,
505
+ output_tensor: list[Tensor],
506
+ input_tensor: list[Tensor],
507
+ opts=...,
508
+ ) -> Work: ...
509
+ @overload
510
+ def alltoall(
511
+ self,
512
+ output: list[Tensor],
513
+ input: list[Tensor],
514
+ ) -> Work: ...
515
+ def send(
516
+ self,
517
+ tensors: list[Tensor],
518
+ dstRank: int,
519
+ tag: int,
520
+ ) -> Work: ...
521
+ def recv(
522
+ self,
523
+ tensors: list[Tensor],
524
+ srcRank: int,
525
+ tag: int,
526
+ ) -> Work: ...
527
+ def recv_anysource(self, tensors: list[Tensor], tag: int) -> Work: ...
528
+ def barrier(self, opts=...) -> Work: ...
529
+ def boxed(self) -> ScriptObject: ...
530
+ @staticmethod
531
+ def unbox(obj: ScriptObject) -> ProcessGroup: ...
532
+ def _start_coalescing(self, device: torch.device) -> None: ...
533
+ def _end_coalescing(self, device: torch.device) -> Work: ...
534
+ def _get_backend_name(self) -> str: ...
535
+ def _backend_id(self, backend_type: BackendType) -> int: ...
536
+ @property
537
+ def _device_types(self) -> list[torch.device]: ...
538
+ def _get_backend(self, device: torch.device) -> Backend: ...
539
+ def _set_default_backend(self, backend_type: BackendType) -> None: ...
540
+ def _register_backend(
541
+ self,
542
+ device: torch.device,
543
+ backend_type: BackendType,
544
+ backend: Backend | None,
545
+ ) -> None: ...
546
+ def _set_group_name(self, name: str) -> None: ...
547
+ def _set_group_desc(self, desc: str) -> None: ...
548
+ def name(self) -> str: ...
549
+ def _has_hooks(self) -> bool: ...
550
+ def _wait_for_pending_works(self) -> None: ...
551
+ def _set_sequence_number_for_group(self) -> None: ...
552
+ @property
553
+ def bound_device_id(self) -> torch.device | None: ...
554
+ @bound_device_id.setter
555
+ def bound_device_id(self, device: torch.device | None) -> None: ...
556
+ @property
557
+ def group_name(self) -> str: ...
558
+ @property
559
+ def group_desc(self) -> str: ...
560
+
561
+ class FakeProcessGroup(Backend):
562
+ def __init__(self, rank: int, world_size: int) -> None: ...
563
+
564
+ class FakeWork(Work):
565
+ seq_id: int
566
+ def __init__(self) -> None: ...
567
+ def wait(self, timeout: timedelta = ...) -> bool: ...
568
+ def getFuture(self) -> Future: ...
569
+
570
+ class ProcessGroupGloo(Backend):
571
+ class Device: ...
572
+
573
+ class Options(Backend.Options):
574
+ devices: list[ProcessGroupGloo.Device]
575
+ threads: int
576
+ global_ranks_in_group: list[int]
577
+ group_name: str
578
+
579
+ def __init__(self): ...
580
+
581
+ def __init__(
582
+ self,
583
+ store: Store,
584
+ rank: int,
585
+ size: int,
586
+ timeout: timedelta,
587
+ ) -> None: ...
588
+ @staticmethod
589
+ def create_device(hostname="", interface="", lazy_init=None) -> Device: ...
590
+ @staticmethod
591
+ def create_default_device(lazy_init=None) -> Device: ...
592
+ def _set_default_timeout(self, timeout) -> None: ...
593
+ @property
594
+ def options(self) -> Options: ... # type: ignore[override]
595
+
596
+ class _ProcessGroupWrapper(Backend):
597
+ def __init__(self, pg: Backend, gloo_pg: ProcessGroupGloo) -> None: ...
598
+ wrapped_pg: Backend
599
+
600
+ class ErrorType(Enum):
601
+ SUCCESS = ...
602
+ TIMEOUT = ...
603
+ COMM_ERROR = ...
604
+ REMOTE_ERROR = ...
605
+
606
+ class ProcessGroupNCCL(Backend):
607
+ class NCCLConfig:
608
+ blocking: int
609
+ cga_cluster_size: int
610
+ min_ctas: int
611
+ max_ctas: int
612
+
613
+ class Options(Backend.Options):
614
+ config: ProcessGroupNCCL.NCCLConfig
615
+ is_high_priority_stream: bool
616
+ split_from: ProcessGroupNCCL
617
+ split_color: int
618
+ global_ranks_in_group: list[int]
619
+ group_name: str
620
+
621
+ def __init__(self, is_high_priority_stream: bool = False): ...
622
+
623
+ def __init__(
624
+ self,
625
+ store: Store,
626
+ rank: int,
627
+ size: int,
628
+ options: Options,
629
+ ) -> None: ...
630
+ def _group_start(self) -> None: ...
631
+ def _group_end(self) -> None: ...
632
+ def _start_time_estimate(self) -> None: ...
633
+ def _end_time_estimate(self) -> float: ...
634
+ def _set_default_timeout(self, timeout) -> None: ...
635
+ def perform_nocolor_split(self, device: torch.device) -> None: ...
636
+ def register_mem_pool(self, pool: torch.cuda.MemPool) -> None: ...
637
+ def deregister_mem_pool(self, pool: torch.cuda.MemPool) -> None: ...
638
+ def comm_split_count(self) -> int: ...
639
+ def _add_ephemeral_timeout(self, timeout: timedelta) -> None: ...
640
+ def abort(self) -> None: ...
641
+ def _is_initialized(self) -> bool: ...
642
+ @property
643
+ def uid(self) -> int: ...
644
+ @property
645
+ def options(self) -> Options: ... # type: ignore[override]
646
+ @staticmethod
647
+ def get_build_nccl_version(self) -> tuple[int, int, int]: ...
648
+ @staticmethod
649
+ def get_runtime_nccl_version(self) -> tuple[int, int, int]: ...
650
+
651
+ class ProcessGroupUCC(Backend):
652
+ def __init__(
653
+ self,
654
+ store: Store,
655
+ rank: int,
656
+ size: int,
657
+ timeout: timedelta,
658
+ ) -> None: ...
659
+
660
+ class ProcessGroupMPI(Backend):
661
+ def __init__(
662
+ self,
663
+ rank: int,
664
+ size: int,
665
+ pgComm: int,
666
+ ) -> None: ...
667
+ @staticmethod
668
+ def create(ranks: list[int]) -> ProcessGroupMPI: ...
669
+
670
+ def _compute_bucket_assignment_by_size(
671
+ tensors: list[Tensor],
672
+ bucket_size_limits: list[int],
673
+ expect_sparse_gradient: list[bool] = ...,
674
+ tensor_indices: list[int] = ...,
675
+ ) -> tuple[list[list[int]], list[int]]: ...
676
+ def _broadcast_coalesced(
677
+ process_group: ProcessGroup,
678
+ tensors: list[Tensor],
679
+ buffer_size: int,
680
+ src: int,
681
+ ): ...
682
+ def _test_python_store(store: Store): ...
683
+ def _verify_params_across_processes(
684
+ process_group: ProcessGroup,
685
+ params: list[Tensor],
686
+ logger: Logger | None,
687
+ ): ...
688
+ def _make_nccl_premul_sum(factor: float | list[Tensor]) -> ReduceOp: ...
689
+ def _register_process_group(
690
+ group_name: str,
691
+ process_group: ProcessGroup,
692
+ ) -> None: ...
693
+ def _resolve_process_group(group_name: str) -> ProcessGroup: ...
694
+ def _register_work(tensor: torch.Tensor, work: Work) -> ProcessGroup: ...
695
+ def _get_work_registry_size() -> int: ...
696
+ def _set_allow_inflight_collective_as_graph_input(
697
+ value: bool,
698
+ ) -> None: ...
699
+ def _allow_inflight_collective_as_graph_input() -> bool: ...
700
+ def _unregister_all_process_groups() -> None: ...
701
+ def _unregister_process_group(group_name: str) -> None: ...
702
+
703
+ # Intializes the device state in CUmodule so that it’s able to perform NVSHMEM
704
+ # operations. CUmodule is a pointer to a CUDA module, carried by a int64 in
705
+ # Python. At C++ interface, it is converted to a uintptr_t.
706
+ def _nvshmemx_cumodule_init(module: int) -> None: ...
707
+
708
+ # Check if NVSHMEM is available on current system.
709
+ def _is_nvshmem_available() -> bool: ...
710
+
711
+ class _SymmetricMemory:
712
+ @staticmethod
713
+ def set_group_info(
714
+ group_name: str,
715
+ rank: int,
716
+ world_size: int,
717
+ store: Store,
718
+ ) -> None: ...
719
+ @staticmethod
720
+ def empty_strided_p2p(
721
+ size: torch.types._size,
722
+ stride: torch.types._size,
723
+ dtype: torch.dtype,
724
+ device: torch.device,
725
+ group_name: str | None = None,
726
+ alloc_id: int | None = None,
727
+ ) -> torch.Tensor: ...
728
+ @staticmethod
729
+ def has_multicast_support(
730
+ device_type: DeviceType,
731
+ device_idx: int,
732
+ ) -> bool: ...
733
+ @property
734
+ def rank(self) -> int: ...
735
+ @property
736
+ def world_size(self) -> int: ...
737
+ @staticmethod
738
+ def rendezvous(
739
+ tensor: torch.Tensor, group_name: str | None = None
740
+ ) -> _SymmetricMemory: ...
741
+ def get_buffer(
742
+ self,
743
+ rank: int,
744
+ sizes: torch.types._size,
745
+ dtype: torch.dtype,
746
+ storage_offset: int | None = 0,
747
+ ) -> torch.Tensor: ...
748
+ def get_signal_pad(
749
+ self,
750
+ rank: int,
751
+ sizes: torch.types._size = [],
752
+ dtype: torch.dtype | None = None,
753
+ storage_offset: int | None = 0,
754
+ ) -> torch.Tensor: ...
755
+ def barrier(self, channel: int = 0, timeout_ms: int = 0) -> None: ...
756
+ def put_signal(
757
+ self,
758
+ dst_rank: int,
759
+ channel: int = 0,
760
+ timeout_ms: int = 0,
761
+ ) -> None: ...
762
+ def wait_signal(
763
+ self,
764
+ src_rank: int,
765
+ channel: int = 0,
766
+ timeout_ms: int = 0,
767
+ ) -> None: ...
768
+ @staticmethod
769
+ def memset32(
770
+ tensor: torch.Tensor, offset: int, val: int, count: int = 1
771
+ ) -> torch.Tensor: ...
772
+ @staticmethod
773
+ def stream_write_value32(
774
+ tensor: torch.Tensor, offset: int, val: int
775
+ ) -> torch.Tensor: ...
776
+ @property
777
+ def buffer_ptrs(self) -> list[int]: ...
778
+ @property
779
+ def buffer_ptrs_dev(self) -> int: ...
780
+ @property
781
+ def signal_pad_ptrs(self) -> list[int]: ...
782
+ @property
783
+ def signal_pad_ptrs_dev(self) -> int: ...
784
+ @property
785
+ def multicast_ptr(self) -> int: ...
786
+ @property
787
+ def buffer_size(self) -> int: ...
788
+ @property
789
+ def signal_pad_size(self) -> int: ...
790
+
791
+ class ProcessGroupXCCL(Backend):
792
+ def __init__(
793
+ self,
794
+ store: Store,
795
+ rank: int,
796
+ size: int,
797
+ ): ...
phivenv/Lib/site-packages/torch/_C/_distributed_rpc.pyi ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ # mypy: disable-error-code="type-arg"
3
+ from datetime import timedelta
4
+ from typing import Any, Generic, overload, TypeVar
5
+
6
+ import torch
7
+ from torch._C import Future
8
+ from torch._C._autograd import ProfilerEvent
9
+ from torch._C._distributed_c10d import Store
10
+ from torch._C._profiler import ProfilerConfig
11
+
12
+ # This module is defined in torch/csrc/distributed/rpc/init.cpp
13
+
14
+ _DEFAULT_INIT_METHOD: str
15
+ _DEFAULT_NUM_WORKER_THREADS: int
16
+ _UNSET_RPC_TIMEOUT: float
17
+ _DEFAULT_RPC_TIMEOUT_SEC: float
18
+
19
+ _T = TypeVar("_T")
20
+
21
+ class RpcBackendOptions:
22
+ rpc_timeout: float
23
+ init_method: str
24
+ def __init__(
25
+ self,
26
+ rpc_timeout: float = ...,
27
+ init_method: str = ...,
28
+ ) -> None: ...
29
+
30
+ class WorkerInfo:
31
+ def __init__(self, name: str, worker_id: int) -> None: ...
32
+ @property
33
+ def name(self) -> str: ...
34
+ @property
35
+ def id(self) -> int: ...
36
+ def __eq__(self, other: object) -> bool: ...
37
+
38
+ class RpcAgent:
39
+ def join(self, shutdown: bool = False, timeout: float = 0): ...
40
+ def sync(self): ...
41
+ def shutdown(self): ...
42
+ @overload
43
+ def get_worker_info(self) -> WorkerInfo: ...
44
+ @overload
45
+ def get_worker_info(self, workerName: str) -> WorkerInfo: ...
46
+ def get_worker_infos(self) -> list[WorkerInfo]: ...
47
+ def _get_device_map(self, dst: WorkerInfo) -> dict[torch.device, torch.device]: ...
48
+ def get_debug_info(self) -> dict[str, str]: ...
49
+ def get_metrics(self) -> dict[str, str]: ...
50
+
51
+ class PyRRef(Generic[_T]):
52
+ def __init__(self, value: _T, type_hint: Any = None) -> None: ...
53
+ def is_owner(self) -> bool: ...
54
+ def confirmed_by_owner(self) -> bool: ...
55
+ def owner(self) -> WorkerInfo: ...
56
+ def owner_name(self) -> str: ...
57
+ def to_here(self, timeout: float = ...) -> _T: ...
58
+ def local_value(self) -> Any: ...
59
+ def rpc_sync(self, timeout: float = ...) -> Any: ...
60
+ def rpc_async(self, timeout: float = ...) -> Any: ...
61
+ def remote(self, timeout: float = ...) -> Any: ...
62
+ def _serialize(self) -> tuple: ...
63
+ @staticmethod
64
+ def _deserialize(tp: tuple) -> PyRRef: ...
65
+ def _get_type(self) -> type[_T]: ...
66
+ def _get_future(self) -> Future[_T]: ...
67
+ def _get_profiling_future(self) -> Future[_T]: ...
68
+ def _set_profiling_future(self, profilingFuture: Future[_T]): ...
69
+
70
+ class _TensorPipeRpcBackendOptionsBase(RpcBackendOptions):
71
+ num_worker_threads: int
72
+ device_maps: dict[str, dict[torch.device, torch.device]]
73
+ devices: list[torch.device]
74
+ def __init__(
75
+ self,
76
+ num_worker_threads: int,
77
+ _transports: list | None,
78
+ _channels: list | None,
79
+ rpc_timeout: float = ...,
80
+ init_method: str = ...,
81
+ device_maps: dict[str, dict[torch.device, torch.device]] = {}, # noqa: B006
82
+ devices: list[torch.device] = [], # noqa: B006
83
+ ) -> None: ...
84
+ def _set_device_map(
85
+ self,
86
+ to: str,
87
+ device_map: dict[torch.device, torch.device],
88
+ ): ...
89
+
90
+ class TensorPipeAgent(RpcAgent):
91
+ def __init__(
92
+ self,
93
+ store: Store,
94
+ name: str,
95
+ worker_id: int,
96
+ world_size: int | None,
97
+ opts: _TensorPipeRpcBackendOptionsBase,
98
+ reverse_device_maps: dict[str, dict[torch.device, torch.device]],
99
+ devices: list[torch.device],
100
+ ) -> None: ...
101
+ def join(self, shutdown: bool = False, timeout: float = 0): ...
102
+ def shutdown(self): ...
103
+ @overload
104
+ def get_worker_info(self) -> WorkerInfo: ...
105
+ @overload
106
+ def get_worker_info(self, workerName: str) -> WorkerInfo: ...
107
+ @overload
108
+ def get_worker_info(self, id: int) -> WorkerInfo: ...
109
+ def get_worker_infos(self) -> list[WorkerInfo]: ...
110
+ def _get_device_map(self, dst: WorkerInfo) -> dict[torch.device, torch.device]: ...
111
+ def _update_group_membership(
112
+ self,
113
+ worker_info: WorkerInfo,
114
+ my_devices: list[torch.device],
115
+ reverse_device_map: dict[str, dict[torch.device, torch.device]],
116
+ is_join: bool,
117
+ ): ...
118
+ def _get_backend_options(self) -> _TensorPipeRpcBackendOptionsBase: ...
119
+ @property
120
+ def is_static_group(self) -> bool: ...
121
+ @property
122
+ def store(self) -> Store: ...
123
+
124
+ def _is_current_rpc_agent_set() -> bool: ...
125
+ def _get_current_rpc_agent() -> RpcAgent: ...
126
+ def _set_and_start_rpc_agent(agent: RpcAgent): ...
127
+ def _reset_current_rpc_agent(): ...
128
+ def _delete_all_user_and_unforked_owner_rrefs(timeout: timedelta = ...): ...
129
+ def _destroy_rref_context(ignoreRRefLeak: bool): ...
130
+ def _rref_context_get_debug_info() -> dict[str, str]: ...
131
+ def _cleanup_python_rpc_handler(): ...
132
+ def _invoke_rpc_builtin(
133
+ dst: WorkerInfo,
134
+ opName: str,
135
+ rpcTimeoutSeconds: float,
136
+ *args: Any,
137
+ **kwargs: Any,
138
+ ): ...
139
+ def _invoke_rpc_python_udf(
140
+ dst: WorkerInfo,
141
+ pickledPythonUDF: str,
142
+ tensors: list[torch.Tensor],
143
+ rpcTimeoutSeconds: float,
144
+ isAsyncExecution: bool,
145
+ ): ...
146
+ def _invoke_rpc_torchscript(
147
+ dstWorkerName: str,
148
+ qualifiedNameStr: str,
149
+ argsTuple: tuple,
150
+ kwargsDict: dict,
151
+ rpcTimeoutSeconds: float,
152
+ isAsyncExecution: bool,
153
+ ): ...
154
+ def _invoke_remote_builtin(
155
+ dst: WorkerInfo,
156
+ opName: str,
157
+ rpcTimeoutSeconds: float,
158
+ *args: Any,
159
+ **kwargs: Any,
160
+ ): ...
161
+ def _invoke_remote_python_udf(
162
+ dst: WorkerInfo,
163
+ pickledPythonUDF: str,
164
+ tensors: list[torch.Tensor],
165
+ rpcTimeoutSeconds: float,
166
+ isAsyncExecution: bool,
167
+ ): ...
168
+ def _invoke_remote_torchscript(
169
+ dstWorkerName: WorkerInfo,
170
+ qualifiedNameStr: str,
171
+ rpcTimeoutSeconds: float,
172
+ isAsyncExecution: bool,
173
+ *args: Any,
174
+ **kwargs: Any,
175
+ ): ...
176
+ def get_rpc_timeout() -> float: ...
177
+ def enable_gil_profiling(flag: bool): ...
178
+ def _set_rpc_timeout(rpcTimeoutSeconds: float): ...
179
+
180
+ class RemoteProfilerManager:
181
+ @staticmethod
182
+ def set_current_profiling_key(key: str): ...
183
+
184
+ def _enable_server_process_global_profiler(new_config: ProfilerConfig): ...
185
+ def _disable_server_process_global_profiler() -> list[list[list[ProfilerEvent]]]: ...
186
+ def _set_profiler_node_id(default_node_id: int): ...
187
+ def _enable_jit_rref_pickle(): ...
188
+ def _disable_jit_rref_pickle(): ...
phivenv/Lib/site-packages/torch/_C/_distributed_rpc_testing.pyi ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch._C._distributed_c10d import Store
3
+ from torch._C._distributed_rpc import _TensorPipeRpcBackendOptionsBase, TensorPipeAgent
4
+
5
+ # This module is defined in torch/csrc/distributed/rpc/testing/init.cpp
6
+
7
+ class FaultyTensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase):
8
+ def __init__(
9
+ self,
10
+ num_worker_threads: int,
11
+ rpc_timeout: float,
12
+ init_method: str,
13
+ messages_to_fail: list[str],
14
+ messages_to_delay: dict[str, float],
15
+ num_fail_sends: int,
16
+ ) -> None: ...
17
+ num_send_recv_threads: int
18
+ messages_to_fail: list[str]
19
+ messages_to_delay: dict[str, float]
20
+ num_fail_sends: int
21
+
22
+ class FaultyTensorPipeAgent(TensorPipeAgent):
23
+ def __init__(
24
+ self,
25
+ store: Store,
26
+ name: str,
27
+ rank: int,
28
+ world_size: int,
29
+ options: FaultyTensorPipeRpcBackendOptions,
30
+ reverse_device_maps: dict[str, dict[torch.device, torch.device]],
31
+ devices: list[torch.device],
32
+ ) -> None: ...
phivenv/Lib/site-packages/torch/_C/_dynamo/__init__.pyi ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from . import compiled_autograd, eval_frame, guards # noqa: F401
2
+
3
+ def strip_function_call(name: str) -> str: ...
4
+ def is_valid_var_name(name: str) -> bool | int: ...
phivenv/Lib/site-packages/torch/_C/_dynamo/compiled_autograd.pyi ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable
2
+
3
+ from torch import Tensor
4
+ from torch._dynamo.compiled_autograd import AutogradCompilerInstance
5
+
6
+ def set_autograd_compiler(
7
+ autograd_compiler: Callable[[], AutogradCompilerInstance] | None,
8
+ dynamic: bool,
9
+ ) -> tuple[Callable[[], AutogradCompilerInstance] | None, bool]: ...
10
+ def clear_cache() -> None: ...
11
+ def is_cache_empty() -> bool: ...
12
+ def set_verbose_logger(fn: Callable[[str], None] | None) -> bool: ...
13
+ def call_cpp_tensor_pre_hooks(idx: int, grad: Tensor) -> Tensor: ...
phivenv/Lib/site-packages/torch/_C/_dynamo/eval_frame.pyi ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import enum
2
+ import types
3
+ from typing import Optional, overload
4
+
5
+ from torch._dynamo.types import (
6
+ DynamoCallback,
7
+ DynamoGuardCompleteHook,
8
+ DynamoGuardHook,
9
+ GuardFn,
10
+ )
11
+
12
+ def set_eval_frame(callback: DynamoCallback) -> DynamoCallback: ...
13
+ def set_skip_guard_eval_unsafe(value: bool) -> bool: ...
14
+ def get_eval_frame_callback() -> DynamoCallback: ...
15
+ def reset_code(code: types.CodeType) -> None: ...
16
+ def unsupported(obj1: object, obj2: object) -> object: ...
17
+ def set_code_exec_strategy(
18
+ code: types.CodeType, strategy: _FrameExecStrategy
19
+ ) -> None: ...
20
+ def set_guard_error_hook(hook: DynamoGuardHook) -> None: ...
21
+ def set_guard_complete_hook(
22
+ hook: Optional[DynamoGuardCompleteHook],
23
+ ) -> Optional[DynamoGuardCompleteHook]: ...
24
+ def raise_sigtrap() -> None: ...
25
+
26
+ class _CacheEntry:
27
+ def check_fn(self, *args: object, **kwargs: object) -> bool: ...
28
+ code: types.CodeType
29
+ next: _CacheEntry | None
30
+
31
+ class _ExtraState:
32
+ def invalidate(self, cache_entry: _CacheEntry, guard_manager: object) -> None: ...
33
+
34
+ class _FrameAction(enum.IntEnum):
35
+ DEFAULT = 0
36
+ SKIP = 1
37
+ RUN_ONLY = 2
38
+
39
+ class _FrameExecStrategy:
40
+ cur_action: _FrameAction
41
+ recursive_action: _FrameAction
42
+
43
+ @overload
44
+ def __init__(self) -> None: ...
45
+ @overload
46
+ def __init__(
47
+ self, cur_action: _FrameAction, recursive_action: _FrameAction
48
+ ) -> None: ...
49
+
50
+ # This is an object that encapsulates the Python FrameType, and exposes
51
+ # properties Dynamo cares about for a frame.
52
+ class _PyInterpreterFrame:
53
+ f_code: types.CodeType
54
+ f_locals: dict[str, object]
55
+ f_globals: dict[str, object]
56
+ f_builtins: dict[str, object]
57
+ f_lasti: int
58
+ f_lineo: int
59
+ f_back: types.FrameType
60
+ # A tuple containing cell objects captured by this frame.
61
+ closure: tuple[types.CellType]
62
+
63
+ def _debug_get_cache_entry_list(code: types.CodeType) -> list[_CacheEntry]: ...
64
+
65
+ py_opcode_caches: list[int]
66
+
67
+ def code_framelocals_names(code: types.CodeType) -> tuple[str]: ...
68
+ def _load_precompile_entry(
69
+ code: types.CodeType, guard_manager: GuardFn, dynamo_code: types.CodeType
70
+ ) -> None: ...
71
+ def _reset_precompile_entries(code: types.CodeType) -> None: ...
phivenv/Lib/site-packages/torch/_C/_dynamo/guards.pyi ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from typing import Any, Callable
3
+
4
+ import torch
5
+
6
+ class GlobalStateGuard:
7
+ def check(self) -> bool: ...
8
+ def reason(self) -> str: ...
9
+
10
+ class LeafGuard: ...
11
+ class GuardDebugInfo: ...
12
+
13
+ class GuardManager:
14
+ def check(self, value) -> bool: ...
15
+ def check_verbose(self, value) -> GuardDebugInfo: ...
16
+
17
+ # Accessors
18
+ def globals_dict_manager(
19
+ self,
20
+ f_globals: dict[str, Any],
21
+ source,
22
+ example_value,
23
+ guard_manager_enum,
24
+ ) -> GuardManager: ...
25
+ def framelocals_manager(
26
+ self,
27
+ key: tuple[str, int],
28
+ source,
29
+ example_value,
30
+ guard_manager_enum,
31
+ ) -> GuardManager: ...
32
+ def dict_getitem_manager(
33
+ self,
34
+ key,
35
+ source,
36
+ example_value,
37
+ guard_manager_enum,
38
+ ) -> GuardManager: ...
39
+ def global_weakref_manager(
40
+ self,
41
+ global_name: str,
42
+ source,
43
+ example_value,
44
+ guard_manager_enum,
45
+ ) -> GuardManager: ...
46
+ def type_manager(
47
+ self,
48
+ source,
49
+ example_value,
50
+ guard_manager_enum,
51
+ ) -> GuardManager: ...
52
+ def getattr_manager(
53
+ self,
54
+ attr: str,
55
+ source,
56
+ example_value,
57
+ guard_manager_enum,
58
+ ) -> GuardManager: ...
59
+ def tensor_property_size_manager(
60
+ self,
61
+ idx: int,
62
+ source,
63
+ example_value,
64
+ guard_manager_enum,
65
+ ) -> GuardManager: ...
66
+ def tensor_property_shape_manager(
67
+ self,
68
+ idx: int,
69
+ source,
70
+ example_value,
71
+ guard_manager_enum,
72
+ ) -> GuardManager: ...
73
+ def tensor_property_storage_offset_manager(
74
+ self,
75
+ idx: None,
76
+ source,
77
+ example_value,
78
+ guard_manager_enum,
79
+ ) -> GuardManager: ...
80
+ def indexed_manager(
81
+ self,
82
+ idx: int,
83
+ source,
84
+ example_value,
85
+ guard_manager_enum,
86
+ ) -> GuardManager: ...
87
+ def lambda_manager(
88
+ self,
89
+ python_lambda,
90
+ source,
91
+ example_value,
92
+ guard_manager_enum,
93
+ ) -> GuardManager: ...
94
+
95
+ # Leaf guards
96
+ def add_lambda_guard(self, user_lambda, verbose_code_parts: list[str]) -> None: ...
97
+ def add_id_match_guard(self, id_val, verbose_code_parts: list[str]) -> None: ...
98
+ def add_equals_match_guard(
99
+ self,
100
+ equals_val,
101
+ verbose_code_parts: list[str],
102
+ ) -> None: ...
103
+ def add_global_state_guard(self, verbose_code_parts: list[str]) -> None: ...
104
+ def add_torch_function_mode_stack_guard(
105
+ self, initial_stack, verbose_code_parts: list[str]
106
+ ) -> None: ...
107
+ def add_mapping_keys_guard(sef, value, verbose_code_parts: list[str]) -> None: ...
108
+
109
+ class RootGuardManager(GuardManager):
110
+ def get_epilogue_lambda_guards(self) -> list[LeafGuard]: ...
111
+ def add_epilogue_lambda_guard(
112
+ self,
113
+ guard: LeafGuard,
114
+ verbose_code_parts: list[str],
115
+ ) -> None: ...
116
+ def clone_manager(
117
+ self, clone_filter_fn: Callable[[GuardManager], bool]
118
+ ) -> RootGuardManager: ...
119
+
120
+ class DictGuardManager(GuardManager):
121
+ def get_key_manager(
122
+ self,
123
+ index,
124
+ source,
125
+ example_value,
126
+ guard_manager_enum,
127
+ ) -> GuardManager: ...
128
+ def get_value_manager(
129
+ self,
130
+ index,
131
+ source,
132
+ example_value,
133
+ guard_manager_enum,
134
+ ) -> GuardManager: ...
135
+
136
+ def install_object_aliasing_guard(
137
+ guard_managers: list[GuardManager],
138
+ tensor_names: list[str],
139
+ verbose_code_parts: list[str],
140
+ ): ...
141
+ def install_no_tensor_aliasing_guard(
142
+ guard_managers: list[GuardManager],
143
+ tensor_names: list[str],
144
+ verbose_code_parts: list[str],
145
+ ): ...
146
+ def install_storage_overlapping_guard(
147
+ overlapping_guard_managers: list[GuardManager],
148
+ non_overlapping_guard_managers: list[GuardManager],
149
+ verbose_code_parts: list[str],
150
+ ): ...
151
+ def install_symbolic_shape_guard(
152
+ guard_managers: list[GuardManager],
153
+ nargs_int: int,
154
+ nargs_float: int,
155
+ py_addr: int,
156
+ py_addr_keep_alive: Any,
157
+ verbose_code_parts: list[str],
158
+ ): ...
159
+ def profile_guard_manager(
160
+ guard_manager: GuardManager,
161
+ f_locals: dict[str, Any],
162
+ n_iters: int,
163
+ ) -> float: ...
164
+
165
+ class TensorGuards:
166
+ def __init__(
167
+ self,
168
+ *,
169
+ dynamic_dims_sizes: list[torch.SymInt | None] | None = None,
170
+ dynamic_dims_strides: list[torch.SymInt | None] | None = None,
171
+ ) -> None: ...
172
+ def check(self, *args) -> bool: ...
173
+ def check_verbose(self, *args, tensor_check_names=None) -> bool | str: ...
174
+
175
+ def assert_size_stride(
176
+ item: torch.Tensor,
177
+ size: torch.types._size,
178
+ stride: torch.types._size,
179
+ op_name: str | None = None,
180
+ ): ...
181
+ def assert_alignment(
182
+ item: torch.Tensor,
183
+ alignment: int,
184
+ op_name: str | None = None,
185
+ ): ...
186
+ def check_obj_id(obj: object, expected: int) -> bool: ...
187
+ def check_type_id(obj: object, expected: int) -> bool: ...
188
+ def dict_version(d: dict[Any, Any]) -> int: ...
189
+ def compute_overlapping_tensors(
190
+ tensors: list[torch.Tensor], symbolic: bool = True
191
+ ) -> set[int]: ...
phivenv/Lib/site-packages/torch/_C/_export/__init__.pyi ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Defined in torch/csrc/export/pybind.cpp
2
+ class CppExportedProgram: ...
3
+
4
+ def deserialize_exported_program(
5
+ serialized_program: str,
6
+ ) -> CppExportedProgram: ...
7
+ def serialize_exported_program(
8
+ cpp_exported_program: CppExportedProgram,
9
+ ) -> str: ...
phivenv/Lib/site-packages/torch/_C/_export/pt2_archive_constants.pyi ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Defined in torch/csrc/export/pt2_archive_constants.h
2
+
3
+ ARCHIVE_ROOT_NAME: str = ...
4
+ ARCHIVE_FORMAT_PATH: str = ...
5
+ ARCHIVE_FORMAT_VALUE: str = ...
6
+ ARCHIVE_VERSION_PATH: str = ...
7
+ ARCHIVE_VERSION_VALUE: str = ...
8
+ MODELS_DIR: str = ...
9
+ MODELS_FILENAME_FORMAT: str = ...
10
+ AOTINDUCTOR_DIR: str = ...
11
+ MTIA_DIR: str = ...
12
+ WEIGHTS_DIR: str = ...
13
+ WEIGHT_FILENAME_PREFIX: str = ...
14
+ CONSTANTS_DIR: str = ...
15
+ TENSOR_CONSTANT_FILENAME_PREFIX: str = ...
16
+ CUSTOM_OBJ_FILENAME_PREFIX: str = ...
17
+ SAMPLE_INPUTS_DIR: str = ...
18
+ SAMPLE_INPUTS_FILENAME_FORMAT: str = ...
19
+ EXTRA_DIR: str = ...
20
+ MODULE_INFO_PATH: str = ...
21
+ XL_MODEL_WEIGHTS_DIR: str = ...
22
+ XL_MODEL_WEIGHTS_PARAM_CONFIG_PATH: str = ...
phivenv/Lib/site-packages/torch/_C/_functions.pyi ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import AnyStr, overload
2
+
3
+ from torch import Tensor
4
+
5
+ class UndefinedGrad:
6
+ def __init__(self) -> None: ...
7
+ def __call__(self, *inputs: Tensor) -> list[Tensor]: ...
8
+
9
+ class DelayedError:
10
+ def __init__(self, msg: AnyStr, num_inputs: int) -> None: ...
11
+
12
+ # __call__ should really be a higher-kinded type:
13
+ # def __call__(self, arg: Tensor) -> Tensor: ...
14
+ # def __call__(self, *args: Tensor * num_inputs) -> Tuple[Tensor * num_inputs]: ...
15
+
16
+ @overload
17
+ def __call__(self, i0: Tensor) -> Tensor: ...
18
+ @overload
19
+ def __call__(self, *args: Tensor) -> tuple[Tensor, ...]: ...
phivenv/Lib/site-packages/torch/_C/_functorch.pyi ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from enum import Enum
3
+
4
+ from torch import Tensor
5
+
6
+ # Defined in torch/csrc/functorch/init.cpp
7
+
8
+ def _set_dynamic_layer_keys_included(included: bool) -> None: ...
9
+ def get_unwrapped(tensor: Tensor) -> Tensor: ...
10
+ def is_batchedtensor(tensor: Tensor) -> bool: ...
11
+ def is_functionaltensor(tensor: Tensor) -> bool: ...
12
+ def is_functorch_wrapped_tensor(tensor: Tensor) -> bool: ...
13
+ def is_gradtrackingtensor(tensor: Tensor) -> bool: ...
14
+ def is_legacy_batchedtensor(tensor: Tensor) -> bool: ...
15
+ def maybe_get_bdim(tensor: Tensor) -> int: ...
16
+ def maybe_get_level(tensor: Tensor) -> int: ...
17
+ def maybe_current_level() -> int | None: ...
18
+ def unwrap_if_dead(tensor: Tensor) -> Tensor: ...
19
+ def _unwrap_for_grad(tensor: Tensor, level: int) -> Tensor: ...
20
+ def _wrap_for_grad(tensor: Tensor, level: int) -> Tensor: ...
21
+ def _unwrap_batched(tensor: Tensor, level: int) -> tuple[Tensor, int | None]: ...
22
+ def current_level() -> int: ...
23
+ def count_jvp_interpreters() -> int: ...
24
+ def _add_batch_dim(tensor: Tensor, bdim: int, level: int) -> Tensor: ...
25
+ def set_single_level_autograd_function_allowed(allowed: bool) -> None: ...
26
+ def get_single_level_autograd_function_allowed() -> bool: ...
27
+ def _unwrap_functional_tensor(tensor: Tensor, reapply_views: bool) -> Tensor: ...
28
+ def _wrap_functional_tensor(tensor: Tensor, level: int) -> Tensor: ...
29
+ def _vmap_increment_nesting(batch_size: int, randomness: str) -> int: ...
30
+ def _vmap_decrement_nesting() -> int: ...
31
+ def _grad_increment_nesting() -> int: ...
32
+ def _grad_decrement_nesting() -> int: ...
33
+ def _jvp_increment_nesting() -> int: ...
34
+ def _jvp_decrement_nesting() -> int: ...
35
+
36
+ # Defined in aten/src/ATen/functorch/Interpreter.h
37
+ class TransformType(Enum):
38
+ Torch = ...
39
+ Vmap = ...
40
+ Grad = ...
41
+ Jvp = ...
42
+ Functionalize = ...
43
+
44
+ class RandomnessType(Enum):
45
+ Error = ...
46
+ Same = ...
47
+ Different = ...
48
+
49
+ class CInterpreter:
50
+ def key(self) -> TransformType: ...
51
+ def level(self) -> int: ...
52
+ def serialize(self) -> bytes: ...
53
+ @staticmethod
54
+ def deserialize(bytes) -> CInterpreter: ...
55
+
56
+ class CGradInterpreterPtr:
57
+ def __init__(self, interpreter: CInterpreter) -> None: ...
58
+ def lift(self, Tensor) -> Tensor: ...
59
+ def prevGradMode(self) -> bool: ...
60
+
61
+ class CJvpInterpreterPtr:
62
+ def __init__(self, interpreter: CInterpreter) -> None: ...
63
+ def lift(self, Tensor) -> Tensor: ...
64
+ def prevFwdGradMode(self) -> bool: ...
65
+
66
+ class CFunctionalizeInterpreterPtr:
67
+ def __init__(self, interpreter: CInterpreter) -> None: ...
68
+ def key(self) -> TransformType: ...
69
+ def level(self) -> int: ...
70
+ def functionalizeAddBackViews(self) -> bool: ...
71
+
72
+ class CVmapInterpreterPtr:
73
+ def __init__(self, interpreter: CInterpreter) -> None: ...
74
+ def key(self) -> TransformType: ...
75
+ def level(self) -> int: ...
76
+ def batchSize(self) -> int: ...
77
+ def randomness(self) -> RandomnessType: ...
78
+
79
+ class DynamicLayer: ...
80
+
81
+ def get_dynamic_layer_stack_depth() -> int: ...
82
+ def get_interpreter_stack() -> list[CInterpreter]: ...
83
+ def peek_interpreter_stack() -> CInterpreter: ...
84
+ def pop_dynamic_layer_stack() -> DynamicLayer: ...
85
+ def pop_dynamic_layer_stack_and_undo_to_depth(int) -> None: ...
86
+ def push_dynamic_layer_stack(dl: DynamicLayer) -> int: ...
phivenv/Lib/site-packages/torch/_C/_instruction_counter.pyi ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Defined in torch/csrc/instruction_counter/Module.cpp
2
+
3
+ def start() -> int: ...
4
+ def end(id: int) -> int: ...
phivenv/Lib/site-packages/torch/_C/_itt.pyi ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Defined in torch/csrc/itt.cpp
2
+ def is_available() -> None: ...
3
+ def rangePush(message: str) -> None: ...
4
+ def rangePop() -> None: ...
5
+ def mark(message: str) -> None: ...
phivenv/Lib/site-packages/torch/_C/_lazy.pyi ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import Tensor
2
+
3
+ # defined in torch/csrc/lazy/python/init.cpp
4
+ def _mark_step(device: str, devices: list[str], wait: bool) -> None: ...
5
+ def _wait_device_ops(devices: list[str]) -> None: ...
6
+ def _reset_metrics() -> None: ...
7
+ def _counter_names() -> list[str]: ...
8
+ def _counter_value(name: str) -> int: ...
9
+ def _metrics_report() -> str: ...
10
+ def _get_graph_hash(tensors: list[Tensor]) -> str: ...
11
+ def _sync_multi(
12
+ tensors: list[Tensor],
13
+ devices: list[str],
14
+ wait: bool = True,
15
+ sync_ltc_data: bool = True,
16
+ ) -> None: ...
17
+ def _get_tensor_id(tensor: Tensor) -> int: ...
18
+ def _get_tensors_text(tensors: list[Tensor]) -> str: ...
19
+ def _get_tensors_dot(tensors: list[Tensor]) -> str: ...
20
+ def _get_tensors_backend(tensors: list[Tensor]) -> str: ...
21
+ def _get_force_fallback() -> str: ...
22
+ def _set_force_fallback(newval: str) -> None: ...
23
+ def _clear_ir_cache() -> None: ...
24
+ def _dump_ir_cache(filename: str) -> None: ...
25
+ def _set_reuse_ir(val: bool) -> None: ...
26
+ def _get_default_device_type() -> str: ...
phivenv/Lib/site-packages/torch/_C/_lazy_ts_backend.pyi ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ # defined in torch/csrc/lazy/python/init.cpp
3
+
4
+ from typing import Any
5
+
6
+ from torch import Tensor
7
+
8
+ def _init(): ...
9
+ def _get_tensors_ts_device_data_node(
10
+ tensors: list[Tensor],
11
+ ) -> tuple[list[int], list[Any]]: ...
12
+ def _run_cached_graph(hash_str: str, graph_inputs: list[Any]) -> list[Tensor]: ...
phivenv/Lib/site-packages/torch/_C/_monitor.pyi ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Defined in torch/csrc/monitor/python_init.cpp
2
+
3
+ import datetime
4
+ from enum import Enum
5
+ from types import TracebackType
6
+ from typing import Callable
7
+
8
+ class Aggregation(Enum):
9
+ VALUE = ...
10
+ MEAN = ...
11
+ COUNT = ...
12
+ SUM = ...
13
+ MAX = ...
14
+ MIN = ...
15
+
16
+ class Stat:
17
+ name: str
18
+ count: int
19
+ def __init__(
20
+ self,
21
+ name: str,
22
+ aggregations: list[Aggregation],
23
+ window_size: int,
24
+ max_samples: int = -1,
25
+ ) -> None: ...
26
+ def add(self, v: float) -> None: ...
27
+ def get(self) -> dict[Aggregation, float]: ...
28
+
29
+ class Event:
30
+ name: str
31
+ timestamp: datetime.datetime
32
+ data: dict[str, int | float | bool | str]
33
+ def __init__(
34
+ self,
35
+ name: str,
36
+ timestamp: datetime.datetime,
37
+ data: dict[str, int | float | bool | str],
38
+ ) -> None: ...
39
+
40
+ def log_event(e: Event) -> None: ...
41
+
42
+ class EventHandlerHandle: ...
43
+
44
+ def register_event_handler(handler: Callable[[Event], None]) -> EventHandlerHandle: ...
45
+ def unregister_event_handler(handle: EventHandlerHandle) -> None: ...
46
+
47
+ class _WaitCounterTracker:
48
+ def __enter__(self) -> None: ...
49
+ def __exit__(
50
+ self,
51
+ exc_type: type[BaseException] | None = None,
52
+ exc_value: BaseException | None = None,
53
+ traceback: TracebackType | None = None,
54
+ ) -> None: ...
55
+
56
+ class _WaitCounter:
57
+ def __init__(self, key: str) -> None: ...
58
+ def guard(self) -> _WaitCounterTracker: ...
phivenv/Lib/site-packages/torch/_C/_nn.pyi ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @generated by tools/pyi/gen_pyi.py from torch/_C/_nn.pyi.in
2
+ # mypy: disable-error-code="type-arg"
3
+
4
+ from collections.abc import Sequence
5
+ from typing import Literal, overload
6
+
7
+ from torch import memory_format, Tensor
8
+ from torch.types import _bool, _device, _dtype, _int, _size
9
+
10
+ # Defined in tools/autograd/templates/python_nn_functions.cpp
11
+
12
+ def adaptive_avg_pool2d(input: Tensor, output_size: _int | _size) -> Tensor: ...
13
+ def adaptive_avg_pool3d(input: Tensor, output_size: _int | _size) -> Tensor: ...
14
+ def adaptive_max_pool2d(
15
+ input: Tensor,
16
+ output_size: _int | _size,
17
+ ) -> tuple[Tensor, Tensor]: ...
18
+ def adaptive_max_pool3d(
19
+ input: Tensor,
20
+ output_size: _int | _size,
21
+ ) -> tuple[Tensor, Tensor]: ...
22
+ def avg_pool2d(
23
+ input: Tensor,
24
+ kernel_size: _int | _size,
25
+ stride: _int | _size | None = None,
26
+ padding: _int | _size = 0,
27
+ ceil_mode: bool = False,
28
+ count_include_pad: bool = True,
29
+ divisor_override: int | None = None,
30
+ ) -> Tensor: ...
31
+ def avg_pool3d(
32
+ input: Tensor,
33
+ kernel_size: _int | _size,
34
+ stride: _int | _size | None = None,
35
+ padding: _int | _size = 0,
36
+ ceil_mode: bool = False,
37
+ count_include_pad: bool = True,
38
+ divisor_override: int | None = None,
39
+ ) -> Tensor: ...
40
+ def binary_cross_entropy(
41
+ input: Tensor,
42
+ target: Tensor,
43
+ weight: Tensor | None = None,
44
+ reduction: str = ...,
45
+ ) -> Tensor: ...
46
+ def col2im(
47
+ input: Tensor,
48
+ output_size: _int | _size,
49
+ kernel_size: _int | _size,
50
+ dilation: _int | _size,
51
+ stride: _int | _size | None = None,
52
+ padding: _int | _size = 0,
53
+ ) -> Tensor: ...
54
+ def elu_(input: Tensor, alpha: float = ...) -> Tensor: ...
55
+ def fractional_max_pool2d(
56
+ input: Tensor,
57
+ kernel_size: _int | _size,
58
+ output_size: _int | _size,
59
+ _random_samples: Tensor,
60
+ ) -> tuple[Tensor, Tensor]: ...
61
+ def fractional_max_pool3d(
62
+ input: Tensor,
63
+ kernel_size: _int | _size,
64
+ output_size: _int | _size,
65
+ _random_samples: Tensor,
66
+ ) -> tuple[Tensor, Tensor]: ...
67
+ def gelu(input: Tensor, approximate: str = ...) -> Tensor: ...
68
+ def hardsigmoid(input: Tensor, *, out: Tensor | None = None) -> Tensor: ...
69
+ def hardtanh(
70
+ input: Tensor,
71
+ min_val: float = ...,
72
+ max_val: float = ...,
73
+ *,
74
+ out: Tensor | None = None,
75
+ ) -> Tensor: ...
76
+ def hardtanh_(
77
+ input: Tensor,
78
+ min_val: float = ...,
79
+ max_val: float = ...,
80
+ ) -> Tensor: ...
81
+ def leaky_relu(
82
+ input: Tensor,
83
+ negative_slope: float = ...,
84
+ *,
85
+ out: Tensor | None = None,
86
+ ) -> Tensor: ...
87
+ def leaky_relu_(input: Tensor, negative_slope: float = ...) -> Tensor: ...
88
+ def linear(
89
+ input: Tensor,
90
+ weight: Tensor,
91
+ bias: Tensor | None = None,
92
+ ) -> Tensor: ...
93
+ def log_sigmoid(input: Tensor) -> Tensor: ...
94
+ def one_hot(tensor: Tensor, num_classes: int = ...) -> Tensor: ...
95
+ def pad(
96
+ input: Tensor,
97
+ pad: Sequence[int],
98
+ mode: str = ...,
99
+ value: float | None = None,
100
+ ) -> Tensor: ...
101
+ def scaled_dot_product_attention(
102
+ query: Tensor,
103
+ key: Tensor,
104
+ value: Tensor,
105
+ attn_mask: Tensor | None = None,
106
+ dropout_p: float = 0.0,
107
+ is_causal: bool = False,
108
+ scale: float | None = None,
109
+ enable_gqa: bool = False,
110
+ ) -> Tensor: ...
111
+ def softplus(
112
+ input: Tensor,
113
+ beta: float = ...,
114
+ threshold: float = ...,
115
+ ) -> Tensor: ...
116
+ def softshrink(input: Tensor, lambd: float = ...) -> Tensor: ...
117
+
118
+ # Defined in aten/src/ATen/native/mkldnn/Linear.cpp
119
+ def mkldnn_linear(input: Tensor, weight: Tensor, bias: Tensor | None) -> Tensor: ...
120
+
121
+ # Defined at aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp
122
+ def mkldnn_reorder_conv2d_weight(
123
+ self: Tensor,
124
+ padding: list,
125
+ stride: list,
126
+ dilatation: list,
127
+ groups: int,
128
+ ) -> Tensor: ...
129
+ def mkldnn_reorder_conv3d_weight(
130
+ self: Tensor,
131
+ padding: list,
132
+ stride: list,
133
+ dilatation: list,
134
+ groups: int,
135
+ ) -> Tensor: ...
136
+
137
+ # Defined in aten/src/ATen/native/mkldnn/Prelu.cpp
138
+ def mkldnn_prelu(input: Tensor, weight: Tensor) -> Tensor: ...
139
+
140
+ # Defined at tools/autograd/templates/python_nn_functions.cpp
141
+ @overload
142
+ def _parse_to(
143
+ device: _device,
144
+ dtype: _dtype,
145
+ non_blocking: _bool,
146
+ copy: _bool,
147
+ *,
148
+ memory_format: memory_format,
149
+ ) -> tuple[_device, _dtype, _bool, memory_format]: ...
150
+ @overload
151
+ def _parse_to(
152
+ dtype: _dtype,
153
+ non_blocking: _bool,
154
+ copy: _bool,
155
+ *,
156
+ memory_format: memory_format,
157
+ ) -> tuple[_device, _dtype, _bool, memory_format]: ...
158
+ @overload
159
+ def _parse_to(
160
+ tensor: Tensor,
161
+ non_blocking: _bool,
162
+ copy: _bool,
163
+ *,
164
+ memory_format: memory_format,
165
+ ) -> tuple[_device, _dtype, _bool, memory_format]: ...
166
+
167
+ # Defined in aten/src/ATen/native/PackedSequence.cpp
168
+ def pad_sequence(
169
+ sequences: list[Tensor] | tuple[Tensor, ...],
170
+ batch_first: bool = False,
171
+ padding_value: float = 0.0,
172
+ padding_side: Literal["left", "right"] = "right",
173
+ ) -> Tensor: ...
174
+ def flatten_dense_tensors(tensors: list[Tensor]) -> Tensor: ...
175
+ def unflatten_dense_tensors(flat: Tensor, tensors: list[Tensor]) -> list[Tensor]: ...
phivenv/Lib/site-packages/torch/_C/_nvtx.pyi ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ # Defined in torch/csrc/cuda/shared/nvtx.cpp
3
+ def rangePushA(message: str) -> int: ...
4
+ def rangePop() -> int: ...
5
+ def rangeStartA(message: str) -> int: ...
6
+ def rangeEnd(int) -> None: ...
7
+ def markA(message: str) -> None: ...
8
+ def deviceRangeStart(message: str, stream: int) -> object: ...
9
+ def deviceRangeEnd(range_handle: object, stream: int) -> None: ...
phivenv/Lib/site-packages/torch/_C/_onnx.pyi ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Defined in torch/csrc/onnx/init.cpp
2
+
3
+ from enum import Enum
4
+
5
+ PRODUCER_VERSION: str
6
+
7
+ class TensorProtoDataType(Enum):
8
+ UNDEFINED = ...
9
+ FLOAT = ...
10
+ UINT8 = ...
11
+ INT8 = ...
12
+ UINT16 = ...
13
+ INT16 = ...
14
+ INT32 = ...
15
+ INT64 = ...
16
+ STRING = ...
17
+ BOOL = ...
18
+ FLOAT16 = ...
19
+ DOUBLE = ...
20
+ UINT32 = ...
21
+ UINT64 = ...
22
+ COMPLEX64 = ...
23
+ COMPLEX128 = ...
24
+ BFLOAT16 = ...
25
+ FLOAT8E5M2 = ...
26
+ FLOAT8E4M3FN = ...
27
+ FLOAT8E5M2FNUZ = ...
28
+ FLOAT8E4M3FNUZ = ...
29
+
30
+ class OperatorExportTypes(Enum):
31
+ ONNX = ...
32
+ ONNX_ATEN = ...
33
+ ONNX_ATEN_FALLBACK = ...
34
+ ONNX_FALLTHROUGH = ...
35
+
36
+ class TrainingMode(Enum):
37
+ EVAL = ...
38
+ PRESERVE = ...
39
+ TRAINING = ...
phivenv/Lib/site-packages/torch/_C/_profiler.pyi ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ from typing import Literal
3
+ from typing_extensions import TypeAlias
4
+
5
+ from torch._C import device, dtype, layout
6
+
7
+ # defined in torch/csrc/profiler/python/init.cpp
8
+
9
+ class RecordScope(Enum):
10
+ FUNCTION = ...
11
+ BACKWARD_FUNCTION = ...
12
+ TORCHSCRIPT_FUNCTION = ...
13
+ KERNEL_FUNCTION_DTYPE = ...
14
+ CUSTOM_CLASS = ...
15
+ BUILD_FEATURE = ...
16
+ LITE_INTERPRETER = ...
17
+ USER_SCOPE = ...
18
+ STATIC_RUNTIME_OP = ...
19
+ STATIC_RUNTIME_MODEL = ...
20
+
21
+ class ProfilerState(Enum):
22
+ Disable = ...
23
+ CPU = ...
24
+ CUDA = ...
25
+ NVTX = ...
26
+ ITT = ...
27
+ KINETO = ...
28
+ KINETO_GPU_FALLBACK = ...
29
+ KINETO_PRIVATEUSE1_FALLBACK = ...
30
+ KINETO_PRIVATEUSE1 = ...
31
+
32
+ class ActiveProfilerType(Enum):
33
+ NONE = ...
34
+ LEGACY = ...
35
+ KINETO = ...
36
+ NVTX = ...
37
+ ITT = ...
38
+
39
+ class ProfilerActivity(Enum):
40
+ CPU = ...
41
+ CUDA = ...
42
+ XPU = ...
43
+ MTIA = ...
44
+ HPU = ...
45
+ PrivateUse1 = ...
46
+
47
+ class _EventType(Enum):
48
+ TorchOp = ...
49
+ Backend = ...
50
+ Allocation = ...
51
+ OutOfMemory = ...
52
+ PyCall = ...
53
+ PyCCall = ...
54
+ Kineto = ...
55
+
56
+ class _ExperimentalConfig:
57
+ def __init__(
58
+ self,
59
+ profiler_metrics: list[str] = ...,
60
+ profiler_measure_per_kernel: bool = ...,
61
+ verbose: bool = ...,
62
+ performance_events: list[str] = ...,
63
+ enable_cuda_sync_events: bool = ...,
64
+ ) -> None: ...
65
+
66
+ class ProfilerConfig:
67
+ def __init__(
68
+ self,
69
+ state: ProfilerState,
70
+ report_input_shapes: bool,
71
+ profile_memory: bool,
72
+ with_stack: bool,
73
+ with_flops: bool,
74
+ with_modules: bool,
75
+ experimental_config: _ExperimentalConfig,
76
+ trace_id: str | None = None,
77
+ ) -> None: ...
78
+
79
+ class _ProfilerEvent:
80
+ start_tid: int
81
+ start_time_ns: int
82
+ children: list[_ProfilerEvent]
83
+
84
+ # TODO(robieta): remove in favor of `self.typed`
85
+ extra_fields: (
86
+ _ExtraFields_TorchOp
87
+ | _ExtraFields_Backend
88
+ | _ExtraFields_Allocation
89
+ | _ExtraFields_OutOfMemory
90
+ | _ExtraFields_PyCall
91
+ | _ExtraFields_PyCCall
92
+ | _ExtraFields_Kineto
93
+ )
94
+
95
+ @property
96
+ def typed(
97
+ self,
98
+ ) -> (
99
+ tuple[Literal[_EventType.TorchOp], _ExtraFields_TorchOp]
100
+ | tuple[Literal[_EventType.Backend], _ExtraFields_Backend]
101
+ | tuple[Literal[_EventType.Allocation], _ExtraFields_Allocation]
102
+ | tuple[Literal[_EventType.OutOfMemory], _ExtraFields_OutOfMemory]
103
+ | tuple[Literal[_EventType.PyCall], _ExtraFields_PyCall]
104
+ | tuple[Literal[_EventType.PyCCall], _ExtraFields_PyCCall]
105
+ | tuple[Literal[_EventType.Kineto], _ExtraFields_Kineto]
106
+ ): ...
107
+ @property
108
+ def name(self) -> str: ...
109
+ @property
110
+ def tag(self) -> _EventType: ...
111
+ @property
112
+ def id(self) -> int: ...
113
+ @property
114
+ def parent(self) -> _ProfilerEvent | None: ...
115
+ @property
116
+ def correlation_id(self) -> int: ...
117
+ @property
118
+ def end_time_ns(self) -> int: ...
119
+ @property
120
+ def duration_time_ns(self) -> int: ...
121
+
122
+ class _TensorMetadata:
123
+ impl_ptr: int | None
124
+ storage_data_ptr: int | None
125
+ id: int | None
126
+
127
+ @property
128
+ def allocation_id(self) -> int | None: ...
129
+ @property
130
+ def layout(self) -> layout: ...
131
+ @property
132
+ def device(self) -> device: ...
133
+ @property
134
+ def dtype(self) -> dtype: ...
135
+ @property
136
+ def sizes(self) -> list[int]: ...
137
+ @property
138
+ def strides(self) -> list[int]: ...
139
+
140
+ Scalar: TypeAlias = int | float | bool | complex
141
+ Input: TypeAlias = _TensorMetadata | list[_TensorMetadata] | Scalar | None
142
+
143
+ class _ExtraFields_TorchOp:
144
+ name: str
145
+ sequence_number: int
146
+ allow_tf32_cublas: bool
147
+
148
+ @property
149
+ def inputs(self) -> list[Input]: ...
150
+ @property
151
+ def scope(self) -> RecordScope: ...
152
+
153
+ class _ExtraFields_Backend: ...
154
+
155
+ class _ExtraFields_Allocation:
156
+ ptr: int
157
+ id: int | None
158
+ alloc_size: int
159
+ total_allocated: int
160
+ total_reserved: int
161
+
162
+ @property
163
+ def allocation_id(self) -> int | None: ...
164
+ @property
165
+ def device(self) -> device: ...
166
+
167
+ class _ExtraFields_OutOfMemory: ...
168
+
169
+ class _PyFrameState:
170
+ line_number: int
171
+ function_name: str
172
+
173
+ @property
174
+ def file_name(self) -> str: ...
175
+
176
+ class _NNModuleInfo:
177
+ @property
178
+ def self_ptr(self) -> int: ...
179
+ @property
180
+ def cls_ptr(self) -> int: ...
181
+ @property
182
+ def cls_name(self) -> str: ...
183
+ @property
184
+ def parameters(
185
+ self,
186
+ ) -> list[tuple[str, _TensorMetadata, _TensorMetadata | None]]: ...
187
+
188
+ class _OptimizerInfo:
189
+ @property
190
+ def parameters(
191
+ self,
192
+ ) -> list[
193
+ tuple[
194
+ # Parameter
195
+ _TensorMetadata,
196
+ #
197
+ # Gradient (if present during optimizer.step())
198
+ _TensorMetadata | None,
199
+ #
200
+ # Optimizer state for Parameter as (name, tensor) pairs
201
+ list[tuple[str, _TensorMetadata]],
202
+ ]
203
+ ]: ...
204
+
205
+ class _ExtraFields_PyCCall:
206
+ @property
207
+ def caller(self) -> _PyFrameState: ...
208
+
209
+ class _ExtraFields_PyCall:
210
+ @property
211
+ def callsite(self) -> _PyFrameState: ...
212
+ @property
213
+ def caller(self) -> _PyFrameState: ...
214
+ @property
215
+ def module(self) -> _NNModuleInfo | None: ...
216
+ @property
217
+ def optimizer(self) -> _OptimizerInfo | None: ...
218
+
219
+ class _ExtraFields_Kineto: ...
220
+
221
+ def _add_execution_trace_observer(output_file_path: str) -> bool: ...
222
+ def _remove_execution_trace_observer() -> None: ...
223
+ def _enable_execution_trace_observer() -> None: ...
224
+ def _disable_execution_trace_observer() -> None: ...
225
+ def _set_record_concrete_inputs_enabled_val(val: bool) -> None: ...
226
+ def _set_fwd_bwd_enabled_val(val: bool) -> None: ...
227
+ def _set_cuda_sync_enabled_val(val: bool) -> None: ...
228
+
229
+ class CapturedTraceback: ...
230
+
231
+ def gather_traceback(python: bool, script: bool, cpp: bool) -> CapturedTraceback: ...
232
+
233
+ # The Dict has name, filename, line
234
+ def symbolize_tracebacks(
235
+ to_symbolize: list[CapturedTraceback],
236
+ ) -> list[list[dict[str, str]]]: ...
237
+
238
+ class _RecordFunctionFast:
239
+ def __init__(
240
+ self,
241
+ name: str,
242
+ input_values: list | tuple | None = None,
243
+ keyword_values: dict | None = None,
244
+ ) -> None: ...
245
+ def __enter__(self) -> None: ...
246
+ def __exit__(self, *exc_info: object) -> None: ...
phivenv/Lib/site-packages/torch/_C/_verbose.pyi ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Defined in torch/csrc/utils/verbose.cpp
2
+ def mkl_set_verbose(enable: int) -> int: ...
3
+ def mkldnn_set_verbose(level: int) -> int: ...
phivenv/Lib/site-packages/torch/_C_flatbuffer/__init__.pyi ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from torch._C import LiteScriptModule, ScriptModule
3
+
4
+ def _load_mobile_module_from_file(filename: str): ...
5
+ def _load_mobile_module_from_bytes(bytes_: bytes): ...
6
+ def _load_jit_module_from_file(filename: str): ...
7
+ def _load_jit_module_from_bytes(bytes_: bytes): ...
8
+ def _save_mobile_module(m: LiteScriptModule, filename: str): ...
9
+ def _save_jit_module(m: ScriptModule, filename: str): ...
10
+ def _save_mobile_module_to_bytes(m: LiteScriptModule) -> bytes: ...
11
+ def _save_jit_module_to_bytes(m: ScriptModule) -> bytes: ...
phivenv/Lib/site-packages/torch/_awaits/__init__.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Generic, TypeVar
4
+
5
+ import torch
6
+
7
+ __all__ = ['Await']
8
+
9
+ W = TypeVar("W")
10
+
11
+ class _PyAwaitMeta(type(torch._C._Await), type(Generic)): # type: ignore[misc, no-redef]
12
+ pass
13
+
14
+ class _Await(torch._C._Await, Generic[W], metaclass=_PyAwaitMeta):
15
+ r"""
16
+ Wrapper around a ``torch._C.Await`` which encapsulates delayed execution
17
+ of a callable. All manipulations happen with functions ``torch.jit._awaitable``,
18
+ ``torch.jit._awaitable_wait``, ``torch.jit._awaitable_nowait``.
19
+
20
+ Torch scriptable manipulations:
21
+ ``torch.jit._awaitable(func, *args)``
22
+ Creates ``Await[W]`` object, where W is return type of func.
23
+
24
+ Returns:
25
+ ``torch.jit._awaitable_wait(Await[W])``
26
+ Returns the result of the function, specified at ``_awaitable``, with specified arguments.
27
+
28
+ Returns:
29
+ The result of type ``W`` of the function call. The result is owned by ``Await[W]``
30
+ and returned on all following ``_awaitable_wait`` calls.
31
+
32
+
33
+ ``torch.jit._awaitable_nowait(W)``
34
+ Returns:
35
+ Trivial ``Await[W]`` with specified result.
36
+
37
+
38
+ Only in eager mode:
39
+ ``fn() -> Callable[Tuple[Any], W]``
40
+ Returns:
41
+ Specified at ``_awaitable`` python function ``func``.
42
+
43
+ ``args() -> Tuple[Any]``
44
+ Returns:
45
+ Specified at ``_awaitable`` python args.
46
+
47
+ ``is_nowait() -> _bool``
48
+ Returns:
49
+ ``True`` if this object was created via ``_awaitable_nowait`` call (trivial `Await[W]`).
50
+
51
+ In eager mode ``Await[W]`` can be used as ``W`` i.e. attributes of W can be called on ``Await[W]``,
52
+ ``_awaitable_wait()`` call will be transparently added.
53
+ """
phivenv/Lib/site-packages/torch/_awaits/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (2.03 kB). View file
 
phivenv/Lib/site-packages/torch/_custom_op/__init__.py ADDED
File without changes
phivenv/Lib/site-packages/torch/_custom_op/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (160 Bytes). View file
 
phivenv/Lib/site-packages/torch/_custom_op/__pycache__/autograd.cpython-39.pyc ADDED
Binary file (8.9 kB). View file
 
phivenv/Lib/site-packages/torch/_custom_op/__pycache__/impl.cpython-39.pyc ADDED
Binary file (21.2 kB). View file
 
phivenv/Lib/site-packages/torch/_custom_op/autograd.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import functools
3
+ from collections import namedtuple
4
+
5
+ import torch
6
+ import torch.utils._pytree as pytree
7
+
8
+
9
+ # NOTE [CustomOp autograd kernel indirection]
10
+ # We register `inner` as the autograd kernel for this custom_op.
11
+ # `inner` either calls the autograd formula registered by the user,
12
+ # or goes into an `autograd_not_implemented` kernel.
13
+ #
14
+ # The reason why this indirection exists is
15
+ # so that we can swap out the autograd kernel (the PyTorch dispatcher
16
+ # doesn't actually allow us to do this). By default, we want
17
+ # the `autograd_not_implemented` behavior, but then the user may come
18
+ # and register something that is actually a backward formula
19
+ def autograd_kernel_indirection(custom_op):
20
+ autograd_fallback = autograd_not_implemented(custom_op)
21
+
22
+ def inner(*args, **kwargs):
23
+ if custom_op._has_impl("autograd"):
24
+ kernel = custom_op._get_impl("autograd").func
25
+ return kernel(*args, **kwargs)
26
+ # As explained in NOTE ["backward", "save_for_backward", and "autograd"],
27
+ # after the user gives us "backward" and "save_for_backward", we generate
28
+ # the "autograd" impl. If the user only provided one, then we tell
29
+ # the user they've done something wrong.
30
+ if custom_op._has_impl("save_for_backward") or custom_op._has_impl("backward"):
31
+ missing = (
32
+ "save_for_backward" if custom_op._has_impl("backward") else "backward"
33
+ )
34
+ found = "save_for_backward" if missing == "backward" else "backward"
35
+ loc = custom_op._get_impl(found).location
36
+ raise RuntimeError(
37
+ f"We found a '{found}' registration for {custom_op} at "
38
+ f"{loc} but were unable to find a '{missing}' registration. "
39
+ f"To use the CustomOp API to register a backward formula, "
40
+ f"please provide us both a backward function and a "
41
+ f"'save for backward' function via `impl_backward` and "
42
+ f"`impl_save_for_backward` respectively."
43
+ )
44
+ return autograd_fallback(*args, **kwargs)
45
+
46
+ return inner
47
+
48
+
49
+ # TODO(#101191): Use the actual C++ autograd not implemented fallback,
50
+ # or change the default autograd fallback to the autograd not implemented fallback.
51
+ def autograd_not_implemented(custom_op):
52
+ def kernel(*args, **kwargs):
53
+ if torch.is_grad_enabled() and pytree.tree_any(
54
+ lambda x: isinstance(x, torch.Tensor) and x.requires_grad, (args, kwargs)
55
+ ):
56
+ raise RuntimeError("Autograd has not been implemented for operator")
57
+ with torch._C._AutoDispatchBelowAutograd():
58
+ return custom_op(*args, **kwargs)
59
+
60
+ return kernel
61
+
62
+
63
+ def mark_non_differentiable(ctx, output, output_differentiability):
64
+ # Output types are restricted to be:
65
+ # - Tensor
66
+ # - Tensor[]
67
+ # - int, bool, Scalar, float
68
+ # See _check_can_register_backward
69
+ if output_differentiability is not None:
70
+ if not isinstance(output, tuple):
71
+ tuple_output = (output,)
72
+ else:
73
+ tuple_output = output # type: ignore[assignment]
74
+ assert len(output_differentiability) == len(tuple_output)
75
+ non_differentiable_tensors = []
76
+ for idx, (differentiable, out) in enumerate(
77
+ zip(output_differentiability, tuple_output)
78
+ ):
79
+ if isinstance(out, torch.Tensor):
80
+ if not differentiable:
81
+ non_differentiable_tensors.append(out)
82
+ continue
83
+ if isinstance(out, list):
84
+ if not differentiable:
85
+ non_differentiable_tensors.extend(out)
86
+ continue
87
+ if differentiable:
88
+ raise RuntimeError(
89
+ f"With output_differentiability={output_differentiability}. "
90
+ f"At idx {idx}, we received an object of type {type(out)} that "
91
+ f"is not a Tensor, so it cannot have be marked as differentiable in "
92
+ f"output_differentiability."
93
+ )
94
+ if non_differentiable_tensors:
95
+ ctx.mark_non_differentiable(*non_differentiable_tensors)
96
+
97
+
98
+ def construct_autograd_kernel(
99
+ schema,
100
+ output_differentiability,
101
+ custom_op,
102
+ op_overload,
103
+ save_for_backward_fn,
104
+ backward_fn,
105
+ ):
106
+ def apply(*args):
107
+ flat_args, spec = pytree.tree_flatten(args)
108
+ out_spec = None
109
+
110
+ def forward(ctx, *flat_args):
111
+ ctx.set_materialize_grads(True)
112
+ args = pytree.tree_unflatten(list(flat_args), spec)
113
+ with torch._C._AutoDispatchBelowAutograd():
114
+ output = op_overload(*args)
115
+
116
+ # We use the info about args to give better error messages in backward
117
+ args_info = namedtuple_args(schema, pytree.tree_map(type, args))
118
+
119
+ save_for_backward_fn_inputs = namedtuple_args(schema, args)
120
+ to_save = save_for_backward_fn(save_for_backward_fn_inputs, output)
121
+
122
+ save_pytree_for_backward(ctx, (to_save, args_info))
123
+ mark_non_differentiable(ctx, output, output_differentiability)
124
+
125
+ nonlocal out_spec
126
+ flat_output, out_spec = pytree.tree_flatten(output)
127
+ return tuple(flat_output)
128
+
129
+ def backward(ctx, *flat_grad_output):
130
+ assert out_spec is not None
131
+ grads = pytree.tree_unflatten(list(flat_grad_output), out_spec)
132
+ saved, args_info = unpack_saved(ctx)
133
+ # There is nothing on the ctx object for now, it is just there so
134
+ # that we can add additional things in the future.
135
+ inner_ctx = object()
136
+ if not isinstance(grads, tuple):
137
+ grads = (grads,)
138
+ grad_inputs_dict = backward_fn(inner_ctx, saved, *grads)
139
+
140
+ # Massage the grad_inputs_dict to a form acceptable by
141
+ # autograd.Function.
142
+ validate_grad_inputs_dict(grad_inputs_dict, custom_op, args_info)
143
+ return grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info)
144
+
145
+ generated_cls = gen_autograd_function(
146
+ custom_op._opname + "_customop", forward, backward
147
+ )
148
+
149
+ flat_output = generated_cls.apply(*flat_args)
150
+ assert out_spec is not None
151
+ return pytree.tree_unflatten(list(flat_output), out_spec)
152
+
153
+ return apply
154
+
155
+
156
+ def gen_autograd_function(name, forward, backward):
157
+ generated_cls = type(
158
+ name,
159
+ (torch.autograd.Function,),
160
+ {
161
+ "forward": staticmethod(forward),
162
+ "backward": staticmethod(backward),
163
+ },
164
+ )
165
+ return generated_cls
166
+
167
+
168
+ @functools.lru_cache
169
+ def namedtuple_args_cls(schema):
170
+ attribs = [arg.name for arg in schema.arguments.flat_all]
171
+ name = str(schema.name) + "_args"
172
+ # mypy doesn't support dynamic namedtuple name
173
+ tuple_cls = namedtuple(name, attribs) # type: ignore[misc]
174
+ return tuple_cls
175
+
176
+
177
+ def namedtuple_args(schema, args):
178
+ assert isinstance(args, tuple)
179
+ tuple_cls = namedtuple_args_cls(schema)
180
+ return tuple_cls(*args)
181
+
182
+
183
+ def validate_grad_inputs_dict(grad_inputs_dict, forward_op, args_info):
184
+ def error(what):
185
+ backward = forward_op._get_impl("backward")
186
+ raise RuntimeError(
187
+ f"In the backward function defined for {forward_op} at "
188
+ f"{backward.location} using the CustomOp API, {what}"
189
+ )
190
+
191
+ if not isinstance(grad_inputs_dict, dict):
192
+ error(
193
+ f"expected the output of the backward function to be a dict but "
194
+ f"got {type(grad_inputs_dict)}"
195
+ )
196
+
197
+ expected_keys = {
198
+ arg.name
199
+ for arg in forward_op._schema.arguments.flat_all
200
+ if arg.type.is_tensor_like()
201
+ }
202
+ actual_keys = grad_inputs_dict.keys()
203
+ if expected_keys != actual_keys:
204
+ error(
205
+ f"expected the returned grad_input dict to have keys "
206
+ f"{expected_keys} but got {actual_keys}. The backward "
207
+ f"function must return a gradient (can be None) for each arg "
208
+ f"to the CustomOp that may be a Tensor or Sequence[Tensor]. "
209
+ f"Args declared to be non-Tensor-like types should not appear "
210
+ f"in the grad_input dict"
211
+ )
212
+
213
+ for name, grad in grad_inputs_dict.items():
214
+ arg_info = getattr(args_info, name)
215
+
216
+ if isinstance(arg_info, list):
217
+ if not isinstance(grad, (tuple, list)):
218
+ error(
219
+ f"for input '{name}' expected the grad_input dict to "
220
+ f"hold a list of gradients but got object of type "
221
+ f"{type(grad)}."
222
+ )
223
+ if not len(grad) == len(arg_info):
224
+ error(
225
+ f"for input '{name}' expected the grad_input dict to "
226
+ f"hold a list of {len(arg_info)} gradients but got "
227
+ f"{len(grad)}"
228
+ )
229
+ for idx, (g, info) in enumerate(zip(grad, arg_info)):
230
+ if g is None:
231
+ continue
232
+ if not isinstance(g, torch.Tensor):
233
+ error(
234
+ f"for input '{name}' expected the grad_input dict to "
235
+ f"hold a list of None or Tensor gradients but got "
236
+ f"object of {type(g)} at index {idx}"
237
+ )
238
+ if not issubclass(info, torch.Tensor):
239
+ error(
240
+ f"for input '{name}', got a Tensor as the gradient "
241
+ f"for the {idx}-th value but expected None because "
242
+ f"the {idx}-th value was not a Tensor (it was "
243
+ f"type {arg_info}"
244
+ )
245
+ continue
246
+
247
+ if grad is None:
248
+ continue
249
+ if not isinstance(grad, torch.Tensor):
250
+ error(
251
+ f"got object of type {type(grad)} as the gradient for input "
252
+ f"'{name}', "
253
+ f"but expected the gradient to be either None or a Tensor"
254
+ )
255
+ if not issubclass(arg_info, torch.Tensor):
256
+ error(
257
+ f"got a Tensor as the gradient for input '{name}' but "
258
+ f"expected None as the gradient because input '{name}' "
259
+ f"was not a Tensor (it was type {arg_info})."
260
+ )
261
+
262
+
263
+ def grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info):
264
+ result = []
265
+ for name, arg_info in args_info._asdict().items():
266
+ if name not in grad_inputs_dict:
267
+ result.append(pytree.tree_map(lambda x: None, arg_info))
268
+ continue
269
+ result.append(grad_inputs_dict[name])
270
+ return tuple(pytree.tree_leaves(result))
271
+
272
+
273
+ # Saves "stuff" (a pytree) onto the ctx object. Use unpack_saved to unpack it.
274
+ # autograd.Function prefers that users use ctx.save_for_backward to
275
+ # save Tensors (to avoid reference cycles) and for non-Tensors to go onto the
276
+ # ctx object.
277
+ def save_pytree_for_backward(ctx, stuff):
278
+ flat_stuff, spec = pytree.tree_flatten(stuff)
279
+ num_elts = len(flat_stuff)
280
+ tensor_idxs = [
281
+ idx for idx, thing in enumerate(flat_stuff) if isinstance(thing, torch.Tensor)
282
+ ]
283
+ non_tensor_idxs = [
284
+ idx
285
+ for idx, thing in enumerate(flat_stuff)
286
+ if not isinstance(thing, torch.Tensor)
287
+ ]
288
+ tensors = [thing for thing in flat_stuff if isinstance(thing, torch.Tensor)]
289
+ non_tensors = [thing for thing in flat_stuff if not isinstance(thing, torch.Tensor)]
290
+
291
+ ctx.spec = spec
292
+ ctx.num_elts = num_elts
293
+ ctx.save_for_backward(*tensors)
294
+ ctx.tensor_idxs = tensor_idxs
295
+ ctx.saved_non_tensors = non_tensors
296
+ ctx.non_tensor_idxs = non_tensor_idxs
297
+
298
+
299
+ # Inverse operation to save_pytree_for_backward
300
+ def unpack_saved(ctx):
301
+ flat_stuff = [None] * ctx.num_elts
302
+ for tensor, idx in zip(ctx.saved_tensors, ctx.tensor_idxs):
303
+ flat_stuff[idx] = tensor
304
+ for non_tensor, idx in zip(ctx.saved_non_tensors, ctx.non_tensor_idxs):
305
+ flat_stuff[idx] = non_tensor
306
+ stuff = pytree.tree_unflatten(flat_stuff, ctx.spec)
307
+ return stuff
phivenv/Lib/site-packages/torch/_custom_op/impl.py ADDED
@@ -0,0 +1,715 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import dataclasses
3
+ import functools
4
+ import inspect
5
+ import sys
6
+ import typing
7
+ import warnings
8
+ import weakref
9
+
10
+ import torch
11
+ import torch._C as _C
12
+ import torch._library.infer_schema
13
+ import torch.library as library
14
+ from torch._library.infer_schema import infer_schema
15
+ from torch.library import get_ctx
16
+ from torchgen.model import (
17
+ BaseTy,
18
+ BaseType,
19
+ FunctionSchema,
20
+ ListType,
21
+ OperatorName,
22
+ SchemaKind,
23
+ )
24
+
25
+ from .autograd import autograd_kernel_indirection, construct_autograd_kernel
26
+
27
+
28
+ """
29
+ torch._custom_op is deprecated. We shipped a production-ready version of it into torch.library.
30
+ Please use those APIs instead.
31
+ """
32
+
33
+ __all__ = ["custom_op", "CustomOp", "get_ctx"]
34
+
35
+
36
+ SUPPORTED_DEVICE_TYPE_TO_KEY = {
37
+ "cpu": "CPU",
38
+ "cuda": "CUDA",
39
+ }
40
+
41
+ # We will not let users register CustomOps with anything that could look like
42
+ # PyTorch internals to avoid confusion.
43
+ RESERVED_NS = {
44
+ "prim",
45
+ "prims",
46
+ "aten",
47
+ "at",
48
+ "torch",
49
+ "pytorch",
50
+ }
51
+
52
+
53
+ def warn_deprecated():
54
+ warnings.warn(
55
+ "torch._custom_op is deprecated and will be removed in PyTorch 2.6, please "
56
+ "use the equivalent torch.library API instead.",
57
+ DeprecationWarning,
58
+ )
59
+
60
+
61
+ def custom_op(
62
+ qualname: str, manual_schema: typing.Optional[str] = None
63
+ ) -> typing.Callable:
64
+ r"""
65
+ This API is deprecated, please use torch.library.custom_op instead
66
+ """
67
+ warn_deprecated()
68
+
69
+ def inner(func):
70
+ if not inspect.isfunction(func):
71
+ raise ValueError(
72
+ f"custom_op(...)(func): Expected `func` to be a Python "
73
+ f"function, got: {type(func)}"
74
+ )
75
+
76
+ ns, name = parse_qualname(qualname)
77
+ validate_namespace(ns)
78
+ if func.__name__ != name:
79
+ raise ValueError(
80
+ f"custom_op(qualname='{qualname}', ...)(func): expected `func` "
81
+ f"to have name '{name}' but got '{func.__name__}'. "
82
+ f"Please either change the name of `func` or the qualname that "
83
+ f"is passed to `custom_op`"
84
+ )
85
+
86
+ schema = (
87
+ infer_schema(func, mutates_args=())
88
+ if manual_schema is None
89
+ else manual_schema
90
+ )
91
+ schema_str = f"{name}{schema}"
92
+ function_schema = FunctionSchema.parse(schema_str)
93
+ validate_schema(function_schema)
94
+ if manual_schema is not None:
95
+ validate_function_matches_schema(function_schema, func)
96
+
97
+ lib = library.Library(ns, "FRAGMENT")
98
+ lib.define(schema_str)
99
+ ophandle = find_ophandle_or_throw(ns, function_schema.name)
100
+ result = CustomOp(
101
+ lib, ns, function_schema, name, ophandle, _private_access=True
102
+ )
103
+
104
+ result.__name__ = func.__name__
105
+ result.__module__ = func.__module__
106
+ result.__doc__ = func.__doc__
107
+
108
+ library.impl(lib, result._opname, "Autograd")(
109
+ autograd_kernel_indirection(weakref.proxy(result))
110
+ )
111
+
112
+ torch._C._dispatch_set_report_error_callback(
113
+ ophandle, functools.partial(report_error_callback, weakref.proxy(result))
114
+ )
115
+
116
+ return result
117
+
118
+ return inner
119
+
120
+
121
+ # Global dictionary holding references to all CustomOp objects
122
+ # Yes, it keeps all CustomOps alive (see NOTE [CustomOp lifetime])
123
+ # Used to query the CustomOp associated with a specific C++ dispatcher operator.
124
+ # An example usage is FakeTensor: FakeTensor checks if a specific operator
125
+ # has an implementation registered via the CustomOp API.
126
+ # Indexed by qualname (e.g. aten::foo)
127
+ global_registry: dict[str, "CustomOp"] = {}
128
+
129
+
130
+ class CustomOp:
131
+ r"""
132
+ This API is deprecated, please use torch.library.custom_op instead
133
+ """
134
+
135
+ def __init__(
136
+ self, lib, cpp_ns, schema, operator_name, ophandle, *, _private_access=False
137
+ ):
138
+ super().__init__()
139
+ warn_deprecated()
140
+ if not _private_access:
141
+ raise RuntimeError(
142
+ "The CustomOp constructor is private and we do not guarantee "
143
+ "BC for it. Please use custom_op(...) to create a CustomOp object"
144
+ )
145
+ name = f"{cpp_ns}::{operator_name}"
146
+ self._schema = schema
147
+ self._cpp_ns = cpp_ns
148
+ self._lib: library.Library = lib
149
+ self._ophandle: _C._DispatchOperatorHandle = ophandle
150
+ # Has the name of the op, e.g. "foo". We cache here for convenience.
151
+ self._opname: str = operator_name
152
+ # this is _opname but with namespace. e.g. "custom::foo"
153
+ self._qualname: str = name
154
+ self.__name__ = None # mypy requires this
155
+ # NB: Some of these impls are registered as kernels to DispatchKeys.
156
+ # Modifying the _impls dict directly won't do anything in that case.
157
+ self._impls: dict[str, typing.Optional[FuncAndLocation]] = {}
158
+ # See NOTE [CustomOp autograd kernel indirection]
159
+ self._registered_autograd_kernel_indirection = False
160
+
161
+ global_registry[self._qualname] = self
162
+
163
+ def _register_autograd_kernel_indirection(self):
164
+ assert not self._registered_autograd_kernel_indirection
165
+ self._lib.impl(
166
+ self._opname, autograd_kernel_indirection(weakref.proxy(self)), "Autograd"
167
+ )
168
+ self._registered_autograd_kernel_indirection = True
169
+
170
+ # Records the impl and the source location in self._impls
171
+ # Note that this doesn't cause torch.library to use the impl, that
172
+ # needs to be done in a separate self._lib.impl call.
173
+ def _register_impl(self, kind, func, stacklevel=2):
174
+ if self._has_impl(kind):
175
+ func_and_location = self._impls[kind]
176
+ assert func_and_location is not None # Pacify mypy
177
+ location = func_and_location.location
178
+ raise RuntimeError(
179
+ f"Attempting to register a {kind} impl for operator {self._qualname} "
180
+ f"that already has a {kind} impl registered from Python at "
181
+ f"{location}. This is not supported."
182
+ )
183
+ frame = inspect.getframeinfo(sys._getframe(stacklevel))
184
+ location = f"{frame.filename}:{frame.lineno}"
185
+ self._impls[kind] = FuncAndLocation(func, location)
186
+
187
+ def _get_impl(self, kind):
188
+ return self._impls[kind]
189
+
190
+ def _has_impl(self, kind):
191
+ return kind in self._impls
192
+
193
+ def _destroy(self):
194
+ # NOTE: [CustomOp lifetime]
195
+ # A CustomOp, once created, lives forever. The mechanism is that the
196
+ # global registry holds a reference to it. However, to make testing
197
+ # easier, we want to be able to destroy CustomOp objects.
198
+ # CustomOp._destroy does the job, though it leaves the CustomOp
199
+ # in a garbage state.
200
+ del self._lib
201
+
202
+ opnamespace = getattr(torch.ops, self._cpp_ns)
203
+ if hasattr(opnamespace, self._opname):
204
+ delattr(opnamespace, self._opname)
205
+
206
+ del global_registry[self._qualname]
207
+
208
+ def __repr__(self):
209
+ return f'<CustomOp(op="{self._qualname}")>'
210
+
211
+ def __call__(self, *args, **kwargs):
212
+ # Bypass torch.ops.* and directly do OperatorHandle::callBoxed.
213
+ # Using torch.ops.* is a bit of a pain (it can be slow and it has lifetime
214
+ # issues from caching operators that make testing CustomOp difficult).
215
+ result = _C._dispatch_call_boxed(self._ophandle, *args, **kwargs)
216
+ return result
217
+
218
+ def impl(
219
+ self,
220
+ device_types: typing.Union[str, typing.Iterable[str]],
221
+ _stacklevel=2,
222
+ ) -> typing.Callable:
223
+ r"""
224
+ This API is deprecated, please use torch.library.custom_op instead
225
+ """
226
+ if isinstance(device_types, str):
227
+ device_types = [device_types]
228
+ for device_type in device_types:
229
+ validate_device_type(device_type)
230
+
231
+ def inner(f):
232
+ for device_type in set(device_types):
233
+ self._check_doesnt_have_library_impl(device_type)
234
+ self._register_impl(device_type, f, stacklevel=_stacklevel)
235
+ dispatch_key = SUPPORTED_DEVICE_TYPE_TO_KEY[device_type]
236
+ library.impl(self._lib, self._opname, dispatch_key)(f)
237
+ return f
238
+
239
+ return inner
240
+
241
+ def _check_doesnt_have_library_impl(self, device_type):
242
+ if self._has_impl(device_type):
243
+ return
244
+ key = SUPPORTED_DEVICE_TYPE_TO_KEY[device_type]
245
+ if _C._dispatch_has_computed_kernel_for_dispatch_key(self._qualname, key):
246
+ raise RuntimeError(
247
+ f"impl(..., device_types={device_type}): the operator {self._qualname} "
248
+ f"already has an implementation for this device type via a "
249
+ f"pre-existing torch.library or TORCH_LIBRARY registration."
250
+ )
251
+
252
+ def impl_factory(self) -> typing.Callable:
253
+ r"""Register an implementation for a factory function."""
254
+
255
+ def inner(f):
256
+ self._register_impl("factory", f)
257
+ library.impl(self._lib, self._opname, "BackendSelect")(f)
258
+ return f
259
+
260
+ return inner
261
+
262
+ def impl_abstract(self, _stacklevel=2) -> typing.Callable:
263
+ r"""
264
+ This API is deprecated, please use torch.library.custom_op instead
265
+ """
266
+
267
+ def inner(f):
268
+ self._check_doesnt_have_library_meta_impl()
269
+ self._register_impl("abstract", f, stacklevel=_stacklevel)
270
+ location = self._get_impl("abstract").location
271
+
272
+ qualname = self._qualname
273
+
274
+ # Handle DispatchKey.Meta registration
275
+ @functools.wraps(f)
276
+ def f_with_ctx(*args, **kwargs):
277
+ def error_on_ctx():
278
+ raise RuntimeError(
279
+ f"Attempted to call get_ctx() for the meta implementation "
280
+ f"for {qualname}."
281
+ f"You have presumably called get_ctx() because the operator "
282
+ f"has a data-dependent output shape; if so, there is no "
283
+ f"such meta implementation and this error is the correct "
284
+ f"behavior. Otherwise, please remove the call to get_ctx() "
285
+ f"in the implementation registered with impl_abstract "
286
+ f"at {location}"
287
+ )
288
+
289
+ with torch._library.fake_impl.set_ctx_getter(error_on_ctx):
290
+ return f(*args, **kwargs)
291
+
292
+ self._lib.impl(self._opname, f_with_ctx, "Meta")
293
+ return f
294
+
295
+ return inner
296
+
297
+ def _check_can_register_backward(self):
298
+ def error(detail):
299
+ raise RuntimeError(
300
+ f"Cannot use torch._custom_ops APIs to register backward "
301
+ f"formula for {detail}. Got operator "
302
+ f"{self._qualname} with schema: {schema}"
303
+ )
304
+
305
+ schema = self._schema
306
+ if schema.kind() != SchemaKind.functional:
307
+ error("non-functional operator")
308
+
309
+ rets = schema.returns
310
+ if not schema.returns:
311
+ error("operator with no returns")
312
+
313
+ assert len(rets) > 0
314
+ is_non_mutating_view = any(
315
+ r.annotation is not None and not r.annotation.is_write for r in rets
316
+ )
317
+ if is_non_mutating_view:
318
+ error("operator that returns views")
319
+
320
+ # We make assumptions about the schema's return types.
321
+ allowed_return_types = {
322
+ BaseType(BaseTy.int): "int",
323
+ BaseType(BaseTy.SymInt): "SymInt",
324
+ BaseType(BaseTy.bool): "bool",
325
+ BaseType(BaseTy.float): "float",
326
+ BaseType(BaseTy.Tensor): "Tensor",
327
+ ListType(BaseType(BaseTy.Tensor), None): "List[Tensor]",
328
+ }
329
+ for ret in schema.returns:
330
+ if ret.type in allowed_return_types:
331
+ continue
332
+ error(
333
+ f"operator with return not in {list(allowed_return_types.values())} (got {ret.type})"
334
+ )
335
+
336
+ def _check_doesnt_have_library_autograd_impl(self):
337
+ if self._registered_autograd_kernel_indirection:
338
+ return
339
+
340
+ if _C._dispatch_has_kernel_for_dispatch_key(
341
+ self._qualname, "CompositeImplicitAutograd"
342
+ ):
343
+ raise RuntimeError(
344
+ f"impl_backward/impl_save_for_backward: the operator {self._qualname} "
345
+ f"already has an implementation for this device type via a "
346
+ f"pre-existing registration to DispatchKey::CompositeImplicitAutograd."
347
+ f"CompositeImplicitAutograd operators do not need an autograd formula; "
348
+ f"instead, the operator will decompose into its constituents and those "
349
+ f"can have autograd formulas defined on them."
350
+ )
351
+
352
+ # We can improve this by adding "all Autograd<BACKEND> keys", but
353
+ # realistically people will just be using this API for CPU/CUDA for now.
354
+ for key in ["Autograd", "AutogradCPU", "AutogradCUDA"]:
355
+ if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, key):
356
+ raise RuntimeError(
357
+ f"impl_backward/impl_save_for_backward: "
358
+ f"the operator {self._qualname} already has an Autograd kernel "
359
+ f"registered to DispatchKey::{key} vi a pre-existing "
360
+ f"torch.library or TORCH_LIBRARY registration. Please either "
361
+ f"remove those registrations or don't use the torch._custom_ops APIs"
362
+ )
363
+
364
+ def _check_doesnt_have_library_meta_impl(self):
365
+ if self._has_impl("abstract"):
366
+ return
367
+
368
+ # If the user's operator is CompositeExplicitAutograd,
369
+ # allow them to impl_abstract. This is being pragmatic
370
+ # (existing custom ops may have CompositeExplicitAutograd
371
+ # registration that don't work with Meta kernels, so this
372
+ # gives them an escape hatch).
373
+ if _C._dispatch_has_kernel_for_dispatch_key(
374
+ self._qualname, "CompositeExplicitAutograd"
375
+ ) and not _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "Meta"):
376
+ return
377
+
378
+ # Otherwise, if the user's already has a Meta kernel or their
379
+ # op is CompositeImplicitAutograd or some other alias dispatch key,
380
+ # raise.
381
+
382
+ # Special case for CompositeImplicitAutograd
383
+ if _C._dispatch_has_kernel_for_dispatch_key(
384
+ self._qualname, "CompositeImplicitAutograd"
385
+ ):
386
+ raise RuntimeError(
387
+ f"impl_abstract(...): the operator {self._qualname} "
388
+ f"already has an implementation for this device type via a "
389
+ f"pre-existing registration to DispatchKey::CompositeImplicitAutograd."
390
+ f"CompositeImplicitAutograd operators do not need an abstract impl; "
391
+ f"instead, the operator will decompose into its constituents and those "
392
+ f"can have abstract impls defined on them."
393
+ )
394
+
395
+ if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "Meta"):
396
+ raise RuntimeError(
397
+ f"impl_abstract(...): the operator {self._qualname} "
398
+ f"already has an DispatchKey::Meta implementation via a "
399
+ f"pre-existing torch.library or TORCH_LIBRARY registration. "
400
+ f"Please either remove that registration or don't call impl_abstract."
401
+ )
402
+
403
+ # NOTE ["backward", "save_for_backward", and "autograd"]
404
+ # As a part of the explicit autograd API, a user must provide us
405
+ # a "save_for_backward" function and a "backward" function.
406
+ # When both of these have been provided, then we automatically
407
+ # construct the "autograd" kernel.
408
+ def _register_autograd_kernel(self):
409
+ assert self._has_impl("backward")
410
+ assert self._has_impl("save_for_backward")
411
+ kernel = construct_autograd_kernel(
412
+ self._schema,
413
+ self._output_differentiability,
414
+ self,
415
+ get_op(self._qualname),
416
+ self._get_impl("save_for_backward").func,
417
+ self._get_impl("backward").func,
418
+ )
419
+ self._register_impl("autograd", kernel)
420
+
421
+ def impl_save_for_backward(self, _stacklevel=2):
422
+ r"""Register a function that tells us what to save for backward.
423
+
424
+ Please see impl_backward for more details.
425
+ """
426
+
427
+ def inner(f):
428
+ self._check_can_register_backward()
429
+ self._check_doesnt_have_library_autograd_impl()
430
+ if not self._registered_autograd_kernel_indirection:
431
+ self._register_autograd_kernel_indirection()
432
+ self._register_impl("save_for_backward", f, stacklevel=_stacklevel)
433
+ if self._has_impl("backward"):
434
+ self._register_autograd_kernel()
435
+
436
+ return inner
437
+
438
+ def impl_backward(self, output_differentiability=None, _stacklevel=2):
439
+ r"""
440
+ This API is deprecated, please use torch.library.custom_op instead
441
+ """
442
+ if output_differentiability is not None:
443
+
444
+ def yell():
445
+ raise RuntimeError(
446
+ f"impl_backward(output_differentiability): expected "
447
+ f"output_differentiability to be a list of bools with "
448
+ f"length equal to the number of outputs of this CustomOp "
449
+ f"got: {output_differentiability}"
450
+ )
451
+
452
+ if not isinstance(output_differentiability, list):
453
+ yell()
454
+ for diff in output_differentiability:
455
+ if not isinstance(diff, bool):
456
+ yell()
457
+ if len(self._schema.returns) != len(output_differentiability):
458
+ yell()
459
+
460
+ def inner(f):
461
+ self._check_can_register_backward()
462
+ self._check_doesnt_have_library_autograd_impl()
463
+ if not self._registered_autograd_kernel_indirection:
464
+ self._register_autograd_kernel_indirection()
465
+ self._register_impl("backward", f, stacklevel=_stacklevel)
466
+ self._output_differentiability = output_differentiability
467
+ if self._has_impl("save_for_backward"):
468
+ self._register_autograd_kernel()
469
+
470
+ return inner
471
+
472
+
473
+ @dataclasses.dataclass
474
+ class FuncAndLocation:
475
+ func: typing.Callable
476
+ location: str
477
+
478
+
479
+ def find_ophandle_or_throw(cpp_ns: str, operator_name: OperatorName):
480
+ overload_name = (
481
+ "" if operator_name.overload_name is None else operator_name.overload_name
482
+ )
483
+ return _C._dispatch_find_schema_or_throw(
484
+ f"{cpp_ns}::{str(operator_name.name)}", overload_name
485
+ )
486
+
487
+
488
+ def validate_namespace(ns: str) -> None:
489
+ if "." in ns:
490
+ raise ValueError(
491
+ f'custom_op(..., ns="{ns}"): expected ns to not contain any . (and be a '
492
+ f"valid variable name)"
493
+ )
494
+ if ns in RESERVED_NS:
495
+ raise ValueError(
496
+ f"custom_op(..., ns='{ns}'): '{ns}' is a reserved namespace, "
497
+ f"please choose something else. "
498
+ )
499
+
500
+
501
+ def validate_schema(schema: FunctionSchema) -> None:
502
+ if not torch._library.utils.is_functional_schema(schema):
503
+ raise ValueError(
504
+ f"custom_op only supports functional operators "
505
+ f"(ops that do not mutate any inputs, do not return "
506
+ f"views of the inputs, and has at least one return). "
507
+ f"Got the following non-functional schema: {schema}"
508
+ )
509
+
510
+ # For simplicity: don't allow self arguments
511
+ if schema.arguments.self_arg is not None:
512
+ raise ValueError(
513
+ f"custom_op does not support arguments named 'self'. Please "
514
+ f"rename your argument. Got: {schema}"
515
+ )
516
+
517
+
518
+ def parse_qualname(qualname: str) -> tuple[str, str]:
519
+ names = qualname.split("::", 1)
520
+ if len(names) != 2:
521
+ raise ValueError(
522
+ f"Expected there to be a namespace in {qualname}, i.e. The "
523
+ f"operator name should look something like ns::foo"
524
+ )
525
+ if "." in names[1]:
526
+ raise ValueError(
527
+ f"The torch.custom_ops APIs do not handle overloads, "
528
+ f"i.e. operator names with '.' in them. "
529
+ f"Please name your operator something like ns::foo. "
530
+ f"Got: {qualname}"
531
+ )
532
+ return names[0], names[1]
533
+
534
+
535
+ def validate_device_type(device_type: str) -> None:
536
+ if device_type not in SUPPORTED_DEVICE_TYPE_TO_KEY:
537
+ raise ValueError(
538
+ f"CustomOp.impl(device_types=[{device_type}, ...]): we only support device_type "
539
+ f"in {SUPPORTED_DEVICE_TYPE_TO_KEY.keys()}."
540
+ )
541
+
542
+
543
+ def supported_param(param: inspect.Parameter) -> bool:
544
+ return param.kind in (
545
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
546
+ inspect.Parameter.KEYWORD_ONLY,
547
+ )
548
+
549
+
550
+ def validate_function_matches_schema(
551
+ schema: FunctionSchema, func: typing.Callable
552
+ ) -> None:
553
+ sig = inspect.signature(func)
554
+
555
+ if not all(supported_param(p) for _, p in sig.parameters.items()):
556
+ raise ValueError(
557
+ f"custom_op(..., manual_schema)(func): positional-only args, "
558
+ f"varargs, and kwargs are not supported. Please rewrite `func` "
559
+ f"to not have them. Got `func` with signature: {sig}"
560
+ )
561
+
562
+ if (
563
+ any(
564
+ p.annotation is not inspect.Parameter.empty
565
+ for _, p in sig.parameters.items()
566
+ )
567
+ or sig.return_annotation is not inspect.Signature.empty
568
+ ):
569
+ raise ValueError(
570
+ f"custom_op(..., manual_schema)(func): When passing in a manual "
571
+ f"schema, we expect `func` to have no type annotations to avoid "
572
+ f"ambiguity. Got `func` with signature: {sig}"
573
+ )
574
+
575
+ positional = [
576
+ (name, param)
577
+ for name, param in sig.parameters.items()
578
+ if param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
579
+ ]
580
+ kwargonly = [
581
+ (name, param)
582
+ for name, param in sig.parameters.items()
583
+ if param.kind == inspect.Parameter.KEYWORD_ONLY
584
+ ]
585
+
586
+ def error():
587
+ raise ValueError(
588
+ f"custom_op(..., manual_schema)(func): When passing in a manual "
589
+ f"schema, we expect `func`'s signature to match `manual_schema` "
590
+ f"(aside from type annotations). "
591
+ f"func's signature: {sig}, manual_schema: {schema}"
592
+ )
593
+
594
+ def error_default_args():
595
+ raise ValueError(
596
+ f"custom_op(..., manual_schema)(func): "
597
+ f"neither func nor manual_schema should have default "
598
+ f"arguments. Got "
599
+ f"func's signature: {sig}, manual_schema: {schema}"
600
+ )
601
+
602
+ def compare(sig_args, schema_args):
603
+ if len(sig_args) != len(schema_args):
604
+ error()
605
+ for (name, param), arg in zip(sig_args, schema_args):
606
+ if name != arg.name:
607
+ error()
608
+ if param.default is not inspect.Parameter.empty or arg.default is not None:
609
+ error_default_args()
610
+
611
+ compare(positional, schema.arguments.flat_positional)
612
+ compare(kwargonly, schema.arguments.flat_kwarg_only)
613
+
614
+
615
+ def report_error_callback(custom_op: typing.Any, key: str) -> None:
616
+ if key == "Undefined":
617
+ raise NotImplementedError(
618
+ f"{custom_op}: There were no Tensor inputs to this operator "
619
+ f"(e.g. you passed an empty list of Tensors). If your operator is a "
620
+ f"factory function (that is, it takes no Tensors and constructs "
621
+ f"a new one), then please use CustomOp.impl_factory to register "
622
+ f"an implementation for it"
623
+ )
624
+ if key == "Meta":
625
+ raise NotImplementedError(
626
+ f"{custom_op}: when running with device='Meta' tensors: there is no "
627
+ f"abstract impl registered for this CustomOp. Please register one via "
628
+ f"CustomOp.impl_abstract to get this CustomOp to work with Meta tensors"
629
+ )
630
+ if key in ("CPU", "CUDA"):
631
+ device = key.lower()
632
+ raise NotImplementedError(
633
+ f"{custom_op}: when running with device='{device}' tensors: there is no "
634
+ f"{device} impl registered for this CustomOp. Please register one via "
635
+ f"CustomOp.impl(device_type='{device}')"
636
+ )
637
+ raise NotImplementedError(
638
+ f"{custom_op}: No implementation for dispatch key {key}. It is likely "
639
+ f"that we have not added this functionality yet, please either open an "
640
+ f"issue or if you're feeling adventurous, use the low-level "
641
+ f"torch.library API"
642
+ )
643
+
644
+
645
+ def custom_op_from_existing(op):
646
+ ns = op.namespace
647
+ lib = torch.library.Library(ns, "FRAGMENT")
648
+ name = op.name().split("::")[-1]
649
+ schema_str = str(op._schema)
650
+ # CustomOp expects the schema string without the namespace
651
+ schema_str = schema_str.split("::")[-1]
652
+ schema = FunctionSchema.parse(schema_str)
653
+ return CustomOp(lib, ns, schema, name, op, _private_access=True)
654
+
655
+
656
+ def get_op(qualname):
657
+ def error_not_found():
658
+ raise ValueError(
659
+ f"Could not find the operator {qualname}. Please make sure you have "
660
+ f"already registered the operator and (if registered from C++) "
661
+ f"loaded it via torch.ops.load_library."
662
+ )
663
+
664
+ ns, name = parse_qualname(qualname)
665
+ if not hasattr(torch.ops, ns):
666
+ error_not_found()
667
+ opnamespace = getattr(torch.ops, ns)
668
+ if not hasattr(opnamespace, name):
669
+ error_not_found()
670
+ packet = getattr(opnamespace, name)
671
+ if not hasattr(packet, "default"):
672
+ error_not_found()
673
+ return packet.default
674
+
675
+
676
+ def _find_custom_op(qualname, also_check_torch_library=False):
677
+ if qualname in global_registry:
678
+ return global_registry[qualname]
679
+ if not also_check_torch_library:
680
+ raise RuntimeError(
681
+ f'Could not find custom op "{qualname}". Did you register it via '
682
+ f"the torch._custom_ops API?"
683
+ )
684
+ overload = get_op(qualname)
685
+ result = custom_op_from_existing(overload)
686
+ return result
687
+
688
+
689
+ def get_abstract_impl(qualname):
690
+ if qualname not in torch._custom_op.impl.global_registry:
691
+ return None
692
+ custom_op = torch._custom_op.impl.global_registry[qualname]
693
+ if custom_op is None:
694
+ return None
695
+ if not custom_op._has_impl("abstract"):
696
+ return None
697
+ return custom_op._get_impl("abstract").func
698
+
699
+
700
+ def _custom_op_with_schema(qualname, schema, needs_fixed_stride_order=True):
701
+ ns, name = qualname.split("::")
702
+ schema_str = f"{name}{schema}"
703
+ function_schema = FunctionSchema.parse(schema_str)
704
+ validate_schema(function_schema)
705
+ tags = [torch._C.Tag.needs_fixed_stride_order] if needs_fixed_stride_order else []
706
+ lib = library.Library(ns, "FRAGMENT")
707
+ lib.define(schema_str, tags=tags)
708
+ ophandle = find_ophandle_or_throw(ns, function_schema.name)
709
+ result = CustomOp(lib, ns, function_schema, name, ophandle, _private_access=True)
710
+ result._register_autograd_kernel_indirection()
711
+
712
+ torch._C._dispatch_set_report_error_callback(
713
+ ophandle, functools.partial(report_error_callback, weakref.proxy(result))
714
+ )
715
+ return get_op(qualname)
phivenv/Lib/site-packages/torch/_decomp/__init__.py ADDED
@@ -0,0 +1,544 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import inspect
3
+ from collections import defaultdict
4
+ from collections.abc import Sequence
5
+ from functools import lru_cache, partial, wraps
6
+ from itertools import chain
7
+ from typing import Callable, Optional, TYPE_CHECKING, TypeVar, Union
8
+ from typing_extensions import ParamSpec
9
+
10
+
11
+ if TYPE_CHECKING:
12
+ from torch.export.decomp_utils import CustomDecompTable
13
+
14
+ import torch
15
+ import torch.library
16
+ from torch._ops import HigherOrderOperator, OperatorBase, OpOverload, OpOverloadPacket
17
+ from torch._prims_common import CustomOutParamAnnotation
18
+ from torch._subclasses.functional_tensor import FunctionalTensor
19
+ from torch.utils import _pytree as pytree
20
+
21
+
22
+ __all__ = [
23
+ "decomposition_table",
24
+ "pre_autograd_decomposition_table",
25
+ "meta_table",
26
+ "register_decomposition",
27
+ "get_decompositions",
28
+ "core_aten_decompositions",
29
+ "_should_decompose_because_unsafe_op",
30
+ ]
31
+
32
+ _T = TypeVar("_T")
33
+ _P = ParamSpec("_P")
34
+
35
+ # TODO: relax key type here; torch registrations should be possible to; but
36
+ # right now this type is accurate
37
+ global_decomposition_table: dict[str, dict[torch._ops.OperatorBase, Callable]] = (
38
+ defaultdict(dict)
39
+ )
40
+
41
+ decomposition_table = global_decomposition_table["post_autograd"]
42
+ pre_autograd_decomposition_table = global_decomposition_table["pre_autograd"]
43
+ meta_table = global_decomposition_table["meta"]
44
+
45
+
46
+ def _should_decompose_because_unsafe_op(op: torch._ops.OperatorBase) -> bool:
47
+ """
48
+ Returns True if the op must always decompose in export/compile tracing system
49
+
50
+ In export, we always decompose certain CIA ops that are tagged with
51
+ maybe_aliasing_or_mutating because we statically need to know if the op is
52
+ mutating or not. But these CIA ops could have different behaviour in runtime.
53
+
54
+ native_batch_norm is a prim op which has a wrong schema and it needs to be replaced
55
+ with correct schema. But until then, we will force decompose it via this tag.
56
+ """
57
+ if not isinstance(op, torch._ops.OpOverload):
58
+ return False
59
+ if torch.Tag.maybe_aliasing_or_mutating in op.tags:
60
+ return True
61
+ return op == torch.ops.aten.native_batch_norm.default
62
+
63
+
64
+ def _add_op_to_registry(registry, op, fn):
65
+ """
66
+ This is an internal API for adding an op to the decomposition table.
67
+
68
+ If op is OpOverload, it will be added to the registry directly.
69
+ If op is OpOverloadPacket, all the valid op_overloads in the packet will be added to the registry.
70
+ """
71
+ overloads: list[Union[torch._ops.OperatorBase]] = []
72
+ if isinstance(op, HigherOrderOperator):
73
+ # There's no concept of overloads for HigherOrderOperator
74
+ registry[op] = fn
75
+ return
76
+ elif isinstance(op, OpOverload):
77
+ overloads.append(op)
78
+ else:
79
+ assert isinstance(op, OpOverloadPacket)
80
+ for ol in op.overloads():
81
+ overloads.append(getattr(op, ol))
82
+
83
+ for op_overload in overloads:
84
+ if op_overload in registry:
85
+ raise RuntimeError(f"duplicate registrations for {op_overload}")
86
+ # TorchScript dumps a bunch of extra nonsense overloads
87
+ # which don't have corresponding dispatcher entries, we need
88
+ # to filter those out, e.g aten.add.float_int
89
+ if torch._C._dispatch_has_kernel(op_overload.name()):
90
+ registry[op_overload] = fn
91
+
92
+
93
+ def _convert_out_params(f):
94
+ out_annotation = f.__annotations__.get("out")
95
+
96
+ # If there are no out params, do not wrap the function.
97
+ if not out_annotation:
98
+ return f
99
+
100
+ # Hack to detect when out is a Tuple. There seems to be no pretty way of doing this
101
+ if getattr(out_annotation, "__origin__", None) is tuple:
102
+ sig = inspect.signature(f)
103
+ out_names = sig.return_annotation._fields
104
+ # If out is a tuple, we need to register a function that unpacks all the out
105
+ # elements as this is what native_functions.yaml expects
106
+
107
+ @wraps(f)
108
+ def _fn(*args, **kwargs):
109
+ out_kwargs = tuple(kwargs.pop(o, None) for o in out_names)
110
+ # Either all of the out kwargs are set or none of them
111
+ is_none = out_kwargs[0] is None
112
+ assert all((o is None) == is_none for o in out_kwargs)
113
+ return f(*args, **kwargs, out=None if is_none else out_kwargs)
114
+
115
+ out_params = [
116
+ inspect.Parameter(
117
+ o,
118
+ kind=inspect.Parameter.KEYWORD_ONLY,
119
+ default=None,
120
+ annotation=t,
121
+ )
122
+ for o, t in zip(out_names, out_annotation.__args__)
123
+ ]
124
+ # Drop the out parameter and concatenate the new kwargs in the signature
125
+ params = chain((v for k, v in sig.parameters.items() if k != "out"), out_params)
126
+ _fn.__signature__ = inspect.Signature( # type: ignore[attr-defined]
127
+ parameters=params, # type: ignore[arg-type]
128
+ return_annotation=sig.return_annotation,
129
+ )
130
+ # Drop the out parameter and concatenate the new kwargs in the annotations
131
+ _fn.__annotations__ = {k: v for k, v in f.__annotations__.items() if k != "out"}
132
+ for o in out_params:
133
+ _fn.__annotations__[o.name] = o.annotation
134
+
135
+ # Propagate that this function is wrapped by `out_wrapper`
136
+ _fn._torch_decompositions_out_wrapper = f._torch_decompositions_out_wrapper # type: ignore[attr-defined]
137
+
138
+ return _fn
139
+
140
+ # Alternatively, there may be a single tensor out parameter with a name
141
+ # other than "out". This will need special treatment and is indicated by an
142
+ # annotation, which we will remove here so it is not exposed after wrapping.
143
+ custom_out_param_name = f.__annotations__.pop(CustomOutParamAnnotation, None)
144
+ if custom_out_param_name:
145
+
146
+ @wraps(f)
147
+ def _fn(*args, **kwargs):
148
+ out_kwarg = kwargs.pop(custom_out_param_name, None)
149
+ return f(*args, **kwargs, out=out_kwarg)
150
+
151
+ out_param = inspect.Parameter(
152
+ custom_out_param_name,
153
+ kind=inspect.Parameter.KEYWORD_ONLY,
154
+ default=None,
155
+ annotation=out_annotation,
156
+ )
157
+
158
+ # Drop the out parameter and concatenate the new kwarg in the signature
159
+ sig = inspect.signature(f)
160
+ params = chain(
161
+ (v for k, v in sig.parameters.items() if k != "out"), (out_param,)
162
+ )
163
+ _fn.__signature__ = inspect.Signature( # type: ignore[attr-defined]
164
+ parameters=params, # type: ignore[arg-type]
165
+ return_annotation=sig.return_annotation,
166
+ )
167
+
168
+ # Drop the out parameter and concatenate the new kwargs in the annotations
169
+ _fn.__annotations__ = {k: v for k, v in f.__annotations__.items() if k != "out"}
170
+ _fn.__annotations__[out_param.name] = out_param.annotation
171
+
172
+ return _fn
173
+
174
+ return f
175
+
176
+
177
+ def register_decomposition(
178
+ aten_op, registry=None, *, type="post_autograd", unsafe=False
179
+ ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
180
+ """
181
+ A decorator to register a function as a decomposition to the Python
182
+ decomposition table. Use it like this::
183
+
184
+ @register_decomposition(torch.ops.aten.clamp_min)
185
+ def clamp_min(x):
186
+ return torch.clamp(self, min=min)
187
+
188
+ If you are writing a new decomposition, consider contributing it
189
+ directly to PyTorch in torch._decomp.decompositions.
190
+
191
+ This API is experimental; we are almost certainly going to extend
192
+ the API when we make decompositions eligible for use in transforms (e.g.,
193
+ autograd) and not just backend tracing, where we then need to know if a
194
+ decomposition can be used to simulate a transform.
195
+
196
+ By default, we also will register it to the Meta key of dispatcher,
197
+ and replace the c++ Meta implementation if there is already one.
198
+
199
+ unsafe kwarg is for reuse of this function for registering non-function
200
+ things
201
+ """
202
+
203
+ assert type in {"post_autograd", "pre_autograd", "meta"}
204
+
205
+ def decomposition_decorator(fn: Callable[_P, _T]) -> Callable[_P, _T]:
206
+ orig_fn = fn
207
+ if not unsafe:
208
+ fn = _convert_out_params(fn)
209
+
210
+ nonlocal registry
211
+ if registry is None:
212
+ registry = global_decomposition_table[type]
213
+
214
+ def register(op):
215
+ _add_op_to_registry(registry, op, fn)
216
+
217
+ # To handle allowing multiple aten_ops at once
218
+ pytree.tree_map_(register, aten_op)
219
+ return orig_fn
220
+
221
+ return decomposition_decorator
222
+
223
+
224
+ def get_decompositions(
225
+ aten_ops: Sequence[Union[torch._ops.OperatorBase, OpOverloadPacket]],
226
+ type: str = "post_autograd",
227
+ ) -> dict[torch._ops.OperatorBase, Callable]:
228
+ """
229
+ Retrieve a dictionary of decompositions corresponding to the list of
230
+ operator overloads and overload packets passed as input. Overload
231
+ packets will include all decomposed overloads in the packet. If there is
232
+ no decomposition for a requested operator, it is silently ignored.
233
+
234
+ This API is experimental; we are almost certainly going to give an alternate,
235
+ more recommended formulation, where a user provides the set of operators
236
+ they know how to implement, and we provide decompositions for everything
237
+ not in this set.
238
+ """
239
+ assert type in {"post_autograd", "pre_autograd", "meta"}
240
+
241
+ registry = global_decomposition_table[type]
242
+ packets_to_overloads = defaultdict(list)
243
+ for opo in registry:
244
+ if isinstance(opo, (OpOverload, OpOverloadPacket)):
245
+ packets_to_overloads[opo.overloadpacket].append(opo)
246
+ decompositions: dict[torch._ops.OperatorBase, Callable] = {}
247
+ for op in aten_ops:
248
+ if isinstance(op, OpOverloadPacket) and op in packets_to_overloads:
249
+ for op_overload in packets_to_overloads[op]:
250
+ decompositions[op_overload] = registry[op_overload]
251
+ elif isinstance(op, (torch._ops.OperatorBase)) and op in registry:
252
+ decompositions[op] = registry[op]
253
+ return decompositions
254
+
255
+
256
+ def remove_decompositions(
257
+ decompositions: dict[torch._ops.OperatorBase, Callable],
258
+ aten_ops: Sequence[Union[OpOverload, OpOverloadPacket]],
259
+ ) -> None:
260
+ """
261
+ Given a dictionary of decompositions obtained from get_decompositions(), removes
262
+ operators associated with a list of operator overloads and overload packets passed
263
+ as input. If the decomposition dictionary does not contain a decomposition that is
264
+ specified to be removed, it is silently ignored.
265
+ """
266
+ for op in aten_ops:
267
+ if isinstance(op, OpOverloadPacket):
268
+ for overload_name in op.overloads():
269
+ opo = getattr(op, overload_name)
270
+ decompositions.pop(opo, None)
271
+ elif isinstance(op, OpOverload):
272
+ decompositions.pop(op, None)
273
+
274
+
275
+ # populate the table
276
+ import torch._decomp.decompositions
277
+ import torch._refs
278
+
279
+
280
+ def core_aten_decompositions() -> "CustomDecompTable":
281
+ from torch.export.exported_program import default_decompositions
282
+
283
+ return default_decompositions()
284
+
285
+
286
+ # See NOTE [Core ATen Ops]
287
+ #
288
+ # list was copied from torch/_inductor/decomposition.py
289
+ # excluding decompositions that results in prim ops
290
+ # Resulting opset of decomposition is core aten ops
291
+ def _core_aten_decompositions_post_autograd() -> dict[
292
+ torch._ops.OperatorBase, Callable
293
+ ]:
294
+ aten = torch.ops.aten
295
+ return get_decompositions(
296
+ [
297
+ aten.addcdiv,
298
+ aten.addcdiv_,
299
+ aten.addcmul,
300
+ aten.addcmul_,
301
+ aten.addr,
302
+ aten.affine_grid_generator,
303
+ aten.alias_copy,
304
+ aten.all,
305
+ aten.aminmax,
306
+ aten.arange.default,
307
+ aten.arange.start,
308
+ aten.avg_pool2d_backward,
309
+ aten.baddbmm,
310
+ aten.binary_cross_entropy,
311
+ aten.binary_cross_entropy_backward,
312
+ aten.binary_cross_entropy_with_logits,
313
+ aten.block_diag,
314
+ aten.bernoulli.p,
315
+ aten.bernoulli.default,
316
+ aten.celu,
317
+ aten.celu_,
318
+ aten.channel_shuffle,
319
+ aten.clamp_max,
320
+ aten.clamp_min,
321
+ aten.col2im,
322
+ aten.count_nonzero,
323
+ aten.linalg_cross,
324
+ aten.cudnn_batch_norm,
325
+ aten.cudnn_batch_norm_backward,
326
+ aten.miopen_batch_norm_backward,
327
+ aten.deg2rad,
328
+ aten.deg2rad_,
329
+ aten.detach,
330
+ aten.diag_embed,
331
+ aten.diagonal_backward,
332
+ aten.diagonal_copy,
333
+ aten.dot,
334
+ aten.vdot,
335
+ aten.elu_,
336
+ aten.elu_backward,
337
+ aten._embedding_bag,
338
+ aten.embedding_dense_backward,
339
+ aten.empty_like,
340
+ aten._euclidean_dist.default,
341
+ aten.expand_as,
342
+ aten.expand_copy,
343
+ aten.eye,
344
+ aten.fill,
345
+ aten.fill_,
346
+ aten.floor_divide,
347
+ aten.frac,
348
+ aten.frac_,
349
+ aten._fused_moving_avg_obs_fq_helper,
350
+ aten.gelu_,
351
+ aten.gelu_backward,
352
+ aten.glu,
353
+ aten.glu_backward,
354
+ aten.hardshrink,
355
+ aten.hardsigmoid,
356
+ aten.hardsigmoid_,
357
+ aten.hardsigmoid_backward,
358
+ aten.hardswish,
359
+ aten.hardswish_,
360
+ aten.hardswish_backward,
361
+ aten.hardtanh_,
362
+ aten.hardtanh_backward,
363
+ aten.heaviside,
364
+ aten.heaviside_,
365
+ aten.huber_loss,
366
+ aten.huber_loss_backward,
367
+ aten.im2col,
368
+ aten.index_add.out,
369
+ aten.index_add.default,
370
+ aten.index_add_,
371
+ aten.index_copy.out,
372
+ aten.index_copy.default,
373
+ aten.index_copy_,
374
+ aten.index_fill.int_Scalar,
375
+ aten.index_fill.int_Tensor,
376
+ aten.index_fill.int_Scalar_out,
377
+ aten.index_fill.int_Tensor_out,
378
+ aten.index_fill_,
379
+ aten.isin,
380
+ aten.isneginf,
381
+ aten.isposinf,
382
+ aten.l1_loss,
383
+ aten._lazy_clone,
384
+ aten._test_parallel_materialize,
385
+ aten.leaky_relu_,
386
+ aten.leaky_relu_backward,
387
+ aten.lerp,
388
+ aten.lerp_,
389
+ aten.linspace,
390
+ aten.logaddexp,
391
+ aten.logaddexp2,
392
+ aten.logit,
393
+ aten.logit_,
394
+ aten.logit_backward,
395
+ aten.log_sigmoid_backward,
396
+ aten.log_sigmoid_forward,
397
+ aten._log_softmax_backward_data,
398
+ aten.logspace,
399
+ aten.logsumexp.default,
400
+ aten.masked_fill,
401
+ aten.masked_fill_,
402
+ aten.max_unpool2d,
403
+ aten.max_unpool3d,
404
+ aten.mish,
405
+ aten.mish_,
406
+ aten.mse_loss,
407
+ aten.mse_loss_backward,
408
+ aten.multi_margin_loss,
409
+ aten.multilabel_margin_loss_forward,
410
+ aten.mv,
411
+ aten.mvlgamma,
412
+ aten.mvlgamma_,
413
+ aten.nansum,
414
+ aten.nan_to_num,
415
+ aten.nan_to_num_,
416
+ aten.narrow,
417
+ aten.native_batch_norm_backward,
418
+ aten.native_dropout_backward,
419
+ aten.native_group_norm_backward,
420
+ aten.native_layer_norm_backward,
421
+ aten.new_empty,
422
+ aten.new_full,
423
+ aten.new_ones,
424
+ aten.new_zeros,
425
+ aten.nll_loss2d_forward,
426
+ aten.nll_loss2d_backward,
427
+ aten.nll_loss_backward,
428
+ aten.nll_loss_forward,
429
+ aten.norm.ScalarOpt_dtype,
430
+ aten.norm.Scalar,
431
+ aten.norm.ScalarOpt_dim_dtype,
432
+ aten.norm.ScalarOpt_dim,
433
+ aten.norm.dtype_out,
434
+ aten.norm.out,
435
+ aten.norm.names_dtype_out,
436
+ aten.norm.names_out,
437
+ aten.norm.ScalarOpt_dtype_out,
438
+ aten.norm.Scalar_out,
439
+ aten.ones,
440
+ aten.ones_like,
441
+ aten.pixel_shuffle,
442
+ aten.pixel_unshuffle,
443
+ aten._prelu_kernel,
444
+ aten._prelu_kernel_backward,
445
+ aten._reshape_alias,
446
+ aten.rad2deg,
447
+ aten.rad2deg_,
448
+ aten.reflection_pad1d,
449
+ aten.reflection_pad1d_backward,
450
+ aten.reflection_pad2d,
451
+ aten.reflection_pad2d_backward,
452
+ aten.reflection_pad3d,
453
+ aten.reflection_pad3d_backward,
454
+ aten.replication_pad1d,
455
+ aten.replication_pad2d,
456
+ aten.replication_pad3d,
457
+ aten.renorm,
458
+ aten.renorm_,
459
+ aten.replication_pad2d,
460
+ aten.resize_as,
461
+ aten.roll,
462
+ aten.rot90,
463
+ aten.rrelu_with_noise,
464
+ aten.rrelu_with_noise_,
465
+ aten.rsub,
466
+ aten._safe_softmax,
467
+ aten._scaled_dot_product_flash_attention_for_cpu.default,
468
+ aten.select_backward,
469
+ aten.select_scatter,
470
+ aten.sgn,
471
+ aten.sgn_,
472
+ aten.sigmoid_backward,
473
+ aten.silu,
474
+ aten.silu_,
475
+ aten.silu_backward.grad_input,
476
+ aten.sinc,
477
+ aten.sinc_,
478
+ aten.slice_backward,
479
+ aten.smooth_l1_loss,
480
+ aten.smooth_l1_loss_backward,
481
+ aten.soft_margin_loss,
482
+ aten.soft_margin_loss_backward,
483
+ aten._softmax_backward_data,
484
+ aten.softplus,
485
+ aten.softplus_backward,
486
+ aten.softshrink,
487
+ aten.special_entr,
488
+ aten.special_log_ndtr,
489
+ aten.special_xlog1py,
490
+ aten.split.Tensor,
491
+ aten.split_with_sizes_copy,
492
+ aten.squeeze_copy,
493
+ aten.squeeze.default,
494
+ aten.squeeze.dim,
495
+ aten.std.correction,
496
+ aten.std.out,
497
+ aten.std.correction_out,
498
+ aten.std.names_out,
499
+ aten.std.correction_names_out,
500
+ aten.std_mean.correction,
501
+ aten.std_mean.correction_out,
502
+ aten.stack,
503
+ aten.sum.default,
504
+ aten.sum.out,
505
+ aten.t,
506
+ aten.t_copy,
507
+ aten.take,
508
+ aten.tanh_backward,
509
+ aten.threshold,
510
+ aten.threshold_,
511
+ aten.threshold_backward,
512
+ aten.trace,
513
+ aten.transpose.int,
514
+ aten.transpose_copy,
515
+ aten.tril,
516
+ aten.tril_,
517
+ aten.triu,
518
+ aten.triu_,
519
+ aten.unbind,
520
+ aten.unfold_backward,
521
+ aten.unfold_copy,
522
+ aten._unsafe_index,
523
+ aten._unsafe_index_put,
524
+ aten._unsafe_masked_index,
525
+ aten._unsafe_masked_index_put_accumulate,
526
+ aten.unsafe_split.Tensor,
527
+ aten.unsafe_split_with_sizes,
528
+ aten.unsqueeze_copy,
529
+ aten._unsafe_view,
530
+ aten.upsample_linear1d,
531
+ aten.upsample_bilinear2d.out,
532
+ aten.upsample_trilinear3d.out,
533
+ aten.upsample_nearest2d_backward,
534
+ aten.view_as_complex,
535
+ aten.xlogy,
536
+ aten.xlogy_,
537
+ aten.zero,
538
+ aten.zero_,
539
+ aten.zeros,
540
+ aten.zeros_like,
541
+ aten._chunk_cat,
542
+ aten._weight_norm_interface,
543
+ ]
544
+ )
phivenv/Lib/site-packages/torch/_decomp/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (14 kB). View file
 
phivenv/Lib/site-packages/torch/_decomp/__pycache__/decompositions_for_jvp.cpython-39.pyc ADDED
Binary file (6.7 kB). View file
 
phivenv/Lib/site-packages/torch/_decomp/__pycache__/decompositions_for_rng.cpython-39.pyc ADDED
Binary file (8.1 kB). View file
 
phivenv/Lib/site-packages/torch/_decomp/decompositions.py ADDED
The diff for this file is too large to render. See raw diff
 
phivenv/Lib/site-packages/torch/_decomp/decompositions_for_jvp.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-decorators
2
+ # mypy: allow-untyped-defs
3
+ import inspect
4
+ from typing import Callable, Optional
5
+
6
+ import torch
7
+ import torch._decomp
8
+ from torch import Tensor
9
+ from torch._prims_common.wrappers import _maybe_remove_out_wrapper
10
+
11
+
12
+ decomposition_table = torch._decomp.decomposition_table
13
+ decomposition_table_for_jvp: dict[torch._ops.OperatorBase, Callable] = {}
14
+ register_decomposition = torch._decomp.register_decomposition
15
+ aten = torch.ops.aten
16
+
17
+ # NOTE: [forward-mode AD decompositions mechanism]
18
+ #
19
+ # The mechanism is in VariableType,
20
+ # IF any inputs have forward grad
21
+ # AND there is no forward AD formula implemented
22
+ # AND the functions are actually differentiable
23
+ # run the decomposition
24
+ # See run_jit_decomposition_with_args_for_jvp
25
+ # We currently use python decompositions that we torchscript.
26
+ #
27
+ # Note that we would be building the backward graph at the decomposed level
28
+ # too, but that is OK, because we would've errored out otherwise anyway.
29
+ #
30
+ # TODO: The mechanism we are using to register decompositions doesn't
31
+ # seem to be exclusively used for jvp. So open question here is whether
32
+ # torch/csrc/jit/runtime/decomposition_registry.cpp is being used for other things.
33
+ # If that is the case, we may go down the decomposition path unexpectedly
34
+ # (and possibly produce an unintelligible error) vs erroring out earlier and
35
+ # printing that the forward AD formula is not implemented.
36
+ #
37
+ # The solution to this may be to have an explicitly white list control when
38
+ # to enable the decomposition.
39
+
40
+
41
+ def maybe_register_decomposition(op):
42
+ def decorator(f):
43
+ try:
44
+ return register_decomposition(op)(f)
45
+ except Exception:
46
+ return f
47
+
48
+ return decorator
49
+
50
+
51
+ # Functions where we need a special decomposition for jvp but there's another version that
52
+ # should be used more generally (ex. for jvp we need to recompute the mean and variance for
53
+ # the backwards of a normalization function. Without jvp, it should use the saved value)
54
+ decomposition_table_for_jvp = {}
55
+
56
+
57
+ def register_decomposition_for_jvp(fn):
58
+ return register_decomposition(fn, registry=decomposition_table_for_jvp)
59
+
60
+
61
+ def _register_jit_decomposition_for_jvp(decomp, use_python=False):
62
+ if decomp in decomposition_table_for_jvp:
63
+ decomposition_table_used = decomposition_table_for_jvp
64
+ elif decomp in decomposition_table:
65
+ decomposition_table_used = decomposition_table
66
+ else:
67
+ raise RuntimeError(f"could not find decomposition for {decomp}")
68
+ decomp_fn = decomposition_table_used[decomp]
69
+
70
+ # `out_wrapper` extends a decompositions signature with
71
+ # an `out` parameter. However jit will use the unwrapped function's
72
+ # signature instead so we need to unwrap here to prevent an error
73
+ decomp_fn = _maybe_remove_out_wrapper(decomp_fn)
74
+
75
+ if use_python:
76
+ decomp_fn = torch.jit.ignore(decomp_fn)
77
+ sig = inspect.signature(decomp_fn)
78
+
79
+ # Create a string wrapping the function from the signature
80
+ # example output:
81
+ # def wrapped_decomp(x: torch.Tensor, y: int, z: int):
82
+ # return decomp_fn(x, y, z)
83
+ # Thanks copilot!
84
+ def get_function_def(sig):
85
+ param_def = [f"{param_str}" for param_str in sig.parameters.values()]
86
+ param_use = [f"{param_str}" for param_str in sig.parameters.keys()]
87
+
88
+ return f"def wrapped_decomp({', '.join(param_def)}):\n return decomp_fn({', '.join(param_use)})\n"
89
+
90
+ f_str = get_function_def(sig)
91
+ graph = torch.jit.CompilationUnit(f_str).wrapped_decomp.graph
92
+ else:
93
+ graph = torch.jit.script(decomp_fn).graph
94
+ torch.jit._register_decomposition(decomp, graph)
95
+
96
+
97
+ # The only decompositions here are temporary or hacks for the purposes of jvp
98
+
99
+
100
+ # TODO: do these also belong here?
101
+ @maybe_register_decomposition(aten.trace.default)
102
+ def trace(self: Tensor) -> Tensor:
103
+ return torch.sum(torch.diag(self))
104
+
105
+
106
+ @maybe_register_decomposition(aten.log_sigmoid_forward.default)
107
+ def log_sigmoid_forward(self: Tensor) -> tuple[Tensor, Tensor]:
108
+ min = torch.minimum(self.new_zeros(()), self)
109
+ z = torch.exp(-torch.abs(self))
110
+ if self.is_cuda or self.is_xpu:
111
+ buffer = self.new_zeros((0,))
112
+ else:
113
+ buffer = z
114
+ return min - torch.log1p(z), buffer
115
+
116
+
117
+ def recompute_mean_var(
118
+ input: Tensor, rstd: Tensor, inner_dim_indices: list[int], keepdim: bool
119
+ ):
120
+ # for most norm decompositions, it will be the same as the core version except for here.
121
+ # We recompute the mean and variance so that they track gradients through input
122
+
123
+ mean = torch.mean(input, dim=inner_dim_indices, keepdim=keepdim)
124
+ var = torch.var(input, dim=inner_dim_indices, unbiased=False, keepdim=keepdim)
125
+ eps = torch.pow(1 / rstd, 2) - var # this makes me so sad inside
126
+ eps = eps.detach()
127
+ rstd = 1 / torch.sqrt(var + eps)
128
+ return mean, rstd
129
+
130
+
131
+ @register_decomposition_for_jvp(aten.native_layer_norm_backward)
132
+ def native_layer_norm_backward(
133
+ grad_out: Tensor,
134
+ input: Tensor,
135
+ normalized_shape: list[int],
136
+ mean: Tensor,
137
+ rstd: Tensor,
138
+ weight: Optional[Tensor],
139
+ bias: Optional[Tensor],
140
+ output_mask: list[bool],
141
+ ) -> tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
142
+ input_shape = input.shape
143
+ input_ndim = input.dim()
144
+
145
+ axis = input_ndim - len(normalized_shape)
146
+ inner_dims = input_shape[axis:]
147
+ outer_dims = input_shape[:axis]
148
+ inner_dim_indices = list(range(axis, input_ndim))
149
+ outer_dim_indices = list(range(0, axis))
150
+
151
+ N = 1
152
+ for i in inner_dims:
153
+ N *= i
154
+ M = 1
155
+ for i in outer_dims:
156
+ M *= i
157
+ if M <= 0 or N <= 0:
158
+ return (
159
+ input.new_zeros(input_shape),
160
+ input.new_zeros(input_shape[axis:]),
161
+ input.new_zeros(input_shape[axis:]),
162
+ )
163
+
164
+ mean_, rstd_ = recompute_mean_var(input, rstd, inner_dim_indices, keepdim=True)
165
+
166
+ x_hat = (input - mean_) * rstd_
167
+ if weight is not None:
168
+ grad_x_hat = grad_out * weight
169
+ else:
170
+ grad_x_hat = grad_out
171
+ a = grad_x_hat * N
172
+ b = torch.sum(grad_x_hat, inner_dim_indices, True)
173
+ c1 = torch.mul(grad_x_hat, x_hat)
174
+ c2 = torch.sum(c1, inner_dim_indices, True)
175
+ c3 = torch.mul(x_hat, c2)
176
+ inner = a - b - c3
177
+
178
+ if output_mask[0]:
179
+ d_input: Optional[Tensor] = (rstd_ / N) * inner
180
+ else:
181
+ d_input = torch.zeros_like(input) # should be None but doesn't work with vjp
182
+
183
+ if output_mask[1] and weight is not None:
184
+ if len(outer_dim_indices) > 0:
185
+ d_weight: Optional[Tensor] = torch.sum(
186
+ grad_out * x_hat, outer_dim_indices, False
187
+ )
188
+ else:
189
+ d_weight = grad_out * x_hat
190
+ elif weight is not None:
191
+ d_weight = torch.zeros_like(weight) # should be None but doesn't work with vjp
192
+ else:
193
+ d_weight = torch.zeros(()) # should be None but doesn't work with vjp
194
+
195
+ if output_mask[2] and bias is not None:
196
+ if len(outer_dim_indices) > 0:
197
+ d_bias: Optional[Tensor] = torch.sum(grad_out, outer_dim_indices, False)
198
+ else:
199
+ d_bias = grad_out.clone()
200
+ elif bias is not None:
201
+ d_bias = torch.zeros_like(bias) # should be None but doesn't work with vjp
202
+ else:
203
+ d_bias = torch.zeros(()) # should be None but doesn't work with vjp
204
+
205
+ return (d_input, d_weight, d_bias)
206
+
207
+
208
+ def prod(x: list[int]):
209
+ r = 1
210
+ for i in x:
211
+ r *= i
212
+ return r
213
+
214
+
215
+ @register_decomposition_for_jvp(aten.native_batch_norm_backward)
216
+ def native_batch_norm_backward(
217
+ grad_out: Tensor,
218
+ input: Tensor,
219
+ weight: Optional[Tensor],
220
+ running_mean: Optional[Tensor],
221
+ running_var: Optional[Tensor],
222
+ save_mean: Optional[Tensor],
223
+ save_invstd: Optional[Tensor],
224
+ train: bool,
225
+ eps: float,
226
+ output_mask: list[bool],
227
+ ) -> tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
228
+ input_shape = input.shape
229
+ input_rank = input.dim()
230
+ assert input_rank >= 2, "rank of the input must be at least 2"
231
+
232
+ axis = 1
233
+ num_features = prod(input_shape) / input_shape[axis] # type: ignore[arg-type]
234
+ mean = save_mean
235
+ invstd = save_invstd
236
+ if train:
237
+ assert save_mean is not None and save_invstd is not None, (
238
+ "when train=True, save_mean and save_invstd are required"
239
+ )
240
+
241
+ reduciton_dims = [0] + list(range(2, input.dim()))
242
+ assert invstd is not None # for typing
243
+ mean, invstd = recompute_mean_var(input, invstd, reduciton_dims, keepdim=False)
244
+ else:
245
+ assert running_mean is not None and running_var is not None
246
+ mean = running_mean
247
+ invstd = torch.rsqrt(running_var + eps)
248
+
249
+ assert invstd is not None and mean is not None
250
+
251
+ broadcast_mask = [1] * input_rank
252
+ broadcast_mask[axis] = input_shape[axis]
253
+
254
+ reduction_axes: list[int] = []
255
+ for i in range(input_rank):
256
+ if i != axis:
257
+ reduction_axes.append(i)
258
+
259
+ mean = torch.reshape(mean, broadcast_mask)
260
+ norm = 1.0 / num_features
261
+ grad_output_sum = torch.sum(grad_out, reduction_axes)
262
+ dot_p = torch.sum(grad_out * (input - mean), reduction_axes)
263
+
264
+ grad_mean = torch.reshape(grad_output_sum * norm, broadcast_mask)
265
+ proj_scale = torch.reshape(torch.mul(dot_p * norm, invstd * invstd), broadcast_mask)
266
+
267
+ if weight is None:
268
+ grad_scale = torch.reshape(invstd, broadcast_mask) * 1.0
269
+ else:
270
+ grad_scale = torch.reshape(invstd * weight, broadcast_mask)
271
+
272
+ if train:
273
+ proj = (input - mean) * proj_scale
274
+ grad_input = ((grad_out - proj) - grad_mean) * grad_scale
275
+ else:
276
+ grad_input = grad_out * grad_scale
277
+
278
+ if output_mask[1]:
279
+ grad_weight = dot_p * invstd
280
+ elif weight is not None:
281
+ grad_weight = torch.zeros_like(
282
+ weight
283
+ ) # should be None but doesn't work with vjp
284
+ else:
285
+ grad_weight = torch.zeros(()) # should be None but doesn't work with vjp
286
+
287
+ if output_mask[2]:
288
+ grad_bias = grad_output_sum
289
+ else:
290
+ grad_bias = torch.zeros_like(
291
+ grad_output_sum
292
+ ) # should be None but doesn't work with vjp
293
+
294
+ return (grad_input, grad_weight, grad_bias)
295
+
296
+
297
+ @register_decomposition_for_jvp(aten.batch_norm_backward)
298
+ def batch_norm_backward(
299
+ grad_out: Tensor,
300
+ input: Tensor,
301
+ weight: Tensor,
302
+ running_mean: Optional[Tensor],
303
+ running_var: Optional[Tensor],
304
+ save_mean: Optional[Tensor],
305
+ save_var: Optional[Tensor],
306
+ update: bool,
307
+ eps: float,
308
+ output_mask: list[bool],
309
+ reserve: Tensor,
310
+ ) -> tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
311
+ return native_batch_norm_backward(
312
+ grad_out,
313
+ input,
314
+ weight,
315
+ running_mean,
316
+ running_var,
317
+ save_mean,
318
+ save_var,
319
+ update,
320
+ eps,
321
+ output_mask,
322
+ )
323
+
324
+
325
+ _register_jit_decomposition_for_jvp(torch.ops.aten.trace.default, use_python=True)
326
+ _register_jit_decomposition_for_jvp(torch.ops.aten.nll_loss_backward.default)
327
+ _register_jit_decomposition_for_jvp(torch.ops.aten.nll_loss2d_backward.default)
328
+ _register_jit_decomposition_for_jvp(torch.ops.aten._log_softmax_backward_data.default)
329
+ _register_jit_decomposition_for_jvp(torch.ops.aten._softmax_backward_data.default)
330
+ _register_jit_decomposition_for_jvp(torch.ops.aten.log_sigmoid_forward.default)
331
+ _register_jit_decomposition_for_jvp(torch.ops.aten.native_layer_norm_backward.default)
332
+ _register_jit_decomposition_for_jvp(torch.ops.aten.native_batch_norm_backward.default)
333
+ _register_jit_decomposition_for_jvp(torch.ops.aten.cudnn_batch_norm_backward.default)
334
+ _register_jit_decomposition_for_jvp(torch.ops.aten.batch_norm_backward.default)
335
+ _register_jit_decomposition_for_jvp(torch.ops.aten.miopen_batch_norm_backward.default)
phivenv/Lib/site-packages/torch/_decomp/decompositions_for_rng.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-decorators
2
+ # mypy: allow-untyped-defs
3
+ import functools
4
+ from collections import defaultdict
5
+ from typing import Callable
6
+
7
+ import torch
8
+ import torch._decomp as decomp
9
+ from torch._decomp import get_decompositions
10
+ from torch._ops import OpOverload
11
+
12
+
13
+ aten = torch.ops.aten
14
+
15
+ rng_decompositions: dict[str, dict[OpOverload, Callable]] = defaultdict(dict)
16
+
17
+
18
+ def register_rng_decomposition(aten_op):
19
+ return decomp.register_decomposition(aten_op, rng_decompositions)
20
+
21
+
22
+ def throw_on_non_cuda(device):
23
+ raise RuntimeError(
24
+ f"You are trying to functionalize a {device.type} RNG operator but {device.type} does not "
25
+ f"use Philox/counter-based RNG. Therefore, functionalizing a {device.type} RNG operator is "
26
+ "not supported. We are discussing the possibility of a Philox-based RNG implementation for CPU."
27
+ )
28
+
29
+
30
+ # TODO - We have to register many more distributions here, and also higher level
31
+ # ops like dropout which have fused implementation and can hide the rand inside.
32
+ @register_rng_decomposition(aten.rand)
33
+ def rand(shape, dtype=None, layout=torch.strided, device=None, pin_memory=False):
34
+ if device and device.type != "cuda":
35
+ throw_on_non_cuda(device)
36
+ seed, offset = PhiloxStateTracker.get_state_as_tuple()
37
+ dtype = dtype or torch.float32
38
+ out, offset_jump = torch.ops.rngprims.philox_rand(
39
+ shape, seed, offset, None, device, dtype
40
+ )
41
+ PhiloxStateTracker.advance_offset(offset_jump)
42
+ return out
43
+
44
+
45
+ @register_rng_decomposition(aten.rand_like)
46
+ def rand_like(
47
+ x: torch.Tensor,
48
+ dtype=None,
49
+ layout=None,
50
+ device=None,
51
+ pin_memory=False,
52
+ memory_format=torch.preserve_format,
53
+ ):
54
+ device = device or x.device
55
+ if device.type != "cuda":
56
+ throw_on_non_cuda(device)
57
+ dtype = dtype or x.dtype
58
+ seed, offset = PhiloxStateTracker.get_state_as_tuple()
59
+ out, offset_jump = torch.ops.rngprims.philox_rand(
60
+ x.shape, seed, offset, None, device, dtype
61
+ )
62
+ PhiloxStateTracker.advance_offset(offset_jump)
63
+ return out
64
+
65
+
66
+ class PhiloxState:
67
+ """
68
+ Represents a PhiloxRngState - (seed, offset) where offset = base_offset +
69
+ relative_offset. seed and base_offset basically point to the rng state just
70
+ before tracing starts. relative offset tracks the totally consumed offset at
71
+ trace time.
72
+ """
73
+
74
+ def __init__(self) -> None:
75
+ self.reset()
76
+
77
+ def reset(self):
78
+ self.seed = torch.tensor(())
79
+ self.base_offset = torch.tensor(())
80
+ self.relative_offset = 0
81
+ self.offset_advanced_alteast_once = False
82
+
83
+ def validate_state(self):
84
+ assert self.seed.numel() != 0 and self.base_offset.numel() != 0
85
+
86
+ def advance_offset(self, consumed_offset):
87
+ self.offset_advanced_alteast_once = True
88
+ self.relative_offset = self.relative_offset + consumed_offset
89
+
90
+ def set_state(self, seed, base_offset, relative_offset=0):
91
+ self.seed = seed
92
+ self.base_offset = base_offset
93
+ self.relative_offset = relative_offset
94
+
95
+ def get_state_as_tuple(self):
96
+ self.validate_state()
97
+ return (self.seed, self.base_offset + self.relative_offset)
98
+
99
+ def get_state_as_tensor(self):
100
+ # Only needed because we override get_rng_state.
101
+ self.validate_state()
102
+ return torch.stack([self.seed, self.base_offset + self.relative_offset])
103
+
104
+ def set_state_from_tensor(self, state):
105
+ # Only needed because we override set_rng_state.
106
+ self.seed, self.base_offset = torch.unbind(state)
107
+ self.relative_offset = 0
108
+
109
+
110
+ class PhiloxStateTracker:
111
+ """
112
+ Singleton class to track the philox rng state during AOT Autograd tracing.
113
+ For each aot tracing instance, AOT Autograd resets this tracker and keeps
114
+ track of both forward and backward offsets. At runtime, we only care about
115
+ the total consumed forward and backward offsets. For dynamic shapes, these
116
+ offsets are a function of input shapes. Therefore, the AOT generated graphs
117
+ have additional outputs that compute total consumed forward and backward
118
+ offsets.
119
+ """
120
+
121
+ running_state: PhiloxState
122
+ fwd_state: PhiloxState
123
+ bwd_state: PhiloxState
124
+
125
+ def __enter__(self):
126
+ PhiloxStateTracker.reset()
127
+ return self
128
+
129
+ def __exit__(self, exc_type, exc_cal, exc_tb):
130
+ PhiloxStateTracker.reset()
131
+
132
+ @classmethod
133
+ def reset(cls):
134
+ cls.running_state = PhiloxState()
135
+ cls.fwd_state = PhiloxState()
136
+ cls.bwd_state = PhiloxState()
137
+
138
+ @classmethod
139
+ def mark_beginning_of_forward(cls):
140
+ # Tells the tracker to use fwd_state as the running state
141
+ cls.running_state = cls.fwd_state
142
+
143
+ @classmethod
144
+ def mark_beginning_of_backward(cls):
145
+ # Tells the tracker to use bwd_state as the running state
146
+ cls.running_state = cls.bwd_state
147
+
148
+ @classmethod
149
+ def record_state(cls, seed, offset, mode):
150
+ # Records the seed and offset tensors. These tensors are used to invoke
151
+ # the philox_rand functional primitives.
152
+ if mode == "forward":
153
+ cls.fwd_state.set_state(seed, offset)
154
+ cls.mark_beginning_of_forward()
155
+ else:
156
+ assert mode == "backward"
157
+ cls.bwd_state.set_state(seed, offset)
158
+
159
+ @classmethod
160
+ def get_state_as_tensor(cls):
161
+ # The only reason this exists is because we override get_rng_state and
162
+ # set_rng_state during tracing. get_rng_state expects a tensor output,
163
+ # so return (seed, offset) tuple upset other parts of the program like
164
+ # ctx.saved_tensors.
165
+
166
+ # A bad consequence is that if user saves and restores rng state, we
167
+ # have little bit of ugliness in the generated code, where we first
168
+ # concat the (seed, offset) to create a tensor for get_rng_state, and
169
+ # then split it back to get (seed, offset) tuple in set_rng_state.
170
+
171
+ # TODO: Investigate if there is be a better way to wrap the tuple in a
172
+ # false Tensor object, and then desugar it later on.
173
+ return cls.running_state.get_state_as_tensor()
174
+
175
+ @classmethod
176
+ def get_state_as_tuple(cls):
177
+ return cls.running_state.get_state_as_tuple()
178
+
179
+ @classmethod
180
+ def set_state_from_tensor(cls, x):
181
+ # This is only needed because we override set_rng_state. Look at the
182
+ # comment in get_state_from_tensor method.
183
+ cls.running_state.set_state_from_tensor(x)
184
+
185
+ @classmethod
186
+ def advance_offset(cls, consumed_offset):
187
+ cls.running_state.advance_offset(consumed_offset)
188
+
189
+ @classmethod
190
+ def get_current_relative_offset(cls):
191
+ return cls.running_state.relative_offset
192
+
193
+ @staticmethod
194
+ def multiple_of_4(offset):
195
+ # torch cuda rng state offset must be a multiple of 4. For inductor, as
196
+ # we sum up all the numel, the result might not be a multiple of 4. This
197
+ # method achieves that.
198
+ return (offset + 3) // 4 * 4
199
+
200
+ @classmethod
201
+ def get_updated_fwd_offset(cls):
202
+ # Short circuit if no rand ops were observed
203
+ if not cls.fwd_state.offset_advanced_alteast_once:
204
+ return cls.fwd_state.base_offset
205
+ return cls.multiple_of_4(
206
+ cls.fwd_state.base_offset + cls.fwd_state.relative_offset
207
+ )
208
+
209
+ @classmethod
210
+ def get_updated_bwd_offset(cls):
211
+ # Short circuit if no rand ops were observed
212
+ if not cls.bwd_state.offset_advanced_alteast_once:
213
+ return cls.bwd_state.base_offset
214
+ return cls.multiple_of_4(
215
+ cls.bwd_state.base_offset + cls.bwd_state.relative_offset
216
+ )
217
+
218
+
219
+ # Adding more decompositions which eventually use rand_like inside decomps.
220
+ # Adding these in rng_decompositions ensures the functionalization of rand_like
221
+ # ops used in these decomps. The list is copied from inductor codebase, which
222
+ # uses it for similar purpose.
223
+ #
224
+ # Caution - These decomps do not have same accuracy as that of eager. However,
225
+ # we can't just disable them with a config flag like fallback_random, because
226
+ # for functionalization of rng ops, we have to decompose these ops.
227
+ extra_random_decomps = get_decompositions(
228
+ [
229
+ aten.cauchy,
230
+ aten.cauchy_,
231
+ aten.exponential,
232
+ aten.exponential_,
233
+ aten.geometric,
234
+ aten.geometric_,
235
+ aten.native_dropout,
236
+ aten.normal,
237
+ aten.normal_,
238
+ aten.normal_functional,
239
+ aten.log_normal,
240
+ aten.log_normal_,
241
+ aten.rrelu_with_noise,
242
+ aten.rrelu_with_noise_,
243
+ aten.uniform_,
244
+ ]
245
+ )
246
+ register_extra_random_decomp = functools.partial(
247
+ decomp.register_decomposition, registry=extra_random_decomps
248
+ )
249
+
250
+
251
+ @register_extra_random_decomp([aten.bernoulli_])
252
+ def bernoulli_(self, p=0.5):
253
+ if self.device == torch.device("cpu"):
254
+ return NotImplemented
255
+ return self.copy_(torch.rand_like(self, dtype=torch.float32) < p)
256
+
257
+
258
+ @register_extra_random_decomp([aten.bernoulli.p])
259
+ def bernoulli_p(self, p=0.5, *, generator=None):
260
+ if self.device == torch.device("cpu"):
261
+ return NotImplemented
262
+ assert generator is None
263
+ return torch.rand_like(self, dtype=torch.float32) < p
264
+
265
+
266
+ rng_decompositions.update(extra_random_decomps) # type: ignore[arg-type]
phivenv/Lib/site-packages/torch/_dispatch/__init__.py ADDED
File without changes
phivenv/Lib/site-packages/torch/_dispatch/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (159 Bytes). View file
 
phivenv/Lib/site-packages/torch/_dispatch/__pycache__/python.cpython-39.pyc ADDED
Binary file (6.92 kB). View file
 
phivenv/Lib/site-packages/torch/_dispatch/python.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import itertools
3
+ import unittest.mock
4
+ from collections.abc import Iterator
5
+ from contextlib import contextmanager
6
+ from typing import Callable, TypeVar, Union
7
+ from typing_extensions import ParamSpec
8
+
9
+ import torch
10
+ import torch._C
11
+ import torch._ops
12
+ import torch.utils._python_dispatch
13
+ import torch.utils._pytree as pytree
14
+ from torch._C import DispatchKey
15
+
16
+
17
+ __all__ = ["enable_python_dispatcher", "no_python_dispatcher", "enable_pre_dispatch"]
18
+
19
+ no_python_dispatcher = torch._C._DisablePythonDispatcher
20
+ enable_python_dispatcher = torch._C._EnablePythonDispatcher
21
+ enable_pre_dispatch = torch._C._EnablePreDispatch
22
+
23
+ CROSSREF_FUNCTIONALIZE = False
24
+
25
+ _P = ParamSpec("_P")
26
+ _T = TypeVar("_T")
27
+
28
+
29
+ def all_py_loaded_overloads() -> Iterator[torch._ops.OpOverload]:
30
+ """
31
+ Warning: the set of overloads this will report is very subtle. It is precisely
32
+ the set of torch.ops functions that have actually been accessed from Python
33
+ (e.g., we actually called torch.ops.aten.blah at some point. This is DIFFERENT
34
+ from the set of registered operators, which will in general be a larger set,
35
+ as this would include all operators which we ran C++ static initializers or
36
+ Python operator registration on. This does not eagerly populate the list on
37
+ torch.ops.aten; this list is lazy!
38
+
39
+ In other words, this is good for traversing over everything that has an
40
+ OpOverload object allocated in Python. We use it for cache invalidation, but
41
+ don't rely on this list being complete.
42
+
43
+ Note that even if we did report all C++ registered overloads, this isn't guaranteed
44
+ to be complete either, as a subsequent lazy load of a library which triggers more
45
+ registrations could add more things to the set.
46
+ """
47
+ for ns in torch.ops:
48
+ packets = getattr(torch.ops, ns)
49
+ for op_name in packets:
50
+ packet = getattr(packets, op_name)
51
+ for overload in packet:
52
+ yield getattr(packet, overload)
53
+
54
+
55
+ @contextmanager
56
+ def suspend_functionalization():
57
+ f_tls = torch._C._dispatch_tls_is_dispatch_key_included(
58
+ torch._C.DispatchKey.Functionalize
59
+ )
60
+ f_rv = torch._C._functionalization_reapply_views_tls()
61
+ if f_tls:
62
+ torch._disable_functionalization()
63
+ try:
64
+ yield
65
+ finally:
66
+ if f_tls:
67
+ torch._enable_functionalization(reapply_views=f_rv)
68
+
69
+
70
+ def check_tensor_metadata_matches(nv, rv, desc):
71
+ assert callable(desc)
72
+ assert nv.size() == rv.size(), f"{desc()}: sizes {nv.size()} != {rv.size()}"
73
+ assert nv.dtype == rv.dtype, f"{desc()}: dtype {nv.dtype} != {rv.dtype}"
74
+ same_strides, idx = torch._prims_common.check_significant_strides(
75
+ nv, rv, only_cuda=False
76
+ )
77
+ assert same_strides, (
78
+ f"{desc()}: strides {nv.stride()} != {rv.stride()} (mismatch at index {idx})"
79
+ )
80
+
81
+
82
+ def check_metadata_matches(n, r, desc):
83
+ assert callable(desc)
84
+ n_vals, _n_spec = pytree.tree_flatten(n)
85
+ r_vals, _r_spec = pytree.tree_flatten(r)
86
+ # TODO: test the specs match; empirically sometimes we have a tuple
87
+ # on one side and a list on the other
88
+ assert len(n_vals) == len(r_vals), f"{len(n_vals)} != {len(r_vals)}"
89
+ for i, nv, rv in zip(range(len(n_vals)), n_vals, r_vals):
90
+ if not isinstance(rv, torch.Tensor):
91
+ continue
92
+ check_tensor_metadata_matches(nv, rv, lambda: f"{desc()} output {i}")
93
+
94
+
95
+ class Lit:
96
+ def __init__(self, s):
97
+ self.s = s
98
+
99
+ def __repr__(self):
100
+ return self.s
101
+
102
+
103
+ def _fmt(a: object) -> object:
104
+ if isinstance(a, torch.Tensor):
105
+ return Lit(
106
+ f"torch.empty_strided({tuple(a.size())}, {a.stride()}, dtype={a.dtype})"
107
+ )
108
+ else:
109
+ return a
110
+
111
+
112
+ def make_crossref_functionalize(
113
+ op: torch._ops.OpOverload[_P, _T], final_key: DispatchKey
114
+ ) -> Union[Callable[_P, _T], DispatchKey]:
115
+ from torch._subclasses.fake_tensor import FakeTensorMode
116
+
117
+ # This case is pretty weird, suppress it for now
118
+ if op == torch.ops.aten.lift_fresh.default:
119
+ return final_key
120
+
121
+ def handler(*args: _P.args, **kwargs: _P.kwargs) -> _T:
122
+ fake_mode = FakeTensorMode()
123
+
124
+ def fakeify_defun(t):
125
+ if isinstance(t, torch.Tensor):
126
+ if torch._is_functional_tensor(t):
127
+ r = torch._from_functional_tensor(t)
128
+ # NB: This assumes that the inner tensor sizes/strides match
129
+ # the outer tensor sizes/strides. This doesn't necessarily have to
130
+ # be the case, see discussion at
131
+ # https://github.com/pytorch/pytorch/pull/87610/files/401ddeda1d769bedc88a12de332c7357b60e51a4#r1007264456
132
+ assert t.size() == r.size()
133
+ assert t.stride() == r.stride()
134
+ else:
135
+ r = t
136
+ # TODO: suppress guards
137
+ return fake_mode.from_tensor(r)
138
+ return t
139
+
140
+ def maybe_detach(t):
141
+ if isinstance(t, torch.Tensor):
142
+ return t.detach()
143
+ else:
144
+ return t
145
+
146
+ # TODO: This probably does the wrong thing if you're running other
147
+ # substantive modes with the normal op outside here
148
+ with (
149
+ torch.utils._python_dispatch._disable_current_modes(),
150
+ suspend_functionalization(),
151
+ ):
152
+ f_args, f_kwargs = pytree.tree_map(fakeify_defun, (args, kwargs))
153
+ orig_f_args, orig_f_kwargs = pytree.tree_map(
154
+ maybe_detach, (f_args, f_kwargs)
155
+ )
156
+ with fake_mode:
157
+ f_r = op(*f_args, **f_kwargs)
158
+ r = op._op_dk(final_key, *args, **kwargs)
159
+
160
+ def desc():
161
+ fmt_args = ", ".join(
162
+ itertools.chain(
163
+ (repr(pytree.tree_map(_fmt, a)) for a in orig_f_args),
164
+ (
165
+ f"{k}={pytree.tree_map(_fmt, v)}"
166
+ for k, v in orig_f_kwargs.items()
167
+ ),
168
+ )
169
+ )
170
+ return f"{op}({fmt_args})"
171
+
172
+ check_metadata_matches(f_r, r, desc)
173
+ return r
174
+
175
+ return handler
176
+
177
+
178
+ # NB: enabling this is slow, don't do it in a hot loop. This is purely
179
+ # for debugging purposes.
180
+ @contextmanager
181
+ def enable_crossref_functionalize():
182
+ for op in all_py_loaded_overloads():
183
+ op._uncache_dispatch(torch._C.DispatchKey.Functionalize)
184
+ try:
185
+ with (
186
+ enable_python_dispatcher(),
187
+ unittest.mock.patch("torch._dispatch.python.CROSSREF_FUNCTIONALIZE", True),
188
+ ):
189
+ yield
190
+ finally:
191
+ for op in all_py_loaded_overloads():
192
+ op._uncache_dispatch(torch._C.DispatchKey.Functionalize)