Spaces:
Build error
Build error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch | |
| import torch.nn as nn | |
| def _fuse_conv_bn(conv: nn.Module, bn: nn.Module) -> nn.Module: | |
| """Fuse conv and bn into one module. | |
| Args: | |
| conv (nn.Module): Conv to be fused. | |
| bn (nn.Module): BN to be fused. | |
| Returns: | |
| nn.Module: Fused module. | |
| """ | |
| conv_w = conv.weight | |
| conv_b = conv.bias if conv.bias is not None else torch.zeros_like( | |
| bn.running_mean) | |
| factor = bn.weight / torch.sqrt(bn.running_var + bn.eps) | |
| conv.weight = nn.Parameter(conv_w * | |
| factor.reshape([conv.out_channels, 1, 1, 1])) | |
| conv.bias = nn.Parameter((conv_b - bn.running_mean) * factor + bn.bias) | |
| return conv | |
| def fuse_conv_bn(module: nn.Module) -> nn.Module: | |
| """Recursively fuse conv and bn in a module. | |
| During inference, the functionary of batch norm layers is turned off | |
| but only the mean and var alone channels are used, which exposes the | |
| chance to fuse it with the preceding conv layers to save computations and | |
| simplify network structures. | |
| Args: | |
| module (nn.Module): Module to be fused. | |
| Returns: | |
| nn.Module: Fused module. | |
| """ | |
| last_conv = None | |
| last_conv_name = None | |
| for name, child in module.named_children(): | |
| if isinstance(child, | |
| (nn.modules.batchnorm._BatchNorm, nn.SyncBatchNorm)): | |
| if last_conv is None: # only fuse BN that is after Conv | |
| continue | |
| fused_conv = _fuse_conv_bn(last_conv, child) | |
| module._modules[last_conv_name] = fused_conv | |
| # To reduce changes, set BN as Identity instead of deleting it. | |
| module._modules[name] = nn.Identity() | |
| last_conv = None | |
| elif isinstance(child, nn.Conv2d): | |
| last_conv = child | |
| last_conv_name = name | |
| else: | |
| fuse_conv_bn(child) | |
| return module | |