| |
| |
| |
| |
|
|
| import logging |
| import typing |
| import warnings |
| from collections import Counter |
| from copy import copy |
| from dataclasses import dataclass |
| from numbers import Number |
| from typing import (Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, |
| TypeVar, Union) |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from torch import Tensor |
| from torch.jit import TracerWarning, _get_trace_graph |
|
|
| from mmengine.logging import print_log |
| from .jit_handles import Handle |
|
|
| T = TypeVar('T', bound='JitModelAnalysis') |
|
|
| |
| |
| _IGNORED_OPS: Set[str] = { |
| 'aten::Int', |
| 'aten::ScalarImplicit', |
| 'aten::__and__', |
| 'aten::arange', |
| 'aten::bitwise_not', |
| 'aten::cat', |
| 'aten::chunk', |
| 'aten::clamp', |
| 'aten::clamp_', |
| 'aten::constant_pad_nd', |
| 'aten::contiguous', |
| 'aten::copy_', |
| 'aten::detach', |
| 'aten::dropout', |
| 'aten::empty', |
| 'aten::eq', |
| 'aten::expand', |
| 'aten::flatten', |
| 'aten::floor', |
| 'aten::floor_divide', |
| 'aten::full', |
| 'aten::full_like', |
| 'aten::gather', |
| 'aten::ge', |
| 'aten::gt', |
| 'aten::index', |
| 'aten::index_put_', |
| 'aten::masked_fill', |
| 'aten::max', |
| 'aten::narrow', |
| 'aten::new_empty', |
| 'aten::new_full', |
| 'aten::new_zeros', |
| 'aten::nonzero', |
| 'aten::ones', |
| 'aten::permute', |
| 'aten::relu', |
| 'aten::relu_', |
| 'aten::remainder', |
| 'aten::reshape', |
| 'aten::roll', |
| 'aten::select', |
| 'aten::size', |
| 'aten::slice', |
| 'aten::split', |
| 'aten::split_with_sizes', |
| 'aten::squeeze', |
| 'aten::stack', |
| 'aten::t', |
| 'aten::to', |
| 'aten::transpose', |
| 'aten::type_as', |
| 'aten::unbind', |
| 'aten::unsqueeze', |
| 'aten::unsqueeze_', |
| 'aten::view', |
| 'aten::zeros', |
| 'aten::zeros_like', |
| } |
|
|
|
|
| @dataclass |
| class Statistics: |
| """For keeping track of the various model statistics recorded during |
| analysis.""" |
|
|
| counts: Dict[str, typing.Counter[str]] |
| unsupported_ops: Dict[str, typing.Counter[str]] |
| uncalled_mods: Set[str] |
|
|
|
|
| def _named_modules_with_dup(model: nn.Module, |
| prefix: str = '' |
| ) -> Iterable[Tuple[str, nn.Module]]: |
| """The same as `model.named_modules()`, except that it includes duplicated |
| modules that have more than one name.""" |
| yield prefix, model |
| for name, module in model._modules.items(): |
| if module is None: |
| continue |
| submodule_prefix = prefix + ('.' if prefix else '') + name |
| yield from _named_modules_with_dup(module, submodule_prefix) |
|
|
|
|
| def _named_modules_without_dup( |
| model: nn.Module) -> Iterator[Tuple[str, nn.Module]]: |
| """Like .named_modules(), but the results are slightly different for some |
| wrapped models.""" |
| seen = set() |
| for name, mod in _named_modules_with_dup(model): |
| if mod not in seen: |
| seen.add(mod) |
| yield name, mod |
|
|
|
|
| def _get_scoped_trace_graph( |
| module: nn.Module, |
| inputs: Union[Tensor, Tuple[Tensor, ...]], |
| aliases: Dict[Union[str, nn.Module], str], |
| ) -> torch._C.Graph: |
| """Traces the provided module using torch.jit._get_trace_graph, but adds |
| submodule scope information to each graph node. |
| |
| The resulting graph is in-lined and has all model parameters treated as |
| inputs. The input model has the scope name '', while its descendants |
| have names of the form 'child.grandchild.grandgrandchild...'. |
| |
| Args: |
| model (nn.Module): The module to trace |
| inputs (tuple): Inputs used during the trace of the model |
| aliases (dict[str or nn.Module, str]): maps modules and module |
| names to the canonical name to be used as the scope for |
| that module. |
| |
| Returns: |
| graph (torch._C.Graph): The pytorch JIT trace of the model |
| """ |
|
|
| |
| |
| |
| |
| class ScopePushHook: |
|
|
| def __init__(self, name: str) -> None: |
| self.name = name |
|
|
| def __call__(self, module: nn.Module, inputs: Any) -> Any: |
| tracing_state = torch._C._get_tracing_state() |
| if tracing_state: |
| tracing_state.push_scope(self.name) |
| return inputs |
|
|
| class ScopePopHook: |
|
|
| def __call__(self, module: nn.Module, inputs: Any, |
| outputs: Any) -> Any: |
| tracing_state = torch._C._get_tracing_state() |
| if tracing_state: |
| tracing_state.pop_scope() |
| return outputs |
|
|
| hook_handles: List[Any] = [] |
|
|
| def register_hooks(mod: nn.Module, name: str) -> None: |
| prehook = mod.register_forward_pre_hook(ScopePushHook(name)) |
| posthook = mod.register_forward_hook(ScopePopHook()) |
| hook_handles.append(prehook) |
| hook_handles.append(posthook) |
|
|
| |
| module_list = (nn.parallel.distributed.DistributedDataParallel, |
| nn.DataParallel) |
| |
| |
| if isinstance(module, module_list): |
| root_name = aliases[module] |
| module = module.module |
| register_hooks(module, root_name) |
|
|
| for name, mod in _named_modules_without_dup(module): |
| name = aliases[mod] |
| register_hooks(mod, name) |
|
|
| graph, _ = _get_trace_graph(module, inputs) |
|
|
| for handle in hook_handles: |
| handle.remove() |
|
|
| return graph |
|
|
|
|
| class JitModelAnalysis: |
| """Provides access to per-submodule model statistics obtained by tracing a |
| model with pytorch's jit tracing functionality. |
| |
| Calculates a statistic on a per-operator basis using the provided set of |
| functions that acts on the inputs and outputs to the operator, then |
| aggregates this over modules in the model. Can return the aggregate |
| statistic for any submodule in the model. Is lazily evaluated, and will |
| perform the trace when a statistic is first requested. Changing the |
| operator handles will cause the trace to be rerun on the next request. |
| |
| Submodules may be referred to using the module's name. The input model has |
| name "", while its descendants have names of the form |
| "child.grandchild.grandgrandchild...". |
| |
| An operator is treated as within the scope of a module if calling that |
| module directly resulted in that operator being run. In particular, this |
| means that calls to other functions owned by a module or explicit |
| calls to module.forward(...) will not register resulting operators as |
| contributing statistics to that module. |
| |
| We will trace the execution of `model.forward(inputs)`. This means |
| inputs have to be tensors or tuple of tensors (see |
| https://pytorch.org/docs/stable/generated/torch.jit.trace.html#torch.jit.trace). |
| In order to trace other methods or unsupported input types, |
| you may need to implement a wrapper module. |
| |
| Args: |
| model: The model to analyze |
| inputs: The inputs to the model for analysis. |
| """ |
|
|
| def __init__( |
| self, |
| model: nn.Module, |
| inputs: Union[Tensor, Tuple[Tensor, ...]], |
| ) -> None: |
| self._model = model |
| self._inputs = inputs |
| self._op_handles: Dict[str, Handle] = {} |
| |
| self._named_modules: Dict[str, nn.Module] = dict( |
| _named_modules_with_dup(model)) |
| |
| |
| self._aliases: Dict[Union[nn.Module, str], |
| str] = self._get_aliases(model) |
| self._stats: Optional[Statistics] = None |
|
|
| self._ignored_ops: Set[str] = copy(_IGNORED_OPS) |
| self.unsupported_ops_warnings(True) |
| self.uncalled_modules_warnings(True) |
| self.tracer_warnings('no_tracer_warning') |
| self.ancestor_mode('owner') |
|
|
| def total(self, module_name: str = '') -> int: |
| """Returns the total aggregated statistic across all operators for the |
| requested module. |
| |
| Args: |
| module_name (str): The submodule to get data for. Defaults to |
| the entire model. |
| |
| Returns: |
| int: The aggregated statistic. |
| """ |
| stats = self._analyze() |
| module_name = self.canonical_module_name(module_name) |
| total_count = sum(stats.counts[module_name].values()) |
| return total_count |
|
|
| def by_operator(self, module_name: str = '') -> typing.Counter[str]: |
| """Returns the statistics for a requested module, grouped by operator |
| type. |
| |
| The operator handle determines the name associated with each |
| operator type. |
| |
| Args: |
| module_name (str): The submodule to get data for. Defaults |
| to the entire model. |
| |
| Returns: |
| Counter(str): The statistics for each operator. |
| """ |
| stats = self._analyze() |
| module_name = self.canonical_module_name(module_name) |
| return stats.counts[module_name] |
|
|
| def by_module_and_operator(self) -> Dict[str, typing.Counter[str]]: |
| """Returns the statistics for all submodules, separated out by operator |
| type for each submodule. |
| |
| The operator handle determines the name associated with |
| each operator type. |
| |
| Returns: |
| dict[str, Counter(str)]: The statistics for each submodule |
| and each operator. Grouped by submodule names, then |
| by operator name. |
| """ |
| stats = self._analyze() |
| return stats.counts |
|
|
| def by_module(self) -> typing.Counter[str]: |
| """Returns the statistics for all submodules, aggregated over all |
| operators. |
| |
| Returns: |
| Counter(str): statistics counter grouped by submodule names |
| """ |
| stats = self._analyze() |
| summed_counts = Counter() |
| for mod, results in stats.counts.items(): |
| summed_counts[mod] = sum(results.values()) |
| return summed_counts |
|
|
| def unsupported_ops(self, module_name: str = '') -> typing.Counter[str]: |
| """Lists the number of operators that were encountered but unsupported |
| because no operator handle is available for them. |
| |
| Does not include operators that are explicitly ignored. |
| |
| Args: |
| module_name (str): The submodule to list unsupported ops. |
| Defaults to the entire model. |
| |
| Returns: |
| Counter(str): The number of occurrences each unsupported operator. |
| """ |
| if self._stats is None: |
| raise RuntimeError('Analysis results should be computed ' |
| 'before calling unsupported_ops()') |
| module_name = self.canonical_module_name(module_name) |
| return self._stats.unsupported_ops[module_name] |
|
|
| def uncalled_modules(self) -> Set[str]: |
| """Returns a set of submodules that were never called during the trace |
| of the graph. |
| |
| This may be because they were unused, or because they were |
| accessed via direct calls .forward() or with other python methods. |
| In the latter case, statistics will not be attributed to the submodule, |
| though the statistics will be included |
| in the parent module. |
| |
| Returns: |
| set[str]: The set of submodule names that were never called |
| during the trace of the model. |
| """ |
| stats = self._analyze() |
| return stats.uncalled_mods |
|
|
| def set_op_handle(self, *args, |
| **kwargs: Optional[Handle]) -> 'JitModelAnalysis': |
| """Sets additional operator handles, or replaces existing ones. |
| |
| If a handle is ``None``, the op will be explicitly ignored. Otherwise, |
| handle should be a function that calculates the desirable statistic |
| from an operator. The function must take two arguments, which are the |
| inputs and outputs of the operator, in the form of |
| ``list(torch._C.Value)``. The function should return a counter object |
| with per-operator statistics. |
| |
| Args: |
| args: (str, Handle) pairs of operator names and handles. |
| kwargs: mapping from operator names to handles. |
| |
| Examples: |
| >>> handlers = {"aten::linear": my_handler} |
| >>> counter.set_op_handle("aten::matmul", None, |
| ... "aten::bmm", my_handler2).set_op_handle(**handlers) |
| """ |
| self._stats = None |
| if len(args) % 2 != 0: |
| raise TypeError( |
| 'set_op_handle should be called with pairs of names and' |
| 'handles!') |
| for name, handle in zip(args[::2], args[1::2]): |
| kwargs[name] = handle |
| for name, handle in kwargs.items(): |
| if handle is None: |
| self._ignored_ops.add(name) |
| else: |
| self._op_handles[name] = handle |
| return self |
|
|
| def clear_op_handles(self) -> 'JitModelAnalysis': |
| """Clears all operator handles currently set.""" |
| self._op_handles = {} |
| self._ignored_ops = copy(_IGNORED_OPS) |
| self._stats = None |
| return self |
|
|
| def canonical_module_name(self, name: str) -> str: |
| """Returns the canonical module name of the given ``name``, which might |
| be different from the given ``name`` if the module is shared. |
| |
| This is the name that will be used as a key when statistics are |
| output using .by_module() and .by_module_and_operator(). |
| |
| Args: |
| name (str): The name of the module to find the canonical name for. |
| |
| Returns: |
| str: The canonical name of the module. |
| """ |
| |
| assert isinstance(name, str), 'Module name must be a string.' |
| if name in self._aliases: |
| return self._aliases[name] |
| else: |
| raise KeyError('Requested module name is not among ' |
| 'the descendants of the analyzed model.') |
|
|
| def copy( |
| self, |
| new_model: Optional[nn.Module] = None, |
| new_inputs: Union[None, Tensor, Tuple[Tensor, ...]] = None, |
| ) -> 'JitModelAnalysis': |
| """Returns a copy of the :class:`JitModelAnalysis` object, keeping all |
| settings, but on a new model or new inputs. |
| |
| Args: |
| new_model (nn.Module or None): a new model for the new |
| JitModelAnalysis. If None, uses the original model. |
| Defaults to None. |
| new_inputs (typing.Tuple[object, ...], optional): new inputs |
| for the new JitModelAnalysis. If None, uses the original |
| inputs. Defaults to None. |
| |
| Returns: |
| JitModelAnalysis: the new model analysis object |
| """ |
| model = self._model if new_model is None else new_model |
| inputs = self._inputs if new_inputs is None else new_inputs |
| return (JitModelAnalysis(model=model, inputs=inputs).set_op_handle( |
| **self._op_handles).unsupported_ops_warnings( |
| self._enable_warn_unsupported_ops).uncalled_modules_warnings( |
| self._enable_warn_uncalled_mods).tracer_warnings( |
| self._warn_trace)) |
|
|
| def tracer_warnings(self: T, mode: str) -> T: |
| """Sets which warnings to print when tracing the graph to calculate |
| statistics. There are three modes. Defaults to 'no_tracer_warning'. |
| Allowed values are: |
| |
| * 'all' : keeps all warnings raised while tracing |
| * 'no_tracer_warning' : suppress torch.jit.TracerWarning only |
| * 'none' : suppress all warnings raised while tracing |
| |
| Args: |
| mode (str) : warning mode in one of the above values. |
| """ |
| if mode not in ['all', 'no_tracer_warning', 'none']: |
| raise ValueError(f'Unrecognized tracer warning mode {mode}.') |
| self._warn_trace = mode |
| return self |
|
|
| def ancestor_mode(self: T, mode: str) -> T: |
| """Sets how to determine the ancestor modules of an operator. Must be |
| one of "owner" or "caller". |
| |
| * "caller": an operator belongs to all modules that are currently |
| executing `forward()` at the time the operator is called. |
| * "owner": an operator belongs to the last module that's executing |
| `forward()` at the time the operator is called, plus this |
| module's recursive parents. If an module has multiple parents |
| (e.g. a shared module), only one will be picked. |
| |
| For most cases, a module only calls submodules it owns, so both |
| options would work identically. In certain edge cases, this option |
| will affect the hierarchy of results, but won't affect the total |
| count. |
| """ |
| if mode not in ['owner', 'caller']: |
| raise ValueError(f'Unrecognized ancestor mode: {mode}') |
| self._ancestor_mode = mode |
| return self |
|
|
| def unsupported_ops_warnings(self: T, enabled: bool) -> T: |
| """Sets if warnings for unsupported operators are shown. |
| |
| Defaults to True. Counts of unsupported operators may be |
| obtained from :meth:`unsupported_ops` regardless of this setting. |
| |
| Args: |
| enabled (bool): Set to 'True' to show unsupported operator |
| warnings. |
| """ |
| self._enable_warn_unsupported_ops = enabled |
| return self |
|
|
| def uncalled_modules_warnings(self: T, enabled: bool) -> T: |
| """Sets if warnings from uncalled submodules are shown. |
| |
| Defaults to true. A submodule is considered "uncalled" if it is never |
| called during tracing. This may be because it is actually unused, or |
| because it is accessed via calls to ``.forward()`` or other methods of |
| the module. The set of uncalled modules may be obtained from |
| :meth:`uncalled_modules` regardless of this setting. |
| |
| Args: |
| enabled (bool): Set to 'True' to show warnings. |
| """ |
| self._enable_warn_uncalled_mods = enabled |
| return self |
|
|
| def _warn_unsupported_ops(self, ops: typing.Counter[str]) -> None: |
| if not self._enable_warn_unsupported_ops: |
| return |
|
|
| for op, freq in ops.items(): |
| print_log( |
| 'Unsupported operator {} encountered {} time(s)'.format( |
| op, freq), |
| 'current', |
| logging.WARNING, |
| ) |
|
|
| def _warn_uncalled_mods(self, uncalled_mods: Set[str]) -> None: |
| if not self._enable_warn_uncalled_mods: |
| return |
| uncalled_mods = {x for x in uncalled_mods if self._has_forward(x)} |
| if len(uncalled_mods) == 0: |
| return |
|
|
| print_log( |
| 'The following submodules of the model were never ' |
| 'called during the trace of the graph. They may be ' |
| 'unused, or they were accessed by direct calls to ' |
| '.forward() or via other python methods. In the latter ' |
| 'case they will have zeros for statistics, though their ' |
| 'statistics will still contribute to their parent calling ' |
| 'module.\n' + ', '.join(sorted(uncalled_mods)), 'current', |
| logging.WARNING) |
|
|
| def _get_aliases(self, |
| model: nn.Module) -> Dict[Union[str, nn.Module], str]: |
| aliases = {} |
| for name, module in _named_modules_with_dup(model): |
| if module not in aliases: |
| aliases[module] = name |
| aliases[name] = aliases[module] |
| return aliases |
|
|
| def _get_all_ancestors(self, module_name: str) -> Set[str]: |
| """Get all ancestors of the given module, defined by ownership. |
| |
| If the given module has multiple owners, use its canonical name. |
| """ |
| parts = self.canonical_module_name(module_name).split('.') |
| res = {''} |
| for k in range(len(parts) + 1): |
| res.add('.'.join(parts[:k])) |
| return res |
|
|
| def _analyze(self) -> 'Statistics': |
| |
| stats = self._stats |
| if stats is not None: |
| return stats |
|
|
| with warnings.catch_warnings(): |
| if self._warn_trace == 'none': |
| warnings.simplefilter('ignore') |
| elif self._warn_trace == 'no_tracer_warning': |
| warnings.filterwarnings('ignore', category=TracerWarning) |
| graph = _get_scoped_trace_graph(self._model, self._inputs, |
| self._aliases) |
|
|
| |
| |
| counts = {} |
| unsupported_ops = {} |
| |
| |
| for _, mod in _named_modules_with_dup(self._model): |
| name = self._aliases[mod] |
| counts[name] = Counter() |
| unsupported_ops[name] = Counter() |
|
|
| all_seen = set() |
| for node in graph.nodes(): |
| kind = node.kind() |
| if kind == 'prim::PythonOp': |
| |
| |
| kind = kind + '.' + node.pyname() |
| scope_names = node.scopeName().split('/') |
| all_seen.update(scope_names) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| if self._ancestor_mode == 'caller': |
| ancestors = set(scope_names) |
| else: |
| ancestors = self._get_all_ancestors(scope_names[-1]) |
| all_seen.update(ancestors) |
| if kind not in self._op_handles: |
| if self._should_ignore_node(node): |
| continue |
| for name in ancestors: |
| unsupported_ops[name][kind] += 1 |
| else: |
| inputs, outputs = list(node.inputs()), list(node.outputs()) |
| op_counts = self._op_handles[kind](inputs, outputs) |
| if isinstance(op_counts, Number): |
| op_counts = Counter( |
| {self._simplify_op_name(kind): op_counts}) |
| for v in op_counts.values(): |
| if not isinstance(v, (int, float, np.float64, np.int64)): |
| raise ValueError( |
| f'Invalid type {type(v)} for the flop count! ' |
| 'Please use a wider type to avoid overflow.') |
|
|
| |
| for name in ancestors: |
| counts[name] += op_counts |
|
|
| uncalled_mods = set(self._aliases.values()) - all_seen |
| stats = Statistics( |
| counts=counts, |
| unsupported_ops=unsupported_ops, |
| uncalled_mods=uncalled_mods) |
| self._stats = stats |
| self._warn_unsupported_ops(unsupported_ops['']) |
| self._warn_uncalled_mods(uncalled_mods) |
| return stats |
|
|
| def _simplify_op_name(self, full_op_name: str) -> str: |
| """Get simplified name of the op without the preceding namespace, e.g. |
| aten::batch_norm -> batch_norm.""" |
| p = full_op_name.find('::') |
| if p != -1: |
| return full_op_name[p + 2:] |
| else: |
| return full_op_name |
|
|
| def _has_forward(self, mod_name: str) -> bool: |
| |
| |
| |
| module = self._named_modules.get(mod_name) |
| if module is None: |
| return False |
| module_type = type(module) |
| |
| |
| |
| |
| |
| |
| |
| no_forward_mods = { |
| nn.ModuleList, nn.ModuleDict, nn.Module, nn.Identity |
| } |
| for mod in no_forward_mods: |
| if module_type.forward is mod.forward: |
| return False |
| return True |
|
|
| def _should_ignore_node(self, node) -> bool: |
| kind = node.kind() |
| if kind in self._ignored_ops: |
| return True |
| |
| |
| |
| if kind.startswith('prim::PythonOp') or kind.startswith( |
| 'prim::CallFunction'): |
| return False |
| if kind.startswith('prim::'): |
| return True |
| return False |
|
|