| |
| import torch |
| import torch.nn as nn |
|
|
|
|
| def _fuse_conv_bn(conv, bn): |
| """Fuse conv and bn into one module. |
| |
| Args: |
| conv (nn.Module): Conv to be fused. |
| bn (nn.Module): BN to be fused. |
| |
| Returns: |
| nn.Module: Fused module. |
| """ |
| conv_w = conv.weight |
| conv_b = conv.bias if conv.bias is not None else torch.zeros_like( |
| bn.running_mean) |
|
|
| factor = bn.weight / torch.sqrt(bn.running_var + bn.eps) |
| conv.weight = nn.Parameter(conv_w * |
| factor.reshape([conv.out_channels, 1, 1, 1])) |
| conv.bias = nn.Parameter((conv_b - bn.running_mean) * factor + bn.bias) |
| return conv |
|
|
|
|
| def fuse_conv_bn(module): |
| """Recursively fuse conv and bn in a module. |
| |
| During inference, the functionary of batch norm layers is turned off |
| but only the mean and var alone channels are used, which exposes the |
| chance to fuse it with the preceding conv layers to save computations and |
| simplify network structures. |
| |
| Args: |
| module (nn.Module): Module to be fused. |
| |
| Returns: |
| nn.Module: Fused module. |
| """ |
| last_conv = None |
| last_conv_name = None |
|
|
| for name, child in module.named_children(): |
| if isinstance(child, |
| (nn.modules.batchnorm._BatchNorm, nn.SyncBatchNorm)): |
| if last_conv is None: |
| continue |
| fused_conv = _fuse_conv_bn(last_conv, child) |
| module._modules[last_conv_name] = fused_conv |
| |
| module._modules[name] = nn.Identity() |
| last_conv = None |
| elif isinstance(child, nn.Conv2d): |
| last_conv = child |
| last_conv_name = name |
| else: |
| fuse_conv_bn(child) |
| return module |
|
|