| | |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from mmcv.cnn.bricks.wrappers import NewEmptyTensorOp, obsolete_torch_version |
| |
|
| | from mmdet.registry import MODELS |
| |
|
| | if torch.__version__ == 'parrots': |
| | TORCH_VERSION = torch.__version__ |
| | else: |
| | |
| | |
| | TORCH_VERSION = tuple(int(x) for x in torch.__version__.split('.')[:2]) |
| |
|
| |
|
| | def adaptive_avg_pool2d(input, output_size): |
| | """Handle empty batch dimension to adaptive_avg_pool2d. |
| | |
| | Args: |
| | input (tensor): 4D tensor. |
| | output_size (int, tuple[int,int]): the target output size. |
| | """ |
| | if input.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)): |
| | if isinstance(output_size, int): |
| | output_size = [output_size, output_size] |
| | output_size = [*input.shape[:2], *output_size] |
| | empty = NewEmptyTensorOp.apply(input, output_size) |
| | return empty |
| | else: |
| | return F.adaptive_avg_pool2d(input, output_size) |
| |
|
| |
|
| | class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d): |
| | """Handle empty batch dimension to AdaptiveAvgPool2d.""" |
| |
|
| | def forward(self, x): |
| | |
| | if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)): |
| | output_size = self.output_size |
| | if isinstance(output_size, int): |
| | output_size = [output_size, output_size] |
| | else: |
| | output_size = [ |
| | v if v is not None else d |
| | for v, d in zip(output_size, |
| | x.size()[-2:]) |
| | ] |
| | output_size = [*x.shape[:2], *output_size] |
| | empty = NewEmptyTensorOp.apply(x, output_size) |
| | return empty |
| |
|
| | return super().forward(x) |
| |
|
| |
|
| | |
| | |
| | @MODELS.register_module('FrozenBN') |
| | class FrozenBatchNorm2d(nn.Module): |
| | """BatchNorm2d where the batch statistics and the affine parameters are |
| | fixed. |
| | |
| | It contains non-trainable buffers called |
| | "weight" and "bias", "running_mean", "running_var", |
| | initialized to perform identity transformation. |
| | Args: |
| | num_features (int): :math:`C` from an expected input of size |
| | :math:`(N, C, H, W)`. |
| | eps (float): a value added to the denominator for numerical stability. |
| | Default: 1e-5 |
| | """ |
| |
|
| | def __init__(self, num_features, eps=1e-5, **kwargs): |
| | super().__init__() |
| | self.num_features = num_features |
| | self.eps = eps |
| | self.register_buffer('weight', torch.ones(num_features)) |
| | self.register_buffer('bias', torch.zeros(num_features)) |
| | self.register_buffer('running_mean', torch.zeros(num_features)) |
| | self.register_buffer('running_var', torch.ones(num_features) - eps) |
| |
|
| | def forward(self, x): |
| | if x.requires_grad: |
| | |
| | |
| | |
| | scale = self.weight * (self.running_var + self.eps).rsqrt() |
| | bias = self.bias - self.running_mean * scale |
| | scale = scale.reshape(1, -1, 1, 1) |
| | bias = bias.reshape(1, -1, 1, 1) |
| | out_dtype = x.dtype |
| | return x * scale.to(out_dtype) + bias.to(out_dtype) |
| | else: |
| | |
| | |
| | return F.batch_norm( |
| | x, |
| | self.running_mean, |
| | self.running_var, |
| | self.weight, |
| | self.bias, |
| | training=False, |
| | eps=self.eps, |
| | ) |
| |
|
| | def __repr__(self): |
| | return 'FrozenBatchNorm2d(num_features={}, eps={})'.format( |
| | self.num_features, self.eps) |
| |
|
| | @classmethod |
| | def convert_frozen_batchnorm(cls, module): |
| | """Convert all BatchNorm/SyncBatchNorm in module into FrozenBatchNorm. |
| | |
| | Args: |
| | module (torch.nn.Module): |
| | Returns: |
| | If module is BatchNorm/SyncBatchNorm, returns a new module. |
| | Otherwise, in-place convert module and return it. |
| | Similar to convert_sync_batchnorm in |
| | https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py |
| | """ |
| | bn_module = nn.modules.batchnorm |
| | bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm) |
| | res = module |
| | if isinstance(module, bn_module): |
| | res = cls(module.num_features) |
| | if module.affine: |
| | res.weight.data = module.weight.data.clone().detach() |
| | res.bias.data = module.bias.data.clone().detach() |
| | res.running_mean.data = module.running_mean.data |
| | res.running_var.data = module.running_var.data |
| | res.eps = module.eps |
| | else: |
| | for name, child in module.named_children(): |
| | new_child = cls.convert_frozen_batchnorm(child) |
| | if new_child is not child: |
| | res.add_module(name, new_child) |
| | return res |
| |
|