File size: 24,024 Bytes
ad5f26a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
# mypy: allow-untyped-defs
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar
from typing_extensions import ParamSpec

import torch

from . import config


if TYPE_CHECKING:
    from ._cache import CacheInfo


__all__ = [
    "compile",
    "config",
    "assume_constant_result",
    "reset",
    "allow_in_graph",
    "substitute_in_graph",
    "list_backends",
    "disable",
    "set_stance",
    "set_enable_guard_collectives",
    "cudagraph_mark_step_begin",
    "wrap_numpy",
    "is_compiling",
    "is_dynamo_compiling",
    "is_exporting",
    "save_cache_artifacts",
    "load_cache_artifacts",
    "skip_guard_on_inbuilt_nn_modules_unsafe",
    "skip_guard_on_all_nn_modules_unsafe",
    "keep_tensor_guards_unsafe",
    "skip_guard_on_globals_unsafe",
    "nested_compile_region",
]


_P = ParamSpec("_P")
_R = TypeVar("_R")


def compile(*args, **kwargs):
    """

    See :func:`torch.compile` for details on the arguments for this function.

    """
    return torch.compile(*args, **kwargs)


def reset() -> None:
    """

    This function clears all compilation caches and restores the system to its initial state.

    It is recommended to call this function, especially after using operations like `torch.compile(...)`

    to ensure a clean state before another unrelated compilation

    """
    import torch._dynamo

    torch._dynamo.reset()


def allow_in_graph(fn):
    """

    Tells the compiler frontend (Dynamo) to skip symbolic introspection of the function

    and instead directly write it to the graph when encountered.



    If you are using :func:`torch.compile` (with backend="inductor" (the default)), or

    :func:`torch.export.export`, and trying to black-box a Python function throughout

    all tracing, do not use this API.

    Instead, please create a custom operator (see `PyTorch Custom Operators Landing Page

    <https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html>`_)



    .. warning::



        If you're a typical torch.compile user (e.g. you're applying torch.compile to

        a model to make it run faster), you probably don't want to use this function.

        :func:`allow_in_graph` is a footgun because it skips the compiler frontend

        (Dynamo) that is responsible for doing safety checks (graph breaks, handling

        closures, etc). Incorrect usage will lead to difficult-to-debug silent

        incorrectness issues.



    Given a Python function with no allow_in_graph decorator, regular execution

    of torch.compile traces through the function. :func:`allow_in_graph` changes

    it so that the frontend does not trace inside the function, but the compiler

    backend still traces through it. Compare this to custom operators, which

    treats a function as a black box throughout the torch.compile stack. The following

    table compares these mechanisms.



    +------------------------+-----------------------+--------------------------------+

    | Mechanism              | Frontend (Dynamo)     | Backend (AOTAutograd+Inductor) |

    +========================+=======================+================================+

    | no decorator           | trace inside          | trace inside                   |

    +------------------------+-----------------------+--------------------------------+

    | allow_in_graph         | opaque callable       | trace inside                   |

    +------------------------+-----------------------+--------------------------------+

    | custom op              | opaque callable       | opaque callable                |

    +------------------------+-----------------------+--------------------------------+



    One common use case for :func:`allow_in_graph()` is as an escape hatch for the compiler

    frontend: if you know the function works w.r.t. to the downstream components of the

    compilation stack (AOTAutograd and Inductor) but there is a Dynamo bug that prevents it from

    symbolically introspecting the function properly (or if your code is in C/C++ and

    therefore cannot be introspected with Dynamo), then one can decorate said function

    with :func:`allow_in_graph` to bypass Dynamo.



    We require that ``fn`` adhere to the following restrictions. Failure to adhere

    results in undefined behavior:



    - The inputs to ``fn`` must be Proxy-able types in the FX graph. Valid types include:

      Tensor/int/bool/float/None/List[Tensor?]/List[int?]/List[float?]

      Tuple[Tensor?, ...]/Tuple[int?, ...]/Tuple[float?, ...]/torch.dtype/torch.device

    - The outputs to ``fn`` must be Proxy-able types in the FX graph (see previous bullet)

    - all Tensors used inside of ``fn`` must be passed directly as inputs to ``fn``

      (as opposed to being captured variables).



    Args:

        fn: A callable representing the function to be included in the graph.

            If ``fn`` is a list or tuple of callables it recursively applies

            :func:`allow_in_graph()` to each function and returns a new list or

            tuple containing the modified functions.



    Example::



        torch.compiler.allow_in_graph(my_custom_function)



        @torch.compile(...)

        def fn(x):

            x = torch.add(x, 1)

            x = my_custom_function(x)

            x = torch.add(x, 1)

            return x



        fn(...)



    Will capture a single graph containing ``my_custom_function()``.



    """
    import torch._dynamo

    return torch._dynamo.allow_in_graph(fn)


