| |
|
|
| import typing |
| from collections import defaultdict |
| from typing import Any, Counter, DefaultDict, Dict, Optional, Tuple, Union |
|
|
| import torch.nn as nn |
| from rich import box |
| from rich.console import Console |
| from rich.table import Table |
| from torch import Tensor |
|
|
| from .jit_analysis import JitModelAnalysis |
| from .jit_handles import (Handle, addmm_flop_jit, batchnorm_flop_jit, |
| bmm_flop_jit, conv_flop_jit, einsum_flop_jit, |
| elementwise_flop_counter, generic_activation_jit, |
| linear_flop_jit, matmul_flop_jit, norm_flop_counter) |
|
|
| |
| _DEFAULT_SUPPORTED_FLOP_OPS: Dict[str, Handle] = { |
| 'aten::addmm': addmm_flop_jit, |
| 'aten::bmm': bmm_flop_jit, |
| 'aten::_convolution': conv_flop_jit, |
| 'aten::einsum': einsum_flop_jit, |
| 'aten::matmul': matmul_flop_jit, |
| 'aten::mm': matmul_flop_jit, |
| 'aten::linear': linear_flop_jit, |
| |
| |
| 'aten::batch_norm': batchnorm_flop_jit, |
| 'aten::group_norm': norm_flop_counter(2), |
| 'aten::layer_norm': norm_flop_counter(2), |
| 'aten::instance_norm': norm_flop_counter(1), |
| 'aten::upsample_nearest2d': elementwise_flop_counter(0, 1), |
| 'aten::upsample_bilinear2d': elementwise_flop_counter(0, 4), |
| 'aten::adaptive_avg_pool2d': elementwise_flop_counter(1, 0), |
| 'aten::grid_sampler': elementwise_flop_counter(0, 4), |
| } |
|
|
| |
| |
| _DEFAULT_SUPPORTED_ACT_OPS: Dict[str, Handle] = { |
| 'aten::_convolution': generic_activation_jit('conv'), |
| 'aten::addmm': generic_activation_jit(), |
| 'aten::bmm': generic_activation_jit(), |
| 'aten::einsum': generic_activation_jit(), |
| 'aten::matmul': generic_activation_jit(), |
| 'aten::linear': generic_activation_jit(), |
| } |
|
|
|
|
| class FlopAnalyzer(JitModelAnalysis): |
| """Provides access to per-submodule model flop count obtained by tracing a |
| model with pytorch's jit tracing functionality. |
| |
| By default, comes with standard flop counters for a few common operators. |
| |
| Note: |
| - Flop is not a well-defined concept. We just produce our best |
| estimate. |
| - We count one fused multiply-add as one flop. |
| |
| Handles for additional operators may be added, or the default ones |
| overwritten, using the ``.set_op_handle(name, func)`` method. |
| See the method documentation for details. |
| Flop counts can be obtained as: |
| |
| - ``.total(module_name="")``: total flop count for the module |
| - ``.by_operator(module_name="")``: flop counts for the module, as a |
| Counter over different operator types |
| - ``.by_module()``: Counter of flop counts for all submodules |
| - ``.by_module_and_operator()``: dictionary indexed by descendant of |
| Counters over different operator types |
| |
| An operator is treated as within a module if it is executed inside the |
| module's ``__call__`` method. Note that this does not include calls to |
| other methods of the module or explicit calls to ``module.forward(...)``. |
| |
| Modified from |
| https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/flop_count.py |
| |
| Args: |
| model (nn.Module): The model to analyze. |
| inputs (Union[Tensor, Tuple[Tensor, ...]]): The input to the model. |
| |
| Examples: |
| >>> import torch.nn as nn |
| >>> import torch |
| >>> class TestModel(nn.Module): |
| ... def __init__(self): |
| ... super().__init__() |
| ... self.fc = nn.Linear(in_features=1000, out_features=10) |
| ... self.conv = nn.Conv2d( |
| ... in_channels=3, out_channels=10, kernel_size=1 |
| ... ) |
| ... self.act = nn.ReLU() |
| ... def forward(self, x): |
| ... return self.fc(self.act(self.conv(x)).flatten(1)) |
| >>> model = TestModel() |
| >>> inputs = (torch.randn((1,3,10,10)),) |
| >>> flops = FlopAnalyzer(model, inputs) |
| >>> flops.total() |
| 13000 |
| >>> flops.total("fc") |
| 10000 |
| >>> flops.by_operator() |
| Counter({"addmm" : 10000, "conv" : 3000}) |
| >>> flops.by_module() |
| Counter({"" : 13000, "fc" : 10000, "conv" : 3000, "act" : 0}) |
| >>> flops.by_module_and_operator() |
| {"" : Counter({"addmm" : 10000, "conv" : 3000}), |
| "fc" : Counter({"addmm" : 10000}), |
| "conv" : Counter({"conv" : 3000}), |
| "act" : Counter() |
| } |
| """ |
|
|
| def __init__( |
| self, |
| model: nn.Module, |
| inputs: Union[Tensor, Tuple[Tensor, ...]], |
| ) -> None: |
| super().__init__(model=model, inputs=inputs) |
| self.set_op_handle(**_DEFAULT_SUPPORTED_FLOP_OPS) |
|
|
| __init__.__doc__ = JitModelAnalysis.__init__.__doc__ |
|
|
|
|
| class ActivationAnalyzer(JitModelAnalysis): |
| """Provides access to per-submodule model activation count obtained by |
| tracing a model with pytorch's jit tracing functionality. |
| |
| By default, comes with standard activation counters for convolutional and |
| dot-product operators. Handles for additional operators may be added, or |
| the default ones overwritten, using the ``.set_op_handle(name, func)`` |
| method. See the method documentation for details. Activation counts can be |
| obtained as: |
| |
| - ``.total(module_name="")``: total activation count for a module |
| - ``.by_operator(module_name="")``: activation counts for the module, |
| as a Counter over different operator types |
| - ``.by_module()``: Counter of activation counts for all submodules |
| - ``.by_module_and_operator()``: dictionary indexed by descendant of |
| Counters over different operator types |
| |
| An operator is treated as within a module if it is executed inside the |
| module's ``__call__`` method. Note that this does not include calls to |
| other methods of the module or explicit calls to ``module.forward(...)``. |
| |
| Modified from |
| https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/activation_count.py |
| |
| Args: |
| model (nn.Module): The model to analyze. |
| inputs (Union[Tensor, Tuple[Tensor, ...]]): The input to the model. |
| |
| Examples: |
| >>> import torch.nn as nn |
| >>> import torch |
| >>> class TestModel(nn.Module): |
| ... def __init__(self): |
| ... super().__init__() |
| ... self.fc = nn.Linear(in_features=1000, out_features=10) |
| ... self.conv = nn.Conv2d( |
| ... in_channels=3, out_channels=10, kernel_size=1 |
| ... ) |
| ... self.act = nn.ReLU() |
| ... def forward(self, x): |
| ... return self.fc(self.act(self.conv(x)).flatten(1)) |
| >>> model = TestModel() |
| >>> inputs = (torch.randn((1,3,10,10)),) |
| >>> acts = ActivationAnalyzer(model, inputs) |
| >>> acts.total() |
| 1010 |
| >>> acts.total("fc") |
| 10 |
| >>> acts.by_operator() |
| Counter({"conv" : 1000, "addmm" : 10}) |
| >>> acts.by_module() |
| Counter({"" : 1010, "fc" : 10, "conv" : 1000, "act" : 0}) |
| >>> acts.by_module_and_operator() |
| {"" : Counter({"conv" : 1000, "addmm" : 10}), |
| "fc" : Counter({"addmm" : 10}), |
| "conv" : Counter({"conv" : 1000}), |
| "act" : Counter() |
| } |
| """ |
|
|
| def __init__( |
| self, |
| model: nn.Module, |
| inputs: Union[Tensor, Tuple[Tensor, ...]], |
| ) -> None: |
| super().__init__(model=model, inputs=inputs) |
| self.set_op_handle(**_DEFAULT_SUPPORTED_ACT_OPS) |
|
|
| __init__.__doc__ = JitModelAnalysis.__init__.__doc__ |
|
|
|
|
| def flop_count( |
| model: nn.Module, |
| inputs: Tuple[Any, ...], |
| supported_ops: Optional[Dict[str, Handle]] = None, |
| ) -> Tuple[DefaultDict[str, float], Counter[str]]: |
| """Given a model and an input to the model, compute the per-operator Gflops |
| of the given model. |
| |
| Adopted from |
| https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/flop_count.py |
| |
| Args: |
| model (nn.Module): The model to compute flop counts. |
| inputs (tuple): Inputs that are passed to `model` to count flops. |
| Inputs need to be in a tuple. |
| supported_ops (dict(str,Callable) or None) : provide additional |
| handlers for extra ops, or overwrite the existing handlers for |
| convolution and matmul and einsum. The key is operator name and |
| the value is a function that takes (inputs, outputs) of the op. |
| We count one Multiply-Add as one FLOP. |
| |
| Returns: |
| tuple[defaultdict, Counter]: A dictionary that records the number of |
| gflops for each operation and a Counter that records the number of |
| unsupported operations. |
| """ |
| if supported_ops is None: |
| supported_ops = {} |
| flop_counter = FlopAnalyzer(model, inputs).set_op_handle(**supported_ops) |
| giga_flops = defaultdict(float) |
| for op, flop in flop_counter.by_operator().items(): |
| giga_flops[op] = flop / 1e9 |
| return giga_flops, flop_counter.unsupported_ops() |
|
|
|
|
| def activation_count( |
| model: nn.Module, |
| inputs: Tuple[Any, ...], |
| supported_ops: Optional[Dict[str, Handle]] = None, |
| ) -> Tuple[DefaultDict[str, float], Counter[str]]: |
| """Given a model and an input to the model, compute the total number of |
| activations of the model. |
| |
| Adopted from |
| https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/activation_count.py |
| |
| Args: |
| model (nn.Module): The model to compute activation counts. |
| inputs (tuple): Inputs that are passed to `model` to count activations. |
| Inputs need to be in a tuple. |
| supported_ops (dict(str,Callable) or None) : provide additional |
| handlers for extra ops, or overwrite the existing handlers for |
| convolution and matmul. The key is operator name and the value |
| is a function that takes (inputs, outputs) of the op. |
| |
| Returns: |
| tuple[defaultdict, Counter]: A dictionary that records the number of |
| activation (mega) for each operation and a Counter that records the |
| number of unsupported operations. |
| """ |
| if supported_ops is None: |
| supported_ops = {} |
| act_counter = ActivationAnalyzer(model, |
| inputs).set_op_handle(**supported_ops) |
| mega_acts = defaultdict(float) |
| for op, act in act_counter.by_operator().items(): |
| mega_acts[op] = act / 1e6 |
| return mega_acts, act_counter.unsupported_ops() |
|
|
|
|
| def parameter_count(model: nn.Module) -> typing.DefaultDict[str, int]: |
| """Count parameters of a model and its submodules. |
| |
| Adopted from |
| https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/parameter_count.py |
| |
| Args: |
| model (nn.Module): the model to count parameters. |
| |
| Returns: |
| dict[str, int]: the key is either a parameter name or a module name. |
| The value is the number of elements in the parameter, or in all |
| parameters of the module. The key "" corresponds to the total |
| number of parameters of the model. |
| """ |
| count = defaultdict(int) |
| for name, param in model.named_parameters(): |
| size = param.numel() |
| name = name.split('.') |
| for k in range(0, len(name) + 1): |
| prefix = '.'.join(name[:k]) |
| count[prefix] += size |
| return count |
|
|
|
|
| def parameter_count_table(model: nn.Module, max_depth: int = 3) -> str: |
| """Format the parameter count of the model (and its submodules or |
| parameters) |
| |
| Adopted from |
| https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/parameter_count.py |
| |
| Args: |
| model (nn.Module): the model to count parameters. |
| max_depth (int): maximum depth to recursively print submodules or |
| parameters |
| |
| Returns: |
| str: the table to be printed |
| """ |
| count: typing.DefaultDict[str, int] = parameter_count(model) |
| |
| param_shape: typing.Dict[str, typing.Tuple] = { |
| k: tuple(v.shape) |
| for k, v in model.named_parameters() |
| } |
|
|
| |
| rows: typing.List[typing.Tuple] = [] |
|
|
| def format_size(x: int) -> str: |
| if x > 1e8: |
| return f'{x / 1e9:.1f}G' |
| if x > 1e5: |
| return f'{x / 1e6:.1f}M' |
| if x > 1e2: |
| return f'{x / 1e3:.1f}K' |
| return str(x) |
|
|
| def fill(lvl: int, prefix: str) -> None: |
| if lvl >= max_depth: |
| return |
| for name, v in count.items(): |
| if name.count('.') == lvl and name.startswith(prefix): |
| indent = ' ' * (lvl + 1) |
| if name in param_shape: |
| rows.append( |
| (indent + name, indent + str(param_shape[name]))) |
| else: |
| rows.append((indent + name, indent + format_size(v))) |
| fill(lvl + 1, name + '.') |
|
|
| rows.append(('model', format_size(count.pop('')))) |
| fill(0, '') |
|
|
| table = Table( |
| title=f'parameter count of {model.__class__.__name__}', box=box.ASCII2) |
| table.add_column('name') |
| table.add_column('#elements or shape') |
|
|
| for row in rows: |
| table.add_row(*row) |
|
|
| console = Console() |
| with console.capture() as capture: |
| console.print(table, end='') |
|
|
| return capture.get() |
|
|