File size: 5,611 Bytes
59f1501 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
# mypy: ignore-errors
"""
This module implements TorchDynamo's backend registry system for managing compiler backends.
The registry provides a centralized way to register, discover and manage different compiler
backends that can be used with torch.compile(). It handles:
- Backend registration and discovery through decorators and entry points
- Lazy loading of backend implementations
- Lookup and validation of backend names
- Categorization of backends using tags (debug, experimental, etc.)
Key components:
- CompilerFn: Type for backend compiler functions that transform FX graphs
- _BACKENDS: Registry mapping backend names to entry points
- _COMPILER_FNS: Registry mapping backend names to loaded compiler functions
Example usage:
@register_backend
def my_compiler(fx_graph, example_inputs):
# Transform FX graph into optimized implementation
return compiled_fn
# Use registered backend
torch.compile(model, backend="my_compiler")
The registry also supports discovering backends through setuptools entry points
in the "torch_dynamo_backends" group. Example:
```
setup.py
---
from setuptools import setup
setup(
name='my_torch_backend',
version='0.1',
packages=['my_torch_backend'],
entry_points={
'torch_dynamo_backends': [
# name = path to entry point of backend implementation
'my_compiler = my_torch_backend.compiler:my_compiler_function',
],
},
)
```
```
my_torch_backend/compiler.py
---
def my_compiler_function(fx_graph, example_inputs):
# Transform FX graph into optimized implementation
return compiled_fn
```
Using `my_compiler` backend:
```
import torch
model = ... # Your PyTorch model
optimized_model = torch.compile(model, backend="my_compiler")
```
"""
import functools
import logging
import sys
from collections.abc import Sequence
from importlib.metadata import EntryPoint
from typing import Callable, Optional, Protocol
import torch
from torch import fx
log = logging.getLogger(__name__)
class CompiledFn(Protocol):
def __call__(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: ...
CompilerFn = Callable[[fx.GraphModule, list[torch.Tensor]], CompiledFn]
_BACKENDS: dict[str, Optional[EntryPoint]] = {}
_COMPILER_FNS: dict[str, CompilerFn] = {}
def register_backend(
compiler_fn: Optional[CompilerFn] = None,
name: Optional[str] = None,
tags: Sequence[str] = (),
):
"""
Decorator to add a given compiler to the registry to allow calling
`torch.compile` with string shorthand. Note: for projects not
imported by default, it might be easier to pass a function directly
as a backend and not use a string.
Args:
compiler_fn: Callable taking a FX graph and fake tensor inputs
name: Optional name, defaults to `compiler_fn.__name__`
tags: Optional set of string tags to categorize backend with
"""
if compiler_fn is None:
# @register_backend(name="") syntax
return functools.partial(register_backend, name=name, tags=tags)
assert callable(compiler_fn)
name = name or compiler_fn.__name__
assert name not in _COMPILER_FNS, f"duplicate name: {name}"
if compiler_fn not in _BACKENDS:
_BACKENDS[name] = None
_COMPILER_FNS[name] = compiler_fn
compiler_fn._tags = tuple(tags)
return compiler_fn
register_debug_backend = functools.partial(register_backend, tags=("debug",))
register_experimental_backend = functools.partial(
register_backend, tags=("experimental",)
)
def lookup_backend(compiler_fn):
"""Expand backend strings to functions"""
if isinstance(compiler_fn, str):
if compiler_fn not in _BACKENDS:
_lazy_import()
if compiler_fn not in _BACKENDS:
from ..exc import InvalidBackend
raise InvalidBackend(name=compiler_fn)
if compiler_fn not in _COMPILER_FNS:
entry_point = _BACKENDS[compiler_fn]
register_backend(compiler_fn=entry_point.load(), name=compiler_fn)
compiler_fn = _COMPILER_FNS[compiler_fn]
return compiler_fn
def list_backends(exclude_tags=("debug", "experimental")) -> list[str]:
"""
Return valid strings that can be passed to:
torch.compile(..., backend="name")
"""
_lazy_import()
exclude_tags = set(exclude_tags or ())
backends = [
name
for name in _BACKENDS.keys()
if name not in _COMPILER_FNS
or not exclude_tags.intersection(_COMPILER_FNS[name]._tags)
]
return sorted(backends)
@functools.cache
def _lazy_import():
from .. import backends
from ..utils import import_submodule
import_submodule(backends)
from ..repro.after_dynamo import dynamo_minifier_backend
assert dynamo_minifier_backend is not None
_discover_entrypoint_backends()
@functools.cache
def _discover_entrypoint_backends():
# importing here so it will pick up the mocked version in test_backends.py
from importlib.metadata import entry_points
group_name = "torch_dynamo_backends"
if sys.version_info < (3, 10):
eps = entry_points()
eps = eps[group_name] if group_name in eps else []
eps = {ep.name: ep for ep in eps}
else:
eps = entry_points(group=group_name)
eps = {name: eps[name] for name in eps.names}
for backend_name in eps:
_BACKENDS[backend_name] = eps[backend_name]
|