def substitute_in_graph(

    original_fn: Callable[_P, _R],

    *,

    can_constant_fold_through: bool = False,

    skip_signature_check: bool = False,

) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
    """

    Register a polyfill handler for a function, usually a C function from the C extension, to be

    used in place of the original function when inlining the original function in the graph.



    .. note::



        The polyfill handler is only used when inlining the original function. It is not used when

        the original function is called directly. In the eager mode, the decorated function calls

        the performant C function rather than the polyfill handler.



    The polyfill handler is a function that will be called in place of the original function when

    inlining the original function. The polyfill handler should have the same signature and the same

    behavior as the original function.



    Args:

        original_fn (callable): The original function, usually a C function, to register a polyfill

            handler for.

        can_constant_fold_through (bool, optional): Whether the polyfill handler can be constant

            folded through. That is, if the polyfill handler is a pure function and its arguments

            are constant, the result of the polyfill handler can be constant folded during the

            compilation. Defaults to ``False``.

        skip_signature_check (bool, optional): Whether to skip the signature check between the

            original function and the polyfill handler. Defaults to ``False``.



    Returns:

        A decorator that registers the polyfill handler for the original function.



    Example::



        >>> import operator

        >>> operator.indexOf([1, 2, 3, 4, 5], 3)

        2

        >>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3)

        ... # xdoctest: +SKIP("Long tracebacks")

        Traceback (most recent call last):

        ...

        torch._dynamo.exc.Unsupported: ...



        >>> @torch.compiler.substitute_in_graph(operator.indexOf)

        ... def indexOf(a, b, /):

        ...     for i, item in enumerate(a):

        ...         if item is b or item == b:

        ...             return i

        ...     raise ValueError("sequence.index(x): x not in sequence")

        >>>

        >>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3)

        2

    """
    import torch._dynamo

    return torch._dynamo.substitute_in_graph(
        original_fn,
        can_constant_fold_through=can_constant_fold_through,
        skip_signature_check=skip_signature_check,
    )


def list_backends(exclude_tags=("debug", "experimental")) -> list[str]:
    """

    Return valid strings that can be passed to `torch.compile(..., backend="name")`.



    Args:

        exclude_tags(optional): A tuple of strings representing tags to exclude.

    """
    import torch._dynamo

    return torch._dynamo.list_backends(exclude_tags)


def assume_constant_result(fn):
    """

    This function is used to mark a function `fn` as having a constant result.

    This allows the compiler to optimize away your function.

    Returns The same function `fn`



    Args:

        fn: The function to be marked as having a constant result.



    .. warning::

        `assume_constant_result` can if invalid cause safety and soundness issues, :func:`torch.compile`

        will not attempt to validate whether the constant assumption is true or not



    """
    import torch._dynamo

    return torch._dynamo.assume_constant_result(fn)


def disable(fn=None, recursive=True, *, reason=None):
    """

    This function provides a decorator to disable compilation on a function.

    It also provides the option of recursively disabling called functions.



    Args:

        fn (optional): The function to disable

        recursive (optional): A boolean value indicating whether the disabling should be recursive.

        reason (optional): A string value indicating the reason for disabling the function.

    """
    import torch._dynamo

    return torch._dynamo.disable(fn, recursive, reason=reason)


