|
|
import torch |
|
|
from torch.fx._symbolic_trace import Tracer |
|
|
from torch.fx.node import Target, Node, Argument |
|
|
from torch.nn.intrinsic import _FusedModule |
|
|
from typing import List, Callable, Tuple, Any, Dict, Optional |
|
|
|
|
|
__all__ = [ |
|
|
"QuantizationTracer", |
|
|
] |
|
|
|
|
|
class Scope(object): |
|
|
""" Scope object that records the module path and the module type |
|
|
of a module. Scope is used to track the information of the module |
|
|
that contains a Node in a Graph of GraphModule. For example:: |
|
|
|
|
|
class Sub(torch.nn.Module): |
|
|
def forward(self, x): |
|
|
# This will be a call_method Node in GraphModule, |
|
|
# scope for this would be (module_path="sub", module_type=Sub) |
|
|
return x.transpose(1, 2) |
|
|
|
|
|
class M(torch.nn.Module): |
|
|
def __init__(self): |
|
|
self.sub = Sub() |
|
|
|
|
|
def forward(self, x): |
|
|
# This will be a call_method Node as well, |
|
|
# scope for this would be (module_path="", None) |
|
|
x = x.transpose(1, 2) |
|
|
x = self.sub(x) |
|
|
return x |
|
|
|
|
|
""" |
|
|
|
|
|
def __init__(self, module_path: str, module_type: Any): |
|
|
super().__init__() |
|
|
self.module_path = module_path |
|
|
self.module_type = module_type |
|
|
|
|
|
|
|
|
class ScopeContextManager(object): |
|
|
""" A context manager to track the Scope of Node during symbolic tracing. |
|
|
When entering a forward function of a Module, we'll update the scope information of |
|
|
the current module, and when we exit, we'll restore the previous scope information. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, scope: Scope, current_module: torch.nn.Module, current_module_path: str |
|
|
): |
|
|
super().__init__() |
|
|
self.prev_module_type = scope.module_type |
|
|
self.prev_module_path = scope.module_path |
|
|
self.scope = scope |
|
|
self.scope.module_path = current_module_path |
|
|
self.scope.module_type = type(current_module) |
|
|
|
|
|
def __enter__(self): |
|
|
return |
|
|
|
|
|
def __exit__(self, *args): |
|
|
self.scope.module_path = self.prev_module_path |
|
|
self.scope.module_type = self.prev_module_type |
|
|
return |
|
|
|
|
|
class QuantizationTracer(Tracer): |
|
|
def __init__( |
|
|
self, skipped_module_names: List[str], skipped_module_classes: List[Callable] |
|
|
): |
|
|
super().__init__() |
|
|
self.skipped_module_names = skipped_module_names |
|
|
self.skipped_module_classes = skipped_module_classes |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.scope = Scope("", None) |
|
|
self.node_name_to_scope: Dict[str, Tuple[str, type]] = {} |
|
|
self.record_stack_traces = True |
|
|
|
|
|
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: |
|
|
return ( |
|
|
( |
|
|
(m.__module__.startswith("torch.nn") or m.__module__.startswith("torch.ao.nn")) |
|
|
and not isinstance(m, torch.nn.Sequential) |
|
|
) |
|
|
or module_qualified_name in self.skipped_module_names |
|
|
or type(m) in self.skipped_module_classes |
|
|
or isinstance(m, _FusedModule) |
|
|
) |
|
|
|
|
|
def call_module( |
|
|
self, |
|
|
m: torch.nn.Module, |
|
|
forward: Callable[..., Any], |
|
|
args: Tuple[Any, ...], |
|
|
kwargs: Dict[str, Any], |
|
|
) -> Any: |
|
|
module_qualified_name = self.path_of_module(m) |
|
|
|
|
|
|
|
|
with ScopeContextManager(self.scope, m, module_qualified_name): |
|
|
return super().call_module(m, forward, args, kwargs) |
|
|
|
|
|
def create_node( |
|
|
self, |
|
|
kind: str, |
|
|
target: Target, |
|
|
args: Tuple[Argument, ...], |
|
|
kwargs: Dict[str, Argument], |
|
|
name: Optional[str] = None, |
|
|
type_expr: Optional[Any] = None, |
|
|
) -> Node: |
|
|
node = super().create_node(kind, target, args, kwargs, name, type_expr) |
|
|
self.node_name_to_scope[node.name] = ( |
|
|
self.scope.module_path, |
|
|
self.scope.module_type, |
|
|
) |
|
|
return node |
|
|
|