|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
|
import typing |
|
|
import warnings |
|
|
from collections import Counter |
|
|
from copy import copy |
|
|
from dataclasses import dataclass |
|
|
from numbers import Number |
|
|
from typing import (Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, |
|
|
TypeVar, Union) |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch import Tensor |
|
|
from torch.jit import TracerWarning, _get_trace_graph |
|
|
|
|
|
from mmengine.logging import print_log |
|
|
from .jit_handles import Handle |
|
|
|
|
|
T = TypeVar('T', bound='JitModelAnalysis') |
|
|
|
|
|
|
|
|
|
|
|
_IGNORED_OPS: Set[str] = { |
|
|
'aten::Int', |
|
|
'aten::ScalarImplicit', |
|
|
'aten::__and__', |
|
|
'aten::arange', |
|
|
'aten::bitwise_not', |
|
|
'aten::cat', |
|
|
'aten::chunk', |
|
|
'aten::clamp', |
|
|
'aten::clamp_', |
|
|
'aten::constant_pad_nd', |
|
|
'aten::contiguous', |
|
|
'aten::copy_', |
|
|
'aten::detach', |
|
|
'aten::dropout', |
|
|
'aten::empty', |
|
|
'aten::eq', |
|
|
'aten::expand', |
|
|
'aten::flatten', |
|
|
'aten::floor', |
|
|
'aten::floor_divide', |
|
|
'aten::full', |
|
|
'aten::full_like', |
|
|
'aten::gather', |
|
|
'aten::ge', |
|
|
'aten::gt', |
|
|
'aten::index', |
|
|
'aten::index_put_', |
|
|
'aten::masked_fill', |
|
|
'aten::max', |
|
|
'aten::narrow', |
|
|
'aten::new_empty', |
|
|
'aten::new_full', |
|
|
'aten::new_zeros', |
|
|
'aten::nonzero', |
|
|
'aten::ones', |
|
|
'aten::permute', |
|
|
'aten::relu', |
|
|
'aten::relu_', |
|
|
'aten::remainder', |
|
|
'aten::reshape', |
|
|
'aten::roll', |
|
|
'aten::select', |
|
|
'aten::size', |
|
|
'aten::slice', |
|
|
'aten::split', |
|
|
'aten::split_with_sizes', |
|
|
'aten::squeeze', |
|
|
'aten::stack', |
|
|
'aten::t', |
|
|
'aten::to', |
|
|
'aten::transpose', |
|
|
'aten::type_as', |
|
|
'aten::unbind', |
|
|
'aten::unsqueeze', |
|
|
'aten::unsqueeze_', |
|
|
'aten::view', |
|
|
'aten::zeros', |
|
|
'aten::zeros_like', |
|
|
} |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Statistics: |
|
|
"""For keeping track of the various model statistics recorded during |
|
|
analysis.""" |
|
|
|
|
|
counts: Dict[str, typing.Counter[str]] |
|
|
unsupported_ops: Dict[str, typing.Counter[str]] |
|
|
uncalled_mods: Set[str] |
|
|
|
|
|
|
|
|
def _named_modules_with_dup(model: nn.Module, |
|
|
prefix: str = '' |
|
|
) -> Iterable[Tuple[str, nn.Module]]: |
|
|
"""The same as `model.named_modules()`, except that it includes duplicated |
|
|
modules that have more than one name.""" |
|
|
yield prefix, model |
|
|
for name, module in model._modules.items(): |
|
|
if module is None: |
|
|
continue |
|
|
submodule_prefix = prefix + ('.' if prefix else '') + name |
|
|
yield from _named_modules_with_dup(module, submodule_prefix) |
|
|
|
|
|
|
|
|
def _named_modules_without_dup( |
|
|
model: nn.Module) -> Iterator[Tuple[str, nn.Module]]: |
|
|
"""Like .named_modules(), but the results are slightly different for some |
|
|
wrapped models.""" |
|
|
seen = set() |
|
|
for name, mod in _named_modules_with_dup(model): |
|
|
if mod not in seen: |
|
|
seen.add(mod) |
|
|
yield name, mod |
|
|
|
|
|
|
|
|
def _get_scoped_trace_graph( |
|
|
module: nn.Module, |
|
|
inputs: Union[Tensor, Tuple[Tensor, ...]], |
|
|
aliases: Dict[Union[str, nn.Module], str], |
|
|
) -> torch._C.Graph: |
|
|
"""Traces the provided module using torch.jit._get_trace_graph, but adds |
|
|
submodule scope information to each graph node. |
|
|
|
|
|
The resulting graph is in-lined and has all model parameters treated as |
|
|
inputs. The input model has the scope name '', while its descendants |
|
|
have names of the form 'child.grandchild.grandgrandchild...'. |
|
|
|
|
|
Args: |
|
|
model (nn.Module): The module to trace |
|
|
inputs (tuple): Inputs used during the trace of the model |
|
|
aliases (dict[str or nn.Module, str]): maps modules and module |
|
|
names to the canonical name to be used as the scope for |
|
|
that module. |
|
|
|
|
|
Returns: |
|
|
graph (torch._C.Graph): The pytorch JIT trace of the model |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ScopePushHook: |
|
|
|
|
|
def __init__(self, name: str) -> None: |
|
|
self.name = name |
|
|
|
|
|
def __call__(self, module: nn.Module, inputs: Any) -> Any: |
|
|
tracing_state = torch._C._get_tracing_state() |
|
|
if tracing_state: |
|
|
tracing_state.push_scope(self.name) |
|
|
return inputs |
|
|
|
|
|
class ScopePopHook: |
|
|
|
|
|
def __call__(self, module: nn.Module, inputs: Any, |
|
|
outputs: Any) -> Any: |
|
|
tracing_state = torch._C._get_tracing_state() |
|
|
if tracing_state: |
|
|
tracing_state.pop_scope() |
|
|
return outputs |
|
|
|
|
|
hook_handles: List[Any] = [] |
|
|
|
|
|
def register_hooks(mod: nn.Module, name: str) -> None: |
|
|
prehook = mod.register_forward_pre_hook(ScopePushHook(name)) |
|
|
posthook = mod.register_forward_hook(ScopePopHook()) |
|
|
hook_handles.append(prehook) |
|
|
hook_handles.append(posthook) |
|
|
|
|
|
|
|
|
module_list = (nn.parallel.distributed.DistributedDataParallel, |
|
|
nn.DataParallel) |
|
|
|
|
|
|
|
|
if isinstance(module, module_list): |
|
|
root_name = aliases[module] |
|
|
module = module.module |
|
|
register_hooks(module, root_name) |
|
|
|
|
|
for name, mod in _named_modules_without_dup(module): |
|
|
name = aliases[mod] |
|
|
register_hooks(mod, name) |
|
|
|
|
|
graph, _ = _get_trace_graph(module, inputs) |
|
|
|
|
|
for handle in hook_handles: |
|
|
handle.remove() |
|
|
|
|
|
return graph |
|
|
|
|
|
|
|
|
class JitModelAnalysis: |
|
|
"""Provides access to per-submodule model statistics obtained by tracing a |
|
|
model with pytorch's jit tracing functionality. |
|
|
|
|
|
Calculates a statistic on a per-operator basis using the provided set of |
|
|
functions that acts on the inputs and outputs to the operator, then |
|
|
aggregates this over modules in the model. Can return the aggregate |
|
|
statistic for any submodule in the model. Is lazily evaluated, and will |
|
|
perform the trace when a statistic is first requested. Changing the |
|
|
operator handles will cause the trace to be rerun on the next request. |
|
|
|
|
|
Submodules may be referred to using the module's name. The input model has |
|
|
name "", while its descendants have names of the form |
|
|
"child.grandchild.grandgrandchild...". |
|
|
|
|
|
An operator is treated as within the scope of a module if calling that |
|
|
module directly resulted in that operator being run. In particular, this |
|
|
means that calls to other functions owned by a module or explicit |
|
|
calls to module.forward(...) will not register resulting operators as |
|
|
contributing statistics to that module. |
|
|
|
|
|
We will trace the execution of `model.forward(inputs)`. This means |
|
|
inputs have to be tensors or tuple of tensors (see |
|
|
https://pytorch.org/docs/stable/generated/torch.jit.trace.html#torch.jit.trace). |
|
|
In order to trace other methods or unsupported input types, |
|
|
you may need to implement a wrapper module. |
|
|
|
|
|
Args: |
|
|
model: The model to analyze |
|
|
inputs: The inputs to the model for analysis. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model: nn.Module, |
|
|
inputs: Union[Tensor, Tuple[Tensor, ...]], |
|
|
) -> None: |
|
|
self._model = model |
|
|
self._inputs = inputs |
|
|
self._op_handles: Dict[str, Handle] = {} |
|
|
|
|
|
self._named_modules: Dict[str, nn.Module] = dict( |
|
|
_named_modules_with_dup(model)) |
|
|
|
|
|
|
|
|
self._aliases: Dict[Union[nn.Module, str], |
|
|
str] = self._get_aliases(model) |
|
|
self._stats: Optional[Statistics] = None |
|
|
|
|
|
self._ignored_ops: Set[str] = copy(_IGNORED_OPS) |
|
|
self.unsupported_ops_warnings(True) |
|
|
self.uncalled_modules_warnings(True) |
|
|
self.tracer_warnings('no_tracer_warning') |
|
|
self.ancestor_mode('owner') |
|
|
|
|
|
def total(self, module_name: str = '') -> int: |
|
|
"""Returns the total aggregated statistic across all operators for the |
|
|
requested module. |
|
|
|
|
|
Args: |
|
|
module_name (str): The submodule to get data for. Defaults to |
|
|
the entire model. |
|
|
|
|
|
Returns: |
|
|
int: The aggregated statistic. |
|
|
""" |
|
|
stats = self._analyze() |
|
|
module_name = self.canonical_module_name(module_name) |
|
|
total_count = sum(stats.counts[module_name].values()) |
|
|
return total_count |
|
|
|
|
|
def by_operator(self, module_name: str = '') -> typing.Counter[str]: |
|
|
"""Returns the statistics for a requested module, grouped by operator |
|
|
type. |
|
|
|
|
|
The operator handle determines the name associated with each |
|
|
operator type. |
|
|
|
|
|
Args: |
|
|
module_name (str): The submodule to get data for. Defaults |
|
|
to the entire model. |
|
|
|
|
|
Returns: |
|
|
Counter(str): The statistics for each operator. |
|
|
""" |
|
|
stats = self._analyze() |
|
|
module_name = self.canonical_module_name(module_name) |
|
|
return stats.counts[module_name] |
|
|
|
|
|
def by_module_and_operator(self) -> Dict[str, typing.Counter[str]]: |
|
|
"""Returns the statistics for all submodules, separated out by operator |
|
|
type for each submodule. |
|
|
|
|
|
The operator handle determines the name associated with |
|
|
each operator type. |
|
|
|
|
|
Returns: |
|
|
dict[str, Counter(str)]: The statistics for each submodule |
|
|
and each operator. Grouped by submodule names, then |
|
|
by operator name. |
|
|
""" |
|
|
stats = self._analyze() |
|
|
return stats.counts |
|
|
|
|
|
def by_module(self) -> typing.Counter[str]: |
|
|
"""Returns the statistics for all submodules, aggregated over all |
|
|
operators. |
|
|
|
|
|
Returns: |
|
|
Counter(str): statistics counter grouped by submodule names |
|
|
""" |
|
|
stats = self._analyze() |
|
|
summed_counts = Counter() |
|
|
for mod, results in stats.counts.items(): |
|
|
summed_counts[mod] = sum(results.values()) |
|
|
return summed_counts |
|
|
|
|
|
def unsupported_ops(self, module_name: str = '') -> typing.Counter[str]: |
|
|
"""Lists the number of operators that were encountered but unsupported |
|
|
because no operator handle is available for them. |
|
|
|
|
|
Does not include operators that are explicitly ignored. |
|
|
|
|
|
Args: |
|
|
module_name (str): The submodule to list unsupported ops. |
|
|
Defaults to the entire model. |
|
|
|
|
|
Returns: |
|
|
Counter(str): The number of occurrences each unsupported operator. |
|
|
""" |
|
|
if self._stats is None: |
|
|
raise RuntimeError('Analysis results should be computed ' |
|
|
'before calling unsupported_ops()') |
|
|
module_name = self.canonical_module_name(module_name) |
|
|
return self._stats.unsupported_ops[module_name] |
|
|
|
|
|
def uncalled_modules(self) -> Set[str]: |
|
|
"""Returns a set of submodules that were never called during the trace |
|
|
of the graph. |
|
|
|
|
|
This may be because they were unused, or because they were |
|
|
accessed via direct calls .forward() or with other python methods. |
|
|
In the latter case, statistics will not be attributed to the submodule, |
|
|
though the statistics will be included |
|
|
in the parent module. |
|
|
|
|
|
Returns: |
|
|
set[str]: The set of submodule names that were never called |
|
|
during the trace of the model. |
|
|
""" |
|
|
stats = self._analyze() |
|
|
return stats.uncalled_mods |
|
|
|
|
|
def set_op_handle(self, *args, |
|
|
**kwargs: Optional[Handle]) -> 'JitModelAnalysis': |
|
|
"""Sets additional operator handles, or replaces existing ones. |
|
|
|
|
|
If a handle is ``None``, the op will be explicitly ignored. Otherwise, |
|
|
handle should be a function that calculates the desirable statistic |
|
|
from an operator. The function must take two arguments, which are the |
|
|
inputs and outputs of the operator, in the form of |
|
|
``list(torch._C.Value)``. The function should return a counter object |
|
|
with per-operator statistics. |
|
|
|
|
|
Args: |
|
|
args: (str, Handle) pairs of operator names and handles. |
|
|
kwargs: mapping from operator names to handles. |
|
|
|
|
|
Examples: |
|
|
>>> handlers = {"aten::linear": my_handler} |
|
|
>>> counter.set_op_handle("aten::matmul", None, |
|
|
... "aten::bmm", my_handler2).set_op_handle(**handlers) |
|
|
""" |
|
|
self._stats = None |
|
|
if len(args) % 2 != 0: |
|
|
raise TypeError( |
|
|
'set_op_handle should be called with pairs of names and' |
|
|
'handles!') |
|
|
for name, handle in zip(args[::2], args[1::2]): |
|
|
kwargs[name] = handle |
|
|
for name, handle in kwargs.items(): |
|
|
if handle is None: |
|
|
self._ignored_ops.add(name) |
|
|
else: |
|
|
self._op_handles[name] = handle |
|
|
return self |
|
|
|
|
|
def clear_op_handles(self) -> 'JitModelAnalysis': |
|
|
"""Clears all operator handles currently set.""" |
|
|
self._op_handles = {} |
|
|
self._ignored_ops = copy(_IGNORED_OPS) |
|
|
self._stats = None |
|
|
return self |
|
|
|
|
|
def canonical_module_name(self, name: str) -> str: |
|
|
"""Returns the canonical module name of the given ``name``, which might |
|
|
be different from the given ``name`` if the module is shared. |
|
|
|
|
|
This is the name that will be used as a key when statistics are |
|
|
output using .by_module() and .by_module_and_operator(). |
|
|
|
|
|
Args: |
|
|
name (str): The name of the module to find the canonical name for. |
|
|
|
|
|
Returns: |
|
|
str: The canonical name of the module. |
|
|
""" |
|
|
|
|
|
assert isinstance(name, str), 'Module name must be a string.' |
|
|
if name in self._aliases: |
|
|
return self._aliases[name] |
|
|
else: |
|
|
raise KeyError('Requested module name is not among ' |
|
|
'the descendants of the analyzed model.') |
|
|
|
|
|
def copy( |
|
|
self, |
|
|
new_model: Optional[nn.Module] = None, |
|
|
new_inputs: Union[None, Tensor, Tuple[Tensor, ...]] = None, |
|
|
) -> 'JitModelAnalysis': |
|
|
"""Returns a copy of the :class:`JitModelAnalysis` object, keeping all |
|
|
settings, but on a new model or new inputs. |
|
|
|
|
|
Args: |
|
|
new_model (nn.Module or None): a new model for the new |
|
|
JitModelAnalysis. If None, uses the original model. |
|
|
Defaults to None. |
|
|
new_inputs (typing.Tuple[object, ...], optional): new inputs |
|
|
for the new JitModelAnalysis. If None, uses the original |
|
|
inputs. Defaults to None. |
|
|
|
|
|
Returns: |
|
|
JitModelAnalysis: the new model analysis object |
|
|
""" |
|
|
model = self._model if new_model is None else new_model |
|
|
inputs = self._inputs if new_inputs is None else new_inputs |
|
|
return (JitModelAnalysis(model=model, inputs=inputs).set_op_handle( |
|
|
**self._op_handles).unsupported_ops_warnings( |
|
|
self._enable_warn_unsupported_ops).uncalled_modules_warnings( |
|
|
self._enable_warn_uncalled_mods).tracer_warnings( |
|
|
self._warn_trace)) |
|
|
|
|
|
def tracer_warnings(self: T, mode: str) -> T: |
|
|
"""Sets which warnings to print when tracing the graph to calculate |
|
|
statistics. There are three modes. Defaults to 'no_tracer_warning'. |
|
|
Allowed values are: |
|
|
|
|
|
* 'all' : keeps all warnings raised while tracing |
|
|
* 'no_tracer_warning' : suppress torch.jit.TracerWarning only |
|
|
* 'none' : suppress all warnings raised while tracing |
|
|
|
|
|
Args: |
|
|
mode (str) : warning mode in one of the above values. |
|
|
""" |
|
|
if mode not in ['all', 'no_tracer_warning', 'none']: |
|
|
raise ValueError(f'Unrecognized tracer warning mode {mode}.') |
|
|
self._warn_trace = mode |
|
|
return self |
|
|
|
|
|
def ancestor_mode(self: T, mode: str) -> T: |
|
|
"""Sets how to determine the ancestor modules of an operator. Must be |
|
|
one of "owner" or "caller". |
|
|
|
|
|
* "caller": an operator belongs to all modules that are currently |
|
|
executing `forward()` at the time the operator is called. |
|
|
* "owner": an operator belongs to the last module that's executing |
|
|
`forward()` at the time the operator is called, plus this |
|
|
module's recursive parents. If an module has multiple parents |
|
|
(e.g. a shared module), only one will be picked. |
|
|
|
|
|
For most cases, a module only calls submodules it owns, so both |
|
|
options would work identically. In certain edge cases, this option |
|
|
will affect the hierarchy of results, but won't affect the total |
|
|
count. |
|
|
""" |
|
|
if mode not in ['owner', 'caller']: |
|
|
raise ValueError(f'Unrecognized ancestor mode: {mode}') |
|
|
self._ancestor_mode = mode |
|
|
return self |
|
|
|
|
|
def unsupported_ops_warnings(self: T, enabled: bool) -> T: |
|
|
"""Sets if warnings for unsupported operators are shown. |
|
|
|
|
|
Defaults to True. Counts of unsupported operators may be |
|
|
obtained from :meth:`unsupported_ops` regardless of this setting. |
|
|
|
|
|
Args: |
|
|
enabled (bool): Set to 'True' to show unsupported operator |
|
|
warnings. |
|
|
""" |
|
|
self._enable_warn_unsupported_ops = enabled |
|
|
return self |
|
|
|
|
|
def uncalled_modules_warnings(self: T, enabled: bool) -> T: |
|
|
"""Sets if warnings from uncalled submodules are shown. |
|
|
|
|
|
Defaults to true. A submodule is considered "uncalled" if it is never |
|
|
called during tracing. This may be because it is actually unused, or |
|
|
because it is accessed via calls to ``.forward()`` or other methods of |
|
|
the module. The set of uncalled modules may be obtained from |
|
|
:meth:`uncalled_modules` regardless of this setting. |
|
|
|
|
|
Args: |
|
|
enabled (bool): Set to 'True' to show warnings. |
|
|
""" |
|
|
self._enable_warn_uncalled_mods = enabled |
|
|
return self |
|
|
|
|
|
def _warn_unsupported_ops(self, ops: typing.Counter[str]) -> None: |
|
|
if not self._enable_warn_unsupported_ops: |
|
|
return |
|
|
|
|
|
for op, freq in ops.items(): |
|
|
print_log( |
|
|
'Unsupported operator {} encountered {} time(s)'.format( |
|
|
op, freq), |
|
|
'current', |
|
|
logging.WARNING, |
|
|
) |
|
|
|
|
|
def _warn_uncalled_mods(self, uncalled_mods: Set[str]) -> None: |
|
|
if not self._enable_warn_uncalled_mods: |
|
|
return |
|
|
uncalled_mods = {x for x in uncalled_mods if self._has_forward(x)} |
|
|
if len(uncalled_mods) == 0: |
|
|
return |
|
|
|
|
|
print_log( |
|
|
'The following submodules of the model were never ' |
|
|
'called during the trace of the graph. They may be ' |
|
|
'unused, or they were accessed by direct calls to ' |
|
|
'.forward() or via other python methods. In the latter ' |
|
|
'case they will have zeros for statistics, though their ' |
|
|
'statistics will still contribute to their parent calling ' |
|
|
'module.\n' + ', '.join(sorted(uncalled_mods)), 'current', |
|
|
logging.WARNING) |
|
|
|
|
|
def _get_aliases(self, |
|
|
model: nn.Module) -> Dict[Union[str, nn.Module], str]: |
|
|
aliases = {} |
|
|
for name, module in _named_modules_with_dup(model): |
|
|
if module not in aliases: |
|
|
aliases[module] = name |
|
|
aliases[name] = aliases[module] |
|
|
return aliases |
|
|
|
|
|
def _get_all_ancestors(self, module_name: str) -> Set[str]: |
|
|
"""Get all ancestors of the given module, defined by ownership. |
|
|
|
|
|
If the given module has multiple owners, use its canonical name. |
|
|
""" |
|
|
parts = self.canonical_module_name(module_name).split('.') |
|
|
res = {''} |
|
|
for k in range(len(parts) + 1): |
|
|
res.add('.'.join(parts[:k])) |
|
|
return res |
|
|
|
|
|
def _analyze(self) -> 'Statistics': |
|
|
|
|
|
stats = self._stats |
|
|
if stats is not None: |
|
|
return stats |
|
|
|
|
|
with warnings.catch_warnings(): |
|
|
if self._warn_trace == 'none': |
|
|
warnings.simplefilter('ignore') |
|
|
elif self._warn_trace == 'no_tracer_warning': |
|
|
warnings.filterwarnings('ignore', category=TracerWarning) |
|
|
graph = _get_scoped_trace_graph(self._model, self._inputs, |
|
|
self._aliases) |
|
|
|
|
|
|
|
|
|
|
|
counts = {} |
|
|
unsupported_ops = {} |
|
|
|
|
|
|
|
|
for _, mod in _named_modules_with_dup(self._model): |
|
|
name = self._aliases[mod] |
|
|
counts[name] = Counter() |
|
|
unsupported_ops[name] = Counter() |
|
|
|
|
|
all_seen = set() |
|
|
for node in graph.nodes(): |
|
|
kind = node.kind() |
|
|
if kind == 'prim::PythonOp': |
|
|
|
|
|
|
|
|
kind = kind + '.' + node.pyname() |
|
|
scope_names = node.scopeName().split('/') |
|
|
all_seen.update(scope_names) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self._ancestor_mode == 'caller': |
|
|
ancestors = set(scope_names) |
|
|
else: |
|
|
ancestors = self._get_all_ancestors(scope_names[-1]) |
|
|
all_seen.update(ancestors) |
|
|
if kind not in self._op_handles: |
|
|
if self._should_ignore_node(node): |
|
|
continue |
|
|
for name in ancestors: |
|
|
unsupported_ops[name][kind] += 1 |
|
|
else: |
|
|
inputs, outputs = list(node.inputs()), list(node.outputs()) |
|
|
op_counts = self._op_handles[kind](inputs, outputs) |
|
|
if isinstance(op_counts, Number): |
|
|
op_counts = Counter( |
|
|
{self._simplify_op_name(kind): op_counts}) |
|
|
for v in op_counts.values(): |
|
|
if not isinstance(v, (int, float, np.float64, np.int64)): |
|
|
raise ValueError( |
|
|
f'Invalid type {type(v)} for the flop count! ' |
|
|
'Please use a wider type to avoid overflow.') |
|
|
|
|
|
|
|
|
for name in ancestors: |
|
|
counts[name] += op_counts |
|
|
|
|
|
uncalled_mods = set(self._aliases.values()) - all_seen |
|
|
stats = Statistics( |
|
|
counts=counts, |
|
|
unsupported_ops=unsupported_ops, |
|
|
uncalled_mods=uncalled_mods) |
|
|
self._stats = stats |
|
|
self._warn_unsupported_ops(unsupported_ops['']) |
|
|
self._warn_uncalled_mods(uncalled_mods) |
|
|
return stats |
|
|
|
|
|
def _simplify_op_name(self, full_op_name: str) -> str: |
|
|
"""Get simplified name of the op without the preceding namespace, e.g. |
|
|
aten::batch_norm -> batch_norm.""" |
|
|
p = full_op_name.find('::') |
|
|
if p != -1: |
|
|
return full_op_name[p + 2:] |
|
|
else: |
|
|
return full_op_name |
|
|
|
|
|
def _has_forward(self, mod_name: str) -> bool: |
|
|
|
|
|
|
|
|
|
|
|
module = self._named_modules.get(mod_name) |
|
|
if module is None: |
|
|
return False |
|
|
module_type = type(module) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
no_forward_mods = { |
|
|
nn.ModuleList, nn.ModuleDict, nn.Module, nn.Identity |
|
|
} |
|
|
for mod in no_forward_mods: |
|
|
if module_type.forward is mod.forward: |
|
|
return False |
|
|
return True |
|
|
|
|
|
def _should_ignore_node(self, node) -> bool: |
|
|
kind = node.kind() |
|
|
if kind in self._ignored_ops: |
|
|
return True |
|
|
|
|
|
|
|
|
|
|
|
if kind.startswith('prim::PythonOp') or kind.startswith( |
|
|
'prim::CallFunction'): |
|
|
return False |
|
|
if kind.startswith('prim::'): |
|
|
return True |
|
|
return False |
|
|
|