| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import contextlib |
| from copy import deepcopy |
| from typing import Sequence |
|
|
| import torch |
| import torch.nn as nn |
| from thop import profile |
|
|
|
|
| __all__ = [ |
| "fuse_conv_and_bn", |
| "fuse_model", |
| "get_model_info", |
| "replace_module", |
| "freeze_module", |
| "adjust_status", |
| ] |
|
|
|
|
| def get_model_info(resolution, model: nn.Module, tsize: Sequence[int]) -> str: |
|
|
|
|
| stride = resolution |
| img = torch.zeros((1, 3, stride, stride), device=next(model.parameters()).device) |
| flops, params = profile(deepcopy(model), inputs=(img,), verbose=False) |
| params /= 1e6 |
| flops /= 1e9 |
| flops *= tsize[0] * tsize[1] / stride / stride * 2 |
| info = "Params: {:.2f}M, Gflops: {:.2f}".format(params, flops) |
| return info |
|
|
|
|
| def fuse_conv_and_bn(conv: nn.Conv2d, bn: nn.BatchNorm2d) -> nn.Conv2d: |
| """ |
| Fuse convolution and batchnorm layers. |
| check more info on https://tehnokv.com/posts/fusing-batchnorm-and-conv/ |
| |
| Args: |
| conv (nn.Conv2d): convolution to fuse. |
| bn (nn.BatchNorm2d): batchnorm to fuse. |
| |
| Returns: |
| nn.Conv2d: fused convolution behaves the same as the input conv and bn. |
| """ |
| fusedconv = ( |
| nn.Conv2d( |
| conv.in_channels, |
| conv.out_channels, |
| kernel_size=conv.kernel_size, |
| stride=conv.stride, |
| padding=conv.padding, |
| groups=conv.groups, |
| bias=True, |
| ) |
| .requires_grad_(False) |
| .to(conv.weight.device) |
| ) |
|
|
| |
| w_conv = conv.weight.clone().view(conv.out_channels, -1) |
| w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))) |
| fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape)) |
|
|
| |
| b_conv = ( |
| torch.zeros(conv.weight.size(0), device=conv.weight.device) |
| if conv.bias is None |
| else conv.bias |
| ) |
| b_bn = bn.bias - bn.weight.mul(bn.running_mean).div( |
| torch.sqrt(bn.running_var + bn.eps) |
| ) |
| fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) |
|
|
| return fusedconv |
|
|
|
|
| def fuse_model(model: nn.Module) -> nn.Module: |
| """fuse conv and bn in model |
| |
| Args: |
| model (nn.Module): model to fuse |
| |
| Returns: |
| nn.Module: fused model |
| """ |
| from yolod.models.blocks.network_blocks import BaseConv |
|
|
| for m in model.modules(): |
| if type(m) is BaseConv and hasattr(m, "bn"): |
| m.conv = fuse_conv_and_bn(m.conv, m.bn) |
| delattr(m, "bn") |
| m.forward = m.fuseforward |
| return model |
|
|
|
|
| def replace_module(module, replaced_module_type, new_module_type, replace_func=None) -> nn.Module: |
| """ |
| Replace given type in module to a new type. mostly used in deploy. |
| |
| Args: |
| module (nn.Module): model to apply replace operation. |
| replaced_module_type (Type): module type to be replaced. |
| new_module_type (Type) |
| replace_func (function): python function to describe replace logic. Defalut value None. |
| |
| Returns: |
| model (nn.Module): module that already been replaced. |
| """ |
|
|
| def default_replace_func(replaced_module_type, new_module_type): |
| return new_module_type() |
|
|
| if replace_func is None: |
| replace_func = default_replace_func |
|
|
| model = module |
| if isinstance(module, replaced_module_type): |
| model = replace_func(replaced_module_type, new_module_type) |
| else: |
| for name, child in module.named_children(): |
| new_child = replace_module(child, replaced_module_type, new_module_type) |
| if new_child is not child: |
| model.add_module(name, new_child) |
|
|
| return model |
|
|
|
|
| def freeze_module(module: nn.Module, name=None) -> nn.Module: |
| """freeze module inplace |
| |
| Args: |
| module (nn.Module): module to freeze. |
| name (str, optional): name to freeze. If not given, freeze the whole module. |
| Note that fuzzy match is not supported. Defaults to None. |
| |
| Examples: |
| freeze the backbone of model |
| >>> freeze_moudle(model.backbone) |
| |
| or freeze the backbone of model by name |
| >>> freeze_moudle(model, name="backbone") |
| """ |
| for param_name, parameter in module.named_parameters(): |
| if name is None or name in param_name: |
| parameter.requires_grad = False |
|
|
| |
| for module_name, sub_module in module.named_modules(): |
| |
| if name is None or name in module_name: |
| sub_module.eval() |
|
|
| return module |
|
|
|
|
| @contextlib.contextmanager |
| def adjust_status(module: nn.Module, training: bool = False) -> nn.Module: |
| """Adjust module to training/eval mode temporarily. |
| |
| Args: |
| module (nn.Module): module to adjust status. |
| training (bool): training mode to set. True for train mode, False fro eval mode. |
| |
| Examples: |
| >>> with adjust_status(model, training=False): |
| ... model(data) |
| """ |
| status = {} |
|
|
| def backup_status(module): |
| for m in module.modules(): |
| |
| status[m] = m.training |
| m.training = training |
|
|
| def recover_status(module): |
| for m in module.modules(): |
| |
| m.training = status.pop(m) |
|
|
| backup_status(module) |
| yield module |
| recover_status(module) |
|
|