| | from typing import Dict |
| |
|
| | import torch |
| | from torch.nn import Module |
| | from torch._ops import OpOverload |
| |
|
| | from torch.fx import GraphModule |
| | from torch.fx.node import Node, _get_qualified_name |
| | from torch.fx.passes.operator_support import OperatorSupport |
| | from torch.fx.passes.tools_common import CALLABLE_NODE_OPS |
| | from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner |
| | from torch._prims.executor import execute |
| | from torch.fx.experimental.proxy_tensor import DecompositionInterpreter |
| | from torch._decomp import decomposition_table |
| |
|
| | import typing as t |
| |
|
| | import logging |
| |
|
| | logger = logging.getLogger(__name__) |
| | logger.setLevel(logging.WARNING) |
| |
|
| | def aten_to_dtype(self, dtype: torch.dtype, **kwargs): |
| | if len(kwargs) > 0 or not dtype: |
| | raise RuntimeError("No support for other to.dtype() formats other than to.dtype(self, dtype)") |
| | return torch._prims.convert_element_type(self, dtype) |
| |
|
| | |
| | |
| | aten2aten_decomp = {} |
| | aten2prim_decomp = {} |
| |
|
| | for op, decomp_fn in decomposition_table.items(): |
| | if "torch._refs" in decomp_fn.__module__: |
| | aten2prim_decomp[op] = decomp_fn |
| | else: |
| | aten2aten_decomp[op] = decomp_fn |
| |
|
| | aten2aten_decomp_skips = { |
| | "aten.native_layer_norm_backward.default", |
| | "aten.embedding_dense_backward.default", |
| | "aten.addmm.default" |
| | } |
| |
|
| | for op, decomp_fn in decomposition_table.items(): |
| | if "torch._refs" in decomp_fn.__module__: |
| | aten2prim_decomp[op] = decomp_fn |
| | else: |
| | if str(op) not in aten2aten_decomp_skips: |
| | aten2aten_decomp[op] = decomp_fn |
| |
|
| |
|
| | aten2prim_decomp[torch.ops.aten.to.dtype] = aten_to_dtype |
| |
|
| |
|
| | class NvFuserOperatorSupport(OperatorSupport): |
| | """ |
| | Operator support for nvFuser backend. |
| | |
| | Currently, partitioning is based on FX ATen graph. The fused subgraph will latter be decomposed into prims. |
| | To determine if an ATen ops is supported by nvFuser, we shall check the prim ops used in its ref decomposition. |
| | Only if all the prim ops in the ref has a nvfuser_impl, we say this Aten op is suppported by nvFuser. |
| | |
| | Note: When adding a rule, please add it to the corresponding section and follow the |
| | alphabetical order. |
| | """ |
| |
|
| | def __init__(self): |
| |
|
| | |
| | |
| | |
| | |
| | support_dict = { |
| | |
| | |
| | |
| | |
| | |
| | "torch.ops.aten.add": None, |
| | "torch.ops.aten.sub": None, |
| | |
| | "torch.ops.aten.div": None, |
| | "torch.ops.aten.atan2": None, |
| | "torch.ops.aten.mul": None, |
| | "torch.ops.aten.max": None, |
| | "torch.ops.aten.min": None, |
| | "torch.ops.aten.pow": None, |
| | "torch.ops.aten.remainder": None, |
| | "torch.ops.aten.fmod": None, |
| | "torch.ops.aten.bitwise_and": None, |
| | "torch.ops.aten.__and__": None, |
| | "torch.ops.aten.bitwise_or": None, |
| | "torch.ops.aten.__or__": None, |
| | "torch.ops.aten.bitwise_xor": None, |
| | "torch.ops.aten.__xor__": None, |
| | "torch.ops.aten.bitwise_left_shift": None, |
| | "torch.ops.aten.__lshift__": None, |
| | "torch.ops.aten.bitwise_right_shift": None, |
| | "torch.ops.aten.__rshift__": None, |
| | "torch.ops.aten.eq": None, |
| | "torch.ops.aten.ne": None, |
| | "torch.ops.aten.ge": None, |
| | "torch.ops.aten.gt": None, |
| | "torch.ops.aten.le": None, |
| | "torch.ops.aten.lt": None, |
| | "torch.ops.aten.abs": None, |
| | "torch.ops.aten.bitwise_not": None, |
| | "torch.ops.aten.ceil": None, |
| | "torch.ops.aten.floor": None, |
| | "torch.ops.aten.frac": None, |
| | "torch.ops.aten.neg": None, |
| | "torch.ops.aten.relu": None, |
| | "torch.ops.aten.round": None, |
| | "torch.ops.aten.silu": None, |
| | "torch.ops.aten.trunc": None, |
| | "torch.ops.aten.log": None, |
| | "torch.ops.aten.log10": None, |
| | "torch.ops.aten.log1p": None, |
| | "torch.ops.aten.log2": None, |
| | "torch.ops.aten.lgamma": None, |
| | "torch.ops.aten.exp": None, |
| | "torch.ops.aten.expm1": None, |
| | "torch.ops.aten.erf": None, |
| | "torch.ops.aten.erfc": None, |
| | "torch.ops.aten.cos": None, |
| | "torch.ops.aten.acos": None, |
| | "torch.ops.aten.cosh": None, |
| | "torch.ops.aten.sin": None, |
| | "torch.ops.aten.asin": None, |
| | "torch.ops.aten.sinh": None, |
| | "torch.ops.aten.tan": None, |
| | "torch.ops.aten.atan": None, |
| | "torch.ops.aten.tanh": None, |
| | "torch.ops.aten.atanh": None, |
| | "torch.ops.aten.sqrt": None, |
| | "torch.ops.aten.rsqrt": None, |
| | "torch.ops.aten.reciprocal": None, |
| | "torch.ops.aten.sigmoid": None, |
| | "torch.ops.aten.isfinite": None, |
| | "torch.ops.aten.isinf": None, |
| | "torch.ops.aten.isnan": None, |
| | "torch.ops.aten.isneginf": None, |
| | "torch.ops.aten.isposinf": None, |
| | "torch.ops.aten.isreal": None, |
| | |
| | "torch.ops.aten.softplus": None, |
| | "torch.ops.aten.threshold": None, |
| | |
| | |
| | "torch.ops.aten.clamp": None, |
| | |
| | |
| | |
| | |
| | |
| | |
| | "torch.ops.aten.where.self": None, |
| | "torch.ops.aten.lerp": None, |
| | "torch.ops.aten.addcmul": None, |
| | |
| | "torch.ops.aten.dropout": None, |
| | |
| | "torch.ops.aten.instance_norm": None, |
| | "torch.ops.aten._batch_norm_impl_index": None, |
| | |
| | "torch.ops.aten.batch_norm": None, |
| | "torch.ops.aten.cudnn_batch_norm": None, |
| | "torch.ops.aten._batch_norm_impl_index_backward": None, |
| | |
| | "torch.ops.aten.native_layer_norm": None, |
| | "torch.ops.aten.layer_norm": None, |
| | |
| | |
| | "torch.ops.aten.softmax.int": None, |
| | "torch.ops.aten.log_softmax.int": None, |
| | |
| | |
| | "torch.ops.aten._log_softmax_backward_data": None, |
| | |
| | |
| | "torch.ops.aten.std.dim": None, |
| | "torch.ops.aten.sum": None, |
| | |
| | "torch.ops.aten._grad_sum_to_size": None, |
| | "torch.ops.aten.sum_to_size": None, |
| | "torch.ops.aten._autocast_to_reduced_precision": None, |
| | "torch.ops.aten._autocast_to_full_precision": None, |
| | |
| | |
| | "torch.ops.aten.linear": None, |
| | "torch.ops.aten.gelu": None, |
| | |
| | |
| | "torch.ops.aten.leaky_relu": None, |
| | "torch.ops.aten.square": None, |
| | |
| | "torch.ops.aten.tanh_backward": None, |
| | |
| | |
| | |
| | |
| | "torch.ops.aten.flatten.using_ints": None, |
| |
|
| | |
| | |
| | |
| | "getattr": None, |
| | "_operator.getitem": None, |
| | } |
| |
|
| | super().__init__(support_dict) |
| |
|
| | def is_node_supported( |
| | self, submodules: t.Mapping[str, Module], node: Node |
| | ) -> bool: |
| |
|
| | |
| | if node.op not in CALLABLE_NODE_OPS: |
| | return False |
| |
|
| | |
| | |
| | if isinstance(node.target, OpOverload): |
| | target = _get_qualified_name(node.target.overloadpacket) |
| | if target in self._support_dict: |
| | return True |
| |
|
| | return super().is_node_supported(submodules, node) |
| |
|
| |
|
| | class NvFuserBackend: |
| | def __init__(self): |
| | self.supported_ops = NvFuserOperatorSupport() |
| |
|
| | |
| | self.partitioner_cache: Dict[GraphModule, GraphModule] = {} |
| |
|
| | |
| | self.prim_decomp_cache: Dict[GraphModule, GraphModule] = {} |
| |
|
| | def lower_to_prims_and_execute(self, graph_module: GraphModule, *args, **kwargs): |
| | |
| | |
| |
|
| | if graph_module in self.prim_decomp_cache: |
| | logger.debug("prim_decomp_cache hit!") |
| | prim_module = self.prim_decomp_cache[graph_module] |
| | else: |
| | prim_graph = torch.fx.Graph() |
| | DecompositionInterpreter(graph_module, prim_graph, decomposition_table=aten2prim_decomp).run(*args, **kwargs) |
| | prim_module = torch.fx.GraphModule(graph_module, prim_graph) |
| | self.prim_decomp_cache[graph_module] = prim_module |
| |
|
| | logger.debug("Lower to prims graph: ", prim_module.code) |
| |
|
| | |
| | return execute(prim_module, *args, executor="nvfuser") |
| |
|
| | def compile(self, graph_module: GraphModule) -> GraphModule: |
| | |
| | logger.debug("Compiling graph_module: ", graph_module.code) |
| |
|
| | |
| | if graph_module in self.partitioner_cache: |
| | logger.debug("partitioner_cache hit!") |
| | fused_graph_module = self.partitioner_cache[graph_module] |
| | else: |
| | partitioner = CapabilityBasedPartitioner( |
| | graph_module, self.supported_ops, allows_single_node_partition=False) |
| | fused_graph_module = partitioner.partition_and_fuse() |
| |
|
| | self.partitioner_cache[graph_module] = fused_graph_module |
| |
|
| | |
| | for node in fused_graph_module.graph.nodes: |
| | |
| | if node.op == "call_module" and "fused_" in node.name: |
| | fused_module = getattr(fused_graph_module, node.name) |
| | fused_module._wrapped_call = self.lower_to_prims_and_execute |
| |
|
| | return fused_graph_module |
| |
|
| | def __call__(self, graph_module: GraphModule, _) -> GraphModule: |
| | |
| | return self.compile(graph_module) |
| |
|