| | |
| | |
| |
|
| | import typing |
| | from typing import Any, List |
| | import fvcore |
| | from fvcore.nn import activation_count, flop_count, parameter_count, parameter_count_table |
| | from torch import nn |
| |
|
| | from detectron2.export import TracingAdapter |
| |
|
| | __all__ = [ |
| | "activation_count_operators", |
| | "flop_count_operators", |
| | "parameter_count_table", |
| | "parameter_count", |
| | "FlopCountAnalysis", |
| | ] |
| |
|
| | FLOPS_MODE = "flops" |
| | ACTIVATIONS_MODE = "activations" |
| |
|
| |
|
| | |
| | _IGNORED_OPS = { |
| | "aten::add", |
| | "aten::add_", |
| | "aten::argmax", |
| | "aten::argsort", |
| | "aten::batch_norm", |
| | "aten::constant_pad_nd", |
| | "aten::div", |
| | "aten::div_", |
| | "aten::exp", |
| | "aten::log2", |
| | "aten::max_pool2d", |
| | "aten::meshgrid", |
| | "aten::mul", |
| | "aten::mul_", |
| | "aten::neg", |
| | "aten::nonzero_numpy", |
| | "aten::reciprocal", |
| | "aten::repeat_interleave", |
| | "aten::rsub", |
| | "aten::sigmoid", |
| | "aten::sigmoid_", |
| | "aten::softmax", |
| | "aten::sort", |
| | "aten::sqrt", |
| | "aten::sub", |
| | "torchvision::nms", |
| | } |
| |
|
| |
|
| | class FlopCountAnalysis(fvcore.nn.FlopCountAnalysis): |
| | """ |
| | Same as :class:`fvcore.nn.FlopCountAnalysis`, but supports detectron2 models. |
| | """ |
| |
|
| | def __init__(self, model, inputs): |
| | """ |
| | Args: |
| | model (nn.Module): |
| | inputs (Any): inputs of the given model. Does not have to be tuple of tensors. |
| | """ |
| | wrapper = TracingAdapter(model, inputs, allow_non_tensor=True) |
| | super().__init__(wrapper, wrapper.flattened_inputs) |
| | self.set_op_handle(**{k: None for k in _IGNORED_OPS}) |
| |
|
| |
|
| | def flop_count_operators(model: nn.Module, inputs: list) -> typing.DefaultDict[str, float]: |
| | """ |
| | Implement operator-level flops counting using jit. |
| | This is a wrapper of :func:`fvcore.nn.flop_count` and adds supports for standard |
| | detection models in detectron2. |
| | Please use :class:`FlopCountAnalysis` for more advanced functionalities. |
| | |
| | Note: |
| | The function runs the input through the model to compute flops. |
| | The flops of a detection model is often input-dependent, for example, |
| | the flops of box & mask head depends on the number of proposals & |
| | the number of detected objects. |
| | Therefore, the flops counting using a single input may not accurately |
| | reflect the computation cost of a model. It's recommended to average |
| | across a number of inputs. |
| | |
| | Args: |
| | model: a detectron2 model that takes `list[dict]` as input. |
| | inputs (list[dict]): inputs to model, in detectron2's standard format. |
| | Only "image" key will be used. |
| | supported_ops (dict[str, Handle]): see documentation of :func:`fvcore.nn.flop_count` |
| | |
| | Returns: |
| | Counter: Gflop count per operator |
| | """ |
| | old_train = model.training |
| | model.eval() |
| | ret = FlopCountAnalysis(model, inputs).by_operator() |
| | model.train(old_train) |
| | return {k: v / 1e9 for k, v in ret.items()} |
| |
|
| |
|
| | def activation_count_operators( |
| | model: nn.Module, inputs: list, **kwargs |
| | ) -> typing.DefaultDict[str, float]: |
| | """ |
| | Implement operator-level activations counting using jit. |
| | This is a wrapper of fvcore.nn.activation_count, that supports standard detection models |
| | in detectron2. |
| | |
| | Note: |
| | The function runs the input through the model to compute activations. |
| | The activations of a detection model is often input-dependent, for example, |
| | the activations of box & mask head depends on the number of proposals & |
| | the number of detected objects. |
| | |
| | Args: |
| | model: a detectron2 model that takes `list[dict]` as input. |
| | inputs (list[dict]): inputs to model, in detectron2's standard format. |
| | Only "image" key will be used. |
| | |
| | Returns: |
| | Counter: activation count per operator |
| | """ |
| | return _wrapper_count_operators(model=model, inputs=inputs, mode=ACTIVATIONS_MODE, **kwargs) |
| |
|
| |
|
| | def _wrapper_count_operators( |
| | model: nn.Module, inputs: list, mode: str, **kwargs |
| | ) -> typing.DefaultDict[str, float]: |
| | |
| | supported_ops = {k: lambda *args, **kwargs: {} for k in _IGNORED_OPS} |
| | supported_ops.update(kwargs.pop("supported_ops", {})) |
| | kwargs["supported_ops"] = supported_ops |
| |
|
| | assert len(inputs) == 1, "Please use batch size=1" |
| | tensor_input = inputs[0]["image"] |
| | inputs = [{"image": tensor_input}] |
| |
|
| | old_train = model.training |
| | if isinstance(model, (nn.parallel.distributed.DistributedDataParallel, nn.DataParallel)): |
| | model = model.module |
| | wrapper = TracingAdapter(model, inputs) |
| | wrapper.eval() |
| | if mode == FLOPS_MODE: |
| | ret = flop_count(wrapper, (tensor_input,), **kwargs) |
| | elif mode == ACTIVATIONS_MODE: |
| | ret = activation_count(wrapper, (tensor_input,), **kwargs) |
| | else: |
| | raise NotImplementedError("Count for mode {} is not supported yet.".format(mode)) |
| | |
| | if isinstance(ret, tuple): |
| | ret = ret[0] |
| | model.train(old_train) |
| | return ret |
| |
|
| |
|
| | def find_unused_parameters(model: nn.Module, inputs: Any) -> List[str]: |
| | """ |
| | Given a model, find parameters that do not contribute |
| | to the loss. |
| | |
| | Args: |
| | model: a model in training mode that returns losses |
| | inputs: argument or a tuple of arguments. Inputs of the model |
| | |
| | Returns: |
| | list[str]: the name of unused parameters |
| | """ |
| | assert model.training |
| | for _, prm in model.named_parameters(): |
| | prm.grad = None |
| |
|
| | if isinstance(inputs, tuple): |
| | losses = model(*inputs) |
| | else: |
| | losses = model(inputs) |
| |
|
| | if isinstance(losses, dict): |
| | losses = sum(losses.values()) |
| | losses.backward() |
| |
|
| | unused: List[str] = [] |
| | for name, prm in model.named_parameters(): |
| | if prm.grad is None: |
| | unused.append(name) |
| | prm.grad = None |
| | return unused |
| |
|