| from __future__ import annotations | |
| import copy | |
| from typing import TypeVar | |
| import torch | |
| __all__ = [ | |
| "fuse_conv_bn_eval", | |
| "fuse_conv_bn_weights", | |
| "fuse_linear_bn_eval", | |
| "fuse_linear_bn_weights", | |
| ] | |
| ConvT = TypeVar("ConvT", bound="torch.nn.modules.conv._ConvNd") | |
| LinearT = TypeVar("LinearT", bound="torch.nn.Linear") | |
| def fuse_conv_bn_eval( | |
| conv: ConvT, | |
| bn: torch.nn.modules.batchnorm._BatchNorm, | |
| transpose: bool = False, | |
| ) -> ConvT: | |
| r"""Fuse a convolutional module and a BatchNorm module into a single, new convolutional module. | |
| Args: | |
| conv (torch.nn.modules.conv._ConvNd): A convolutional module. | |
| bn (torch.nn.modules.batchnorm._BatchNorm): A BatchNorm module. | |
| transpose (bool, optional): If True, transpose the convolutional weight. Defaults to False. | |
| Returns: | |
| torch.nn.modules.conv._ConvNd: The fused convolutional module. | |
| .. note:: | |
| Both ``conv`` and ``bn`` must be in eval mode, and ``bn`` must have its running buffers computed. | |
| """ | |
| assert not (conv.training or bn.training), "Fusion only for eval!" | |
| fused_conv = copy.deepcopy(conv) | |
| assert bn.running_mean is not None and bn.running_var is not None | |
| fused_conv.weight, fused_conv.bias = fuse_conv_bn_weights( | |
| fused_conv.weight, | |
| fused_conv.bias, | |
| bn.running_mean, | |
| bn.running_var, | |
| bn.eps, | |
| bn.weight, | |
| bn.bias, | |
| transpose, | |
| ) | |
| return fused_conv | |
| def fuse_conv_bn_weights( | |
| conv_w: torch.Tensor, | |
| conv_b: torch.Tensor | None, | |
| bn_rm: torch.Tensor, | |
| bn_rv: torch.Tensor, | |
| bn_eps: float, | |
| bn_w: torch.Tensor | None, | |
| bn_b: torch.Tensor | None, | |
| transpose: bool = False, | |
| ) -> tuple[torch.nn.Parameter, torch.nn.Parameter]: | |
| r"""Fuse convolutional module parameters and BatchNorm module parameters into new convolutional module parameters. | |
| Args: | |
| conv_w (torch.Tensor): Convolutional weight. | |
| conv_b (Optional[torch.Tensor]): Convolutional bias. | |
| bn_rm (torch.Tensor): BatchNorm running mean. | |
| bn_rv (torch.Tensor): BatchNorm running variance. | |
| bn_eps (float): BatchNorm epsilon. | |
| bn_w (Optional[torch.Tensor]): BatchNorm weight. | |
| bn_b (Optional[torch.Tensor]): BatchNorm bias. | |
| transpose (bool, optional): If True, transpose the conv weight. Defaults to False. | |
| Returns: | |
| Tuple[torch.nn.Parameter, torch.nn.Parameter]: Fused convolutional weight and bias. | |
| """ | |
| conv_weight_dtype = conv_w.dtype | |
| conv_bias_dtype = conv_b.dtype if conv_b is not None else conv_weight_dtype | |
| if conv_b is None: | |
| conv_b = torch.zeros_like(bn_rm) | |
| if bn_w is None: | |
| bn_w = torch.ones_like(bn_rm) | |
| if bn_b is None: | |
| bn_b = torch.zeros_like(bn_rm) | |
| bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps) | |
| if transpose: | |
| shape = [1, -1] + [1] * (len(conv_w.shape) - 2) | |
| else: | |
| shape = [-1, 1] + [1] * (len(conv_w.shape) - 2) | |
| fused_conv_w = (conv_w * (bn_w * bn_var_rsqrt).reshape(shape)).to( | |
| dtype=conv_weight_dtype | |
| ) | |
| fused_conv_b = ((conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b).to( | |
| dtype=conv_bias_dtype | |
| ) | |
| return ( | |
| torch.nn.Parameter(fused_conv_w, conv_w.requires_grad), | |
| torch.nn.Parameter(fused_conv_b, conv_b.requires_grad), | |
| ) | |
| def fuse_linear_bn_eval( | |
| linear: LinearT, | |
| bn: torch.nn.modules.batchnorm._BatchNorm, | |
| ) -> LinearT: | |
| r"""Fuse a linear module and a BatchNorm module into a single, new linear module. | |
| Args: | |
| linear (torch.nn.Linear): A Linear module. | |
| bn (torch.nn.modules.batchnorm._BatchNorm): A BatchNorm module. | |
| Returns: | |
| torch.nn.Linear: The fused linear module. | |
| .. note:: | |
| Both ``linear`` and ``bn`` must be in eval mode, and ``bn`` must have its running buffers computed. | |
| """ | |
| assert not (linear.training or bn.training), "Fusion only for eval!" | |
| fused_linear = copy.deepcopy(linear) | |
| """ | |
| Linear-BN needs to be fused while preserving the shapes of linear weight/bias. | |
| To preserve the shapes of linear weight/bias, the channel dim of bn needs to be broadcastable with the last dim of linear, | |
| because bn operates over the channel dim, (N, C_in, H, W) while linear operates over the last dim, (*, H_in). | |
| To be broadcastable, the number of features in bn and | |
| the number of output features from linear must satisfy the following condition: | |
| 1. they are equal, or | |
| 2. the number of features in bn is 1 | |
| Otherwise, skip the folding path | |
| """ | |
| assert linear.out_features == bn.num_features or bn.num_features == 1, ( | |
| "To fuse, linear.out_features == bn.num_features or bn.num_features == 1" | |
| ) | |
| assert bn.running_mean is not None and bn.running_var is not None | |
| fused_linear.weight, fused_linear.bias = fuse_linear_bn_weights( | |
| fused_linear.weight, | |
| fused_linear.bias, | |
| bn.running_mean, | |
| bn.running_var, | |
| bn.eps, | |
| bn.weight, | |
| bn.bias, | |
| ) | |
| return fused_linear | |
| def fuse_linear_bn_weights( | |
| linear_w: torch.Tensor, | |
| linear_b: torch.Tensor | None, | |
| bn_rm: torch.Tensor, | |
| bn_rv: torch.Tensor, | |
| bn_eps: float, | |
| bn_w: torch.Tensor, | |
| bn_b: torch.Tensor, | |
| ) -> tuple[torch.nn.Parameter, torch.nn.Parameter]: | |
| r"""Fuse linear module parameters and BatchNorm module parameters into new linear module parameters. | |
| Args: | |
| linear_w (torch.Tensor): Linear weight. | |
| linear_b (Optional[torch.Tensor]): Linear bias. | |
| bn_rm (torch.Tensor): BatchNorm running mean. | |
| bn_rv (torch.Tensor): BatchNorm running variance. | |
| bn_eps (float): BatchNorm epsilon. | |
| bn_w (torch.Tensor): BatchNorm weight. | |
| bn_b (torch.Tensor): BatchNorm bias. | |
| Returns: | |
| Tuple[torch.nn.Parameter, torch.nn.Parameter]: Fused linear weight and bias. | |
| """ | |
| linear_weight_dtype = linear_w.dtype | |
| linear_bias_dtype = linear_b.dtype if linear_b is not None else linear_weight_dtype | |
| if linear_b is None: | |
| linear_b = torch.zeros_like(bn_rm) | |
| bn_scale = bn_w * torch.rsqrt(bn_rv + bn_eps) | |
| fused_w = linear_w * bn_scale.unsqueeze(-1).to(dtype=linear_weight_dtype) | |
| fused_b = ((linear_b - bn_rm) * bn_scale + bn_b).to(dtype=linear_bias_dtype) | |
| return torch.nn.Parameter(fused_w, linear_w.requires_grad), torch.nn.Parameter( | |
| fused_b, linear_b.requires_grad | |
| ) | |