def set_stance(

    stance: str = "default", *, skip_guard_eval_unsafe=False, force_backend=None

):
    """

    Set the current stance of the compiler.

    Can be used as a function, context manager, or decorator.

    Do not use this function inside a `torch.compile` region - an error will be raised otherwise.



    .. code-block:: python



        @torch.compile

        def foo(x):

            ...



        @torch.compiler.set_stance("force_eager")

        def bar():

            # will not be compiled

            foo(...)



        bar()



        with torch.compiler.set_stance("force_eager"):

            # will also not be compiled

            foo(...)



        torch.compiler.set_stance("force_eager")

        # will also not be compiled

        foo(...)

        torch.compiler.set_stance("default")



        # will be compiled

        foo(...)



    Args:

        stance: The stance to set the compiler to. Valid values are:



            - "default": The default stance, used for normal compilation.

            - "force_eager": Ignore all `torch.compile` directives.

            - "eager_on_recompile": Run code eagerly when a recompile is necessary.

              If there is cached compiled code valid for the input, it will still be used.

            - "fail_on_recompile": Raise an error when recompiling a function.

            - "eager_then_compile": Run the first invocation in eager mode, then compile on

              subsequent calls. This is beneficial for dynamic shapes as it allows inferring

              dynamism from the first two invocations instead of wasting a static compile on

              the first invocation.

            - "aot_eager_then_compile": Run the first invocation with AOT eager to get memory

              benefits from activation checkpointing, then compile on subsequent calls. Like

              eager_then_compile, this improves handling of dynamic shapes by avoiding an

              initial static compile.





        skip_guard_eval_unsafe: A flag to run only differentiating guards.

            CAUTION - This flag is unsafe and should only be used if your setup

            meets the following conditions.



            torch.compile uses a guard system to support recompilations and

            choose which compiled artifact to run at runtime.  These guards,

            though efficient, add some overhead, which may impact performance in

            scenarios where you need to optimize for minimal guard processing

            time.  This API enables you to disable guard evaluation, assuming

            that you have warmed up the compiled model with a sufficient variety

            of inputs. This assumption means that, after the warmup phase, no

            further recompilations will be necessary.  If this assumption fails,

            there is a risk of silently producing incorrect results (hence the

            term "unsafe" in the API name).



        force_backend: If `stance` is "default", this argument can be used to force `torch.compile`

            to use a specific backend. Otherwise, an error is raised.

    """
    import torch._dynamo

    return torch._dynamo.set_stance(
        stance,
        skip_guard_eval_unsafe=skip_guard_eval_unsafe,
        force_backend=force_backend,
    )


# forbid in graph
set_stance._dynamo_forbidden = True  # type: ignore[attr-defined]


def set_enable_guard_collectives(enabled: bool):
    """

    Enables use of collectives *during* guard evaluation to synchronize behavior

    across ranks.  This is expensive: we have to issue a collective every time

    we enter a compiled code region, even if no rank actually would need to

    compile.  This can help prevent NCCL hangs by ensuring that we never have a

    situation where one rank starts recompiling while other ranks don't compile;

    it is especially useful in conjunction with enable_compiler_collectives

    where such a situation would immediately cause a hang (as it is necessary

    for all ranks to compile at the same time to run compiler collectives).  Like

    compiler collectives, you can only run this on SPMD programs; you will hang

    otherwise.  Note that a guard collective is only issued if there is any

    compiled code to guard on; if this the first time we encounter a frame or

    the frame is skipped, we don't issue collectives.



    Returns the previous setting of enabled.

    """
    from torch._C._dynamo.eval_frame import set_guard_complete_hook  # noqa: F401
    from torch._dynamo.eval_frame import guard_collectives_hook

    if enabled:
        return set_guard_complete_hook(guard_collectives_hook) is not None
    else:
        return set_guard_complete_hook(None) is not None


set_enable_guard_collectives._dynamo_forbidden = True  # type: ignore[attr-defined]


