| r''' | |
| FX is a toolkit for developers to use to transform ``nn.Module`` | |
| instances. FX consists of three main components: a **symbolic tracer,** | |
| an **intermediate representation**, and **Python code generation**. A | |
| demonstration of these components in action: | |
| :: | |
| import torch | |
| # Simple module for demonstration | |
| class MyModule(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.param = torch.nn.Parameter(torch.rand(3, 4)) | |
| self.linear = torch.nn.Linear(4, 5) | |
| def forward(self, x): | |
| return self.linear(x + self.param).clamp(min=0.0, max=1.0) | |
| module = MyModule() | |
| from torch.fx import symbolic_trace | |
| # Symbolic tracing frontend - captures the semantics of the module | |
| symbolic_traced : torch.fx.GraphModule = symbolic_trace(module) | |
| # High-level intermediate representation (IR) - Graph representation | |
| print(symbolic_traced.graph) | |
| """ | |
| graph(): | |
| %x : [#users=1] = placeholder[target=x] | |
| %param : [#users=1] = get_attr[target=param] | |
| %add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {}) | |
| %linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {}) | |
| %clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0}) | |
| return clamp | |
| """ | |
| # Code generation - valid Python code | |
| print(symbolic_traced.code) | |
| """ | |
| def forward(self, x): | |
| param = self.param | |
| add = x + param; x = param = None | |
| linear = self.linear(add); add = None | |
| clamp = linear.clamp(min = 0.0, max = 1.0); linear = None | |
| return clamp | |
| """ | |
| The **symbolic tracer** performs "symbolic execution" of the Python | |
| code. It feeds fake values, called Proxies, through the code. Operations | |
| on theses Proxies are recorded. More information about symbolic tracing | |
| can be found in the :func:`symbolic_trace` and :class:`Tracer` | |
| documentation. | |
| The **intermediate representation** is the container for the operations | |
| that were recorded during symbolic tracing. It consists of a list of | |
| Nodes that represent function inputs, callsites (to functions, methods, | |
| or :class:`torch.nn.Module` instances), and return values. More information | |
| about the IR can be found in the documentation for :class:`Graph`. The | |
| IR is the format on which transformations are applied. | |
| **Python code generation** is what makes FX a Python-to-Python (or | |
| Module-to-Module) transformation toolkit. For each Graph IR, we can | |
| create valid Python code matching the Graph's semantics. This | |
| functionality is wrapped up in :class:`GraphModule`, which is a | |
| :class:`torch.nn.Module` instance that holds a :class:`Graph` as well as a | |
| ``forward`` method generated from the Graph. | |
| Taken together, this pipeline of components (symbolic tracing -> | |
| intermediate representation -> transforms -> Python code generation) | |
| constitutes the Python-to-Python transformation pipeline of FX. In | |
| addition, these components can be used separately. For example, | |
| symbolic tracing can be used in isolation to capture a form of | |
| the code for analysis (and not transformation) purposes. Code | |
| generation can be used for programmatically generating models, for | |
| example from a config file. There are many uses for FX! | |
| Several example transformations can be found at the | |
| `examples <https://github.com/pytorch/examples/tree/master/fx>`__ | |
| repository. | |
| ''' | |
| from .graph_module import GraphModule | |
| from ._symbolic_trace import symbolic_trace, Tracer, wrap, PH, ProxyableClassMeta | |
| from .graph import Graph, CodeGen | |
| from .node import Node, map_arg | |
| from .proxy import Proxy | |
| from .interpreter import Interpreter as Interpreter, Transformer as Transformer | |
| from .subgraph_rewriter import replace_pattern | |