|
|
import torch |
|
|
from torch.fx.graph import Node, Graph |
|
|
from ..utils import _parent_name |
|
|
from torch.ao.quantization.quantization_types import NodePattern, Pattern |
|
|
from ..fuser_method_mappings import get_fuser_method_new |
|
|
from abc import ABC, abstractmethod |
|
|
from typing import Any, Callable, Dict, Optional, Union, List |
|
|
from .custom_config import FuseCustomConfig |
|
|
from .match_utils import MatchAllNode |
|
|
from torch.nn.utils.parametrize import type_before_parametrizations |
|
|
|
|
|
__all__ = [ |
|
|
"DefaultFuseHandler", |
|
|
"FuseHandler", |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FuseHandler(ABC): |
|
|
""" Base handler class for the fusion patterns |
|
|
""" |
|
|
def __init__(self, node: Node): |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
def fuse(self, |
|
|
load_arg: Callable, |
|
|
named_modules: Dict[str, torch.nn.Module], |
|
|
fused_graph: Graph, |
|
|
root_node: Node, |
|
|
extra_inputs: List[Any], |
|
|
matched_node_pattern: NodePattern, |
|
|
fuse_custom_config: FuseCustomConfig, |
|
|
fuser_method_mapping: Optional[Dict[Pattern, Union[torch.nn.Sequential, Callable]]], |
|
|
is_qat: bool) -> Node: |
|
|
pass |
|
|
|
|
|
|
|
|
class DefaultFuseHandler(FuseHandler): |
|
|
def __init__( |
|
|
self, |
|
|
node: Node): |
|
|
super().__init__(node) |
|
|
|
|
|
def fuse(self, |
|
|
load_arg: Callable, |
|
|
named_modules: Dict[str, torch.nn.Module], |
|
|
fused_graph: Graph, |
|
|
root_node: Node, |
|
|
extra_inputs: List[Any], |
|
|
matched_node_pattern: NodePattern, |
|
|
fuse_custom_config: FuseCustomConfig, |
|
|
fuser_method_mapping: Optional[Dict[Pattern, Union[torch.nn.Sequential, Callable]]], |
|
|
is_qat: bool) -> Node: |
|
|
assert root_node.op == "call_module", "Expecting module node to be a call_module Node" |
|
|
root_module = named_modules[str(root_node.target)] |
|
|
|
|
|
def get_modules(pattern): |
|
|
""" Given a node pattern, extract the corresponding modules |
|
|
e.g. input: (relu_node, (bn_node, conv_node)) |
|
|
output: (relu_module, (bn_module, conv_module)) |
|
|
""" |
|
|
if isinstance(pattern, (tuple, list)): |
|
|
n, *args = pattern |
|
|
modules: List[torch.nn.Module] = [] |
|
|
modules.append(get_modules(n)) |
|
|
for a in args: |
|
|
modules.append(get_modules(a)) |
|
|
return tuple(modules) |
|
|
else: |
|
|
n = pattern |
|
|
if n.op == "call_module": |
|
|
return named_modules[n.target] |
|
|
elif n.op == "call_function" and n.target == torch.nn.functional.relu: |
|
|
relu = torch.nn.ReLU() |
|
|
relu.training = root_module.training |
|
|
return relu |
|
|
elif n.op == "call_function" or n.op == "call_method": |
|
|
return n.target |
|
|
else: |
|
|
return MatchAllNode |
|
|
|
|
|
|
|
|
matched_modules = get_modules(matched_node_pattern) |
|
|
|
|
|
def get_matched_types(m): |
|
|
if isinstance(m, tuple): |
|
|
return tuple(map(get_matched_types, m)) |
|
|
if isinstance(m, torch.nn.Module): |
|
|
return type_before_parametrizations(m) |
|
|
return m |
|
|
|
|
|
matched_module_types = get_matched_types(matched_modules) |
|
|
module_parent_name, module_name = _parent_name(root_node.target) |
|
|
fuser_method = get_fuser_method_new(matched_module_types, fuser_method_mapping) |
|
|
|
|
|
|
|
|
fused_module = fuser_method(is_qat, *matched_modules) |
|
|
setattr(named_modules[module_parent_name], module_name, fused_module) |
|
|
extra_args = [] |
|
|
for input in extra_inputs: |
|
|
extra_args.append(load_arg(input)) |
|
|
node = fused_graph.node_copy(root_node, load_arg) |
|
|
args = list(node.args) |
|
|
args.extend(extra_args) |
|
|
node.args = tuple(args) |
|
|
return node |
|
|
|