File size: 5,611 Bytes
59f1501
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
# mypy: ignore-errors

"""

This module implements TorchDynamo's backend registry system for managing compiler backends.



The registry provides a centralized way to register, discover and manage different compiler

backends that can be used with torch.compile(). It handles:



- Backend registration and discovery through decorators and entry points

- Lazy loading of backend implementations

- Lookup and validation of backend names

- Categorization of backends using tags (debug, experimental, etc.)



Key components:

- CompilerFn: Type for backend compiler functions that transform FX graphs

- _BACKENDS: Registry mapping backend names to entry points

- _COMPILER_FNS: Registry mapping backend names to loaded compiler functions



Example usage:

    @register_backend

    def my_compiler(fx_graph, example_inputs):

        # Transform FX graph into optimized implementation

        return compiled_fn



    # Use registered backend

    torch.compile(model, backend="my_compiler")



The registry also supports discovering backends through setuptools entry points

in the "torch_dynamo_backends" group. Example:

```

setup.py

---

from setuptools import setup



setup(

    name='my_torch_backend',

    version='0.1',

    packages=['my_torch_backend'],

    entry_points={

        'torch_dynamo_backends': [

            # name = path to entry point of backend implementation

            'my_compiler = my_torch_backend.compiler:my_compiler_function',

        ],

    },

)

```

```

my_torch_backend/compiler.py

---

def my_compiler_function(fx_graph, example_inputs):

    # Transform FX graph into optimized implementation

    return compiled_fn

```

Using `my_compiler` backend:

```

import torch



model = ...  # Your PyTorch model

optimized_model = torch.compile(model, backend="my_compiler")

```

"""

import functools
import logging
import sys
from collections.abc import Sequence
from importlib.metadata import EntryPoint
from typing import Callable, Optional, Protocol

import torch
from torch import fx


log = logging.getLogger(__name__)


class CompiledFn(Protocol):
    def __call__(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: ...


CompilerFn = Callable[[fx.GraphModule, list[torch.Tensor]], CompiledFn]

_BACKENDS: dict[str, Optional[EntryPoint]] = {}
_COMPILER_FNS: dict[str, CompilerFn] = {}


def register_backend(

    compiler_fn: Optional[CompilerFn] = None,

    name: Optional[str] = None,

    tags: Sequence[str] = (),

):
    """

    Decorator to add a given compiler to the registry to allow calling

    `torch.compile` with string shorthand.  Note: for projects not

    imported by default, it might be easier to pass a function directly

    as a backend and not use a string.



    Args:

        compiler_fn: Callable taking a FX graph and fake tensor inputs

        name: Optional name, defaults to `compiler_fn.__name__`

        tags: Optional set of string tags to categorize backend with

    """
    if compiler_fn is None:
        # @register_backend(name="") syntax
        return functools.partial(register_backend, name=name, tags=tags)
    assert callable(compiler_fn)
    name = name or compiler_fn.__name__
    assert name not in _COMPILER_FNS, f"duplicate name: {name}"
    if compiler_fn not in _BACKENDS:
        _BACKENDS[name] = None
    _COMPILER_FNS[name] = compiler_fn
    compiler_fn._tags = tuple(tags)
    return compiler_fn


register_debug_backend = functools.partial(register_backend, tags=("debug",))
register_experimental_backend = functools.partial(
    register_backend, tags=("experimental",)
)


def lookup_backend(compiler_fn):
    """Expand backend strings to functions"""
    if isinstance(compiler_fn, str):
        if compiler_fn not in _BACKENDS:
            _lazy_import()
        if compiler_fn not in _BACKENDS:
            from ..exc import InvalidBackend

            raise InvalidBackend(name=compiler_fn)

        if compiler_fn not in _COMPILER_FNS:
            entry_point = _BACKENDS[compiler_fn]
            register_backend(compiler_fn=entry_point.load(), name=compiler_fn)
        compiler_fn = _COMPILER_FNS[compiler_fn]
    return compiler_fn


def list_backends(exclude_tags=("debug", "experimental")) -> list[str]:
    """

    Return valid strings that can be passed to:



        torch.compile(..., backend="name")

    """
    _lazy_import()
    exclude_tags = set(exclude_tags or ())

    backends = [
        name
        for name in _BACKENDS.keys()
        if name not in _COMPILER_FNS
        or not exclude_tags.intersection(_COMPILER_FNS[name]._tags)
    ]
    return sorted(backends)


@functools.cache
def _lazy_import():
    from .. import backends
    from ..utils import import_submodule

    import_submodule(backends)

    from ..repro.after_dynamo import dynamo_minifier_backend

    assert dynamo_minifier_backend is not None

    _discover_entrypoint_backends()


@functools.cache
def _discover_entrypoint_backends():
    # importing here so it will pick up the mocked version in test_backends.py
    from importlib.metadata import entry_points

    group_name = "torch_dynamo_backends"
    if sys.version_info < (3, 10):
        eps = entry_points()
        eps = eps[group_name] if group_name in eps else []
        eps = {ep.name: ep for ep in eps}
    else:
        eps = entry_points(group=group_name)
        eps = {name: eps[name] for name in eps.names}
    for backend_name in eps:
        _BACKENDS[backend_name] = eps[backend_name]