def cudagraph_mark_step_begin():
    """

    Indicates that a new iteration of inference or training is about to begin.



    CUDA Graphs will free tensors of a prior iteration. A new iteration is started on each invocation of

    torch.compile, so long as there is not a pending backward that has not been called.



    If that heuristic is wrong, such as in the following example, manually mark it with this api.



    .. code-block:: python



        @torch.compile(mode="reduce-overhead")

        def rand_foo():

            return torch.rand([4], device="cuda")



        for _ in range(5):

            torch.compiler.cudagraph_mark_step_begin()

            rand_foo() + rand_foo()



    For more details, see `torch.compiler_cudagraph_trees <https://pytorch.org/docs/main/torch.compiler_cudagraph_trees.html>`__

    """
    from torch._inductor import cudagraph_trees

    cudagraph_trees.mark_step_begin()


def wrap_numpy(fn):
    r"""Decorator that turns a function from ``np.ndarray``s to ``np.ndarray``s into a function

    from ``torch.Tensor``s to ``torch.Tensor``s.



    It is designed to be used with :func:`torch.compile` with ``fullgraph=True``. It allows to

    compile a NumPy function as if it were a PyTorch function. This allows you to run NumPy code

    on CUDA or compute its gradients.



    .. note::



        This decorator does not work without :func:`torch.compile`.



    Example::



        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)

        >>> # Compile a NumPy function as a Tensor -> Tensor function

        >>> @torch.compile(fullgraph=True)

        >>> @torch.compiler.wrap_numpy

        >>> def fn(a: np.ndarray):

        >>>     return np.sum(a * a)

        >>> # Execute the NumPy function using Tensors on CUDA and compute the gradients

        >>> x = torch.arange(6, dtype=torch.float32, device="cuda", requires_grad=True)

        >>> out = fn(x)

        >>> out.backward()

        >>> print(x.grad)

        tensor([ 0.,  2.,  4.,  6.,  8., 10.], device='cuda:0')

    """
    from torch._dynamo.external_utils import wrap_numpy as wrap

    return wrap(fn)


_is_compiling_flag: bool = False
_is_exporting_flag: bool = False


def is_compiling() -> bool:
    """

    Indicates whether a graph is executed/traced as part of torch.compile() or torch.export().



    Note that there are 2 other related flags that should deprecated eventually:

      * torch._dynamo.external_utils.is_compiling()

      * torch._utils.is_compiling()



    Example::



        >>> def forward(self, x):

        >>>     if not torch.compiler.is_compiling():

        >>>        pass # ...logic that is not needed in a compiled/traced graph...

        >>>

        >>>     # ...rest of the function...

    """
    if torch.jit.is_scripting():
        return False
    else:
        return _is_compiling_flag


def is_dynamo_compiling() -> bool:
    """

    Indicates whether a graph is traced via TorchDynamo.



    It's stricter than is_compiling() flag, as it would only be set to True when

    TorchDynamo is used.



    Example::



        >>> def forward(self, x):

        >>>     if not torch.compiler.is_dynamo_compiling():

        >>>        pass # ...logic that is not needed in a TorchDynamo-traced graph...

        >>>

        >>>     # ...rest of the function...

    """
    return False


def is_exporting() -> bool:
    """

    Indicated whether we're under exporting.



    It's stricter than is_compiling() flag, as it would only be set to True when

    torch.export is used.



    Example::



        >>> def forward(self, x):

        >>>     if not torch.compiler.is_exporting():

        >>>        pass # ...logic that is not needed in export...

        >>>

        >>>     # ...rest of the function...

    """
    return _is_exporting_flag


def save_cache_artifacts() -> Optional[tuple[bytes, "CacheInfo"]]:
    """

    Serializes all the cache artifacts that were created during the compilation



    Example:



    - Execute torch.compile

    - Call torch.compiler.save_cache_artifacts()

    """
    from ._cache import CacheArtifactManager, CacheInfo

    return CacheArtifactManager.serialize()


def load_cache_artifacts(serialized_artifacts: bytes) -> Optional["CacheInfo"]:
    """

    Hot loads cache artifacts that were previously serialized via

    save_cache_artifacts



    Example:



    # From a previous invocation

    artifacts = torch.compiler.save_cache_artifacts()



    torch.compiler.load_cache_artifacts(artifacts[0])

    """
    from ._cache import CacheArtifactManager, CacheInfo

    artifacts = CacheArtifactManager.deserialize(serialized_artifacts)
    if artifacts is not None:
        return CacheArtifactManager.populate_caches(artifacts)
    return None


