File size: 5,267 Bytes
59f1501
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
# mypy: allow-untyped-defs
from typing import Any, Callable

import torch

class GlobalStateGuard:
    def check(self) -> bool: ...
    def reason(self) -> str: ...

class LeafGuard: ...
class GuardDebugInfo: ...

class GuardManager:
    def check(self, value) -> bool: ...
    def check_verbose(self, value) -> GuardDebugInfo: ...

    # Accessors
    def globals_dict_manager(

        self,

        f_globals: dict[str, Any],

        source,

        example_value,

        guard_manager_enum,

    ) -> GuardManager: ...
    def framelocals_manager(

        self,

        key: tuple[str, int],

        source,

        example_value,

        guard_manager_enum,

    ) -> GuardManager: ...
    def dict_getitem_manager(

        self,

        key,

        source,

        example_value,

        guard_manager_enum,

    ) -> GuardManager: ...
    def global_weakref_manager(

        self,

        global_name: str,

        source,

        example_value,

        guard_manager_enum,

    ) -> GuardManager: ...
    def type_manager(

        self,

        source,

        example_value,

        guard_manager_enum,

    ) -> GuardManager: ...
    def getattr_manager(

        self,

        attr: str,

        source,

        example_value,

        guard_manager_enum,

    ) -> GuardManager: ...
    def tensor_property_size_manager(

        self,

        idx: int,

        source,

        example_value,

        guard_manager_enum,

    ) -> GuardManager: ...
    def tensor_property_shape_manager(

        self,

        idx: int,

        source,

        example_value,

        guard_manager_enum,

    ) -> GuardManager: ...
    def tensor_property_storage_offset_manager(

        self,

        idx: None,

        source,

        example_value,

        guard_manager_enum,

    ) -> GuardManager: ...
    def indexed_manager(

        self,

        idx: int,

        source,

        example_value,

        guard_manager_enum,

    ) -> GuardManager: ...
    def lambda_manager(

        self,

        python_lambda,

        source,

        example_value,

        guard_manager_enum,

    ) -> GuardManager: ...

    # Leaf guards
    def add_lambda_guard(self, user_lambda, verbose_code_parts: list[str]) -> None: ...
    def add_id_match_guard(self, id_val, verbose_code_parts: list[str]) -> None: ...
    def add_equals_match_guard(

        self,

        equals_val,

        verbose_code_parts: list[str],

    ) -> None: ...
    def add_global_state_guard(self, verbose_code_parts: list[str]) -> None: ...
    def add_torch_function_mode_stack_guard(

        self, initial_stack, verbose_code_parts: list[str]

    ) -> None: ...
    def add_mapping_keys_guard(sef, value, verbose_code_parts: list[str]) -> None: ...

class RootGuardManager(GuardManager):
    def get_epilogue_lambda_guards(self) -> list[LeafGuard]: ...
    def add_epilogue_lambda_guard(

        self,

        guard: LeafGuard,

        verbose_code_parts: list[str],

    ) -> None: ...
    def clone_manager(

        self, clone_filter_fn: Callable[[GuardManager], bool]

    ) -> RootGuardManager: ...

class DictGuardManager(GuardManager):
    def get_key_manager(

        self,

        index,

        source,

        example_value,

        guard_manager_enum,

    ) -> GuardManager: ...
    def get_value_manager(

        self,

        index,

        source,

        example_value,

        guard_manager_enum,

    ) -> GuardManager: ...

def install_object_aliasing_guard(

    guard_managers: list[GuardManager],

    tensor_names: list[str],

    verbose_code_parts: list[str],

): ...
def install_no_tensor_aliasing_guard(

    guard_managers: list[GuardManager],

    tensor_names: list[str],

    verbose_code_parts: list[str],

): ...
def install_storage_overlapping_guard(

    overlapping_guard_managers: list[GuardManager],

    non_overlapping_guard_managers: list[GuardManager],

    verbose_code_parts: list[str],

): ...
def install_symbolic_shape_guard(

    guard_managers: list[GuardManager],

    nargs_int: int,

    nargs_float: int,

    py_addr: int,

    py_addr_keep_alive: Any,

    verbose_code_parts: list[str],

): ...
def profile_guard_manager(

    guard_manager: GuardManager,

    f_locals: dict[str, Any],

    n_iters: int,

) -> float: ...

class TensorGuards:
    def __init__(

        self,

        *,

        dynamic_dims_sizes: list[torch.SymInt | None] | None = None,

        dynamic_dims_strides: list[torch.SymInt | None] | None = None,

    ) -> None: ...
    def check(self, *args) -> bool: ...
    def check_verbose(self, *args, tensor_check_names=None) -> bool | str: ...

def assert_size_stride(

    item: torch.Tensor,

    size: torch.types._size,

    stride: torch.types._size,

    op_name: str | None = None,

): ...
def assert_alignment(

    item: torch.Tensor,

    alignment: int,

    op_name: str | None = None,

): ...
def check_obj_id(obj: object, expected: int) -> bool: ...
def check_type_id(obj: object, expected: int) -> bool: ...
def dict_version(d: dict[Any, Any]) -> int: ...
def compute_overlapping_tensors(

    tensors: list[torch.Tensor], symbolic: bool = True

) -> set[int]: ...