def skip_guard_on_inbuilt_nn_modules_unsafe(guard_entries):
    """

    A common function to skip guards on the inbuilt nn modules like

    torch.nn.Linear. This is unsafe to use by default. But for majority of

    torch.compile users, the model code does not modify the inbuilt nn module

    attributes. They can benefit from reduction in guard latency overhead using

    this API.



    To use this API, use guard_filter_fn argument while calling torch.compile



    >> opt_mod = torch.compile(

    >>     mod,

    >>     options={"guard_filter_fn": torch.compiler.skip_guard_on_all_nn_modules_unsafe},

    >> )

    """
    return [
        not entry.orig_guard.source.is_unspecialized_builtin_nn_module()
        for entry in guard_entries
    ]


def skip_guard_on_all_nn_modules_unsafe(guard_entries):
    """

    A common function to skip guards on all nn modules, both user defined as

    well inbuilt nn modules (like torch.nn.Linear). This is unsafe to use by

    default. But for majority of torch.compile users, the model code does not

    modify the nn module attributes. They can benefit from reduction in guard

    latency overhead using this API.



    To use this API, use guard_filter_fn argument while calling torch.compile



    >> opt_mod = torch.compile(

    >>     mod,

    >>     options={"guard_filter_fn": torch.compiler.skip_guard_on_all_nn_modules_unsafe},

    >> )

    """

    return [
        not entry.orig_guard.source.is_unspecialized_nn_module()
        for entry in guard_entries
    ]


def keep_tensor_guards_unsafe(guard_entries, keep_parameters=False):
    """

    A common function to keep tensor guards on all tensors. This is unsafe to

    use by default. But if you don't expect any changes in the model code, you

    can just keep the tensor guards.





    >> opt_mod = torch.compile(

    >>     mod,

    >>     options={"guard_filter_fn": torch.compiler.keep_tensor_guards},

    >> )

    """

    keep_flags = []
    for entry in guard_entries:
        if entry.guard_type == "TENSOR_MATCH":
            if not isinstance(entry.value, torch.nn.Parameter):
                keep_flags.append(True)
            elif keep_parameters:
                keep_flags.append(True)
            else:
                keep_flags.append(False)
        else:
            keep_flags.append(False)
    return keep_flags


def skip_guard_on_globals_unsafe(guard_entries):
    """

    A common function to skip guards on all globals. This is unsafe to use by

    default. But if you don't expect any changes in the globals, you can just

    keep the tensor guards.



    >> opt_mod = torch.compile(

    >>     mod,

    >>     options={"guard_filter_fn": torch.compiler.skip_guard_on_globals},

    >> )

    """

    return [not entry.is_global for entry in guard_entries]


def nested_compile_region(fn=None):
    """

    Tells **``torch.compile``** that the marked set of operations forms a nested

    compile region (which is often repeated in the full model) whose code can be

    compiled once and safely reused.  ``nested_compile_region`` can also be used

    as a decorator.



    During **``torch.compile``** tracing, the compiler applies *hierarchical

    compilation* with ``nested_compile_region``: it emits optimized code for the

    marked region the first time it is encountered and re-emits (or “stamps

    out”) the previously compiled code on every subsequent invocation.  This can

    substantially reduce overall compile time for deeply-stacked,

    structurally-identical components such as the transformer layers of a

    large-language-model (LLM).



    Outside a ``torch.compile`` context—i.e., in standard eager execution—the

    call is a no-op, so existing workflows remain unaffected.



    Note that ``nested_compile_region`` **does not** promise that a region will

    be compiled exactly once.  If the compiler detects that new input conditions

    (shape, dtype, device, stride, globals etc.) make the cached version invalid

    to reuse, it will transparently re-compile the region.  Using it is

    therefore *safe*: correctness is always preserved, and you pay the extra

    compilation cost only when required.

    """

    from torch._higher_order_ops.invoke_subgraph import (
        mark_compile_region as _mark_compile_region,
    )

    return _mark_compile_region(fn)