| """PyTorch ResNet |
| |
| This started as a copy of https://github.com/pytorch/vision 'resnet.py' (BSD-3-Clause) with |
| additional dropout and dynamic global avg/max pool. |
| |
| ResNeXt, SE-ResNeXt, SENet, and MXNet Gluon stem/downsample variants, tiered stems added by Ross Wightman |
| |
| Copyright 2019, Ross Wightman |
| """ |
| import math |
| from functools import partial |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| |
| |
| |
|
|
| |
| |
|
|
| def get_padding(kernel_size, stride, dilation=1): |
| padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 |
| return padding |
|
|
|
|
| class softball(nn.Module): |
| def __init__(self, radius2=None, inplace=True): |
| super(softball, self).__init__() |
| self.radius2 = radius2 if radius2 is not None else None |
|
|
| def forward(self, x): |
| if self.radius2 is None: |
| self.radius2 = x.size()[1] |
| norm = torch.sqrt(1 + (x*x).sum(1, keepdim=True) / self.radius2) |
| return x / norm |
|
|
| class hardball(nn.Module): |
| def __init__(self, radius2=None): |
| super(hardball, self).__init__() |
| self.radius = np.sqrt(radius2) if radius2 is not None else None |
|
|
| def forward(self, x): |
| norm = torch.sqrt((x*x).sum(1, keepdim=True)) |
| if self.radius is None: |
| self.radius = np.sqrt(x.size()[1]) |
| return torch.where(norm > self.radius, self.radius * x / norm, x) |
|
|
|
|
| class ConvBN(nn.Module): |
| def __init__(self, conv, bn): |
| super(ConvBN, self).__init__() |
| self.conv = conv |
| self.bn = bn |
| self.fused_weight = None |
| self.fused_bias = None |
|
|
| def forward(self, x): |
| if self.training: |
| x = self.conv(x) |
| x = self.bn(x) |
| else: |
| if self.fused_weight is not None and self.fused_bias is not None: |
| x = F.conv2d(x, self.fused_weight, self.fused_bias, |
| self.conv.stride, self.conv.padding, |
| self.conv.dilation, self.conv.groups) |
| else: |
| x = self.conv(x) |
| x = self.bn(x) |
| return x |
|
|
| def fuse_bn(self): |
| if self.training: |
| raise RuntimeError("Call fuse_bn only in eval mode") |
| |
| |
| w = self.conv.weight |
| mean = self.bn.running_mean |
| var = torch.sqrt(self.bn.running_var + self.bn.eps) |
| gamma = self.bn.weight |
| beta = self.bn.bias |
|
|
| self.fused_weight = w * (gamma / var).reshape(-1, 1, 1, 1) |
| self.fused_bias = beta - (gamma * mean / var) |
|
|
|
|
| class QLBlock(nn.Module): |
| expansion = 1 |
|
|
| def __init__( |
| self, |
| inplanes, |
| planes, |
| stride=1, |
| downsample=None, |
| cardinality=1, |
| base_width=64, |
| reduce_first=1, |
| dilation=1, |
| first_dilation=None, |
| act_layer=nn.ReLU, |
| norm_layer=nn.BatchNorm2d, |
| ): |
| super(QLBlock, self).__init__() |
|
|
| self.k = 8 if inplanes <= 128 else 4 if inplanes <= 256 else 2 |
| width = inplanes * self.k |
| outplanes = inplanes if downsample is None else inplanes * 2 |
| first_dilation = first_dilation or dilation |
|
|
| self.conv1 = ConvBN( |
| nn.Conv2d(inplanes, width*2, kernel_size=1, stride=1, |
| dilation=first_dilation, groups=1, bias=False), |
| norm_layer(width*2)) |
|
|
| |
| |
| self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, |
| padding=1, dilation=first_dilation, groups=width, bias=False) |
| self.bn2 = norm_layer(width) |
|
|
| self.conv3 = ConvBN( |
| nn.Conv2d(width, outplanes, kernel_size=1, groups=1, bias=False), |
| norm_layer(outplanes)) |
|
|
| self.skip = ConvBN( |
| nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=stride, |
| dilation=first_dilation, groups=1, bias=False), |
| norm_layer(outplanes)) if downsample is not None else nn.Identity() |
|
|
| self.act3 = hardball(radius2=outplanes) |
|
|
| def zero_init_last(self): |
| if getattr(self.conv3.bn, 'weight', None) is not None: |
| nn.init.zeros_(self.conv3.bn.weight) |
|
|
| def conv_forward(self, x): |
| conv = self.conv2 |
| C = x.size(1) // self.k |
| kernel = conv.weight.repeat(C, 1, 1, 1) |
| bias = conv.bias.repeat(C) if conv.bias is not None else None |
| return F.conv2d(x, kernel, bias, conv.stride, |
| conv.padding, conv.dilation, x.size(1)) |
|
|
| def forward(self, x): |
| x0 = self.skip(x) |
| x = self.conv1(x) |
| x = x[:, ::2, :, :] * x[:, 1::2, :, :] |
| |
| x = self.conv2(x) |
| x = self.bn2(x) |
| x = self.conv3(x) |
| x += x0 |
| if self.act3 is not None: |
| x = self.act3(x) |
| return x |
|
|
| def make_blocks( |
| block_fn, |
| channels, |
| block_repeats, |
| inplanes, |
| reduce_first=1, |
| output_stride=32, |
| down_kernel_size=1, |
| avg_down=False, |
| **kwargs, |
| ): |
| stages = [] |
| feature_info = [] |
| net_num_blocks = sum(block_repeats) |
| net_block_idx = 0 |
| net_stride = 4 |
| dilation = prev_dilation = 1 |
| for stage_idx, (planes, num_blocks) in enumerate(zip(channels, block_repeats)): |
| stage_name = f'layer{stage_idx + 1}' |
| stride = 1 if stage_idx == 0 else 2 |
| if net_stride >= output_stride: |
| dilation *= stride |
| stride = 1 |
| else: |
| net_stride *= stride |
|
|
| downsample = None |
| if stride != 1 or inplanes != planes * block_fn.expansion: |
| downsample = True |
|
|
| block_kwargs = dict(reduce_first=reduce_first, dilation=dilation, **kwargs) |
| blocks = [] |
| for block_idx in range(num_blocks): |
| downsample = downsample if block_idx == 0 else None |
| stride = stride if block_idx == 0 else 1 |
| blocks.append(block_fn( |
| inplanes, planes, stride, downsample, first_dilation=prev_dilation, |
| **block_kwargs)) |
| prev_dilation = dilation |
| inplanes = planes * block_fn.expansion |
| net_block_idx += 1 |
|
|
| stages.append((stage_name, nn.Sequential(*blocks))) |
| feature_info.append(dict(num_chs=inplanes, reduction=net_stride, module=stage_name)) |
|
|
| return stages, feature_info |
|
|
|
|
| class QLNet(nn.Module): |
| |
|
|
| def __init__( |
| self, |
| block=QLBlock, |
| layers=[3,4,12,3], |
| num_classes=1000, |
| in_chans=3, |
| output_stride=32, |
| global_pool='avg', |
| cardinality=1, |
| base_width=64, |
| stem_width=32, |
| stem_type='', |
| replace_stem_pool=False, |
| block_reduce_first=1, |
| down_kernel_size=1, |
| avg_down=False, |
| act_layer=nn.ReLU, |
| norm_layer=nn.BatchNorm2d, |
| zero_init_last=True, |
| block_args=None, |
| ): |
| """ |
| Args: |
| block (nn.Module): class for the residual block. Options are BasicBlock, Bottleneck. |
| layers (List[int]) : number of layers in each block |
| num_classes (int): number of classification classes (default 1000) |
| in_chans (int): number of input (color) channels. (default 3) |
| output_stride (int): output stride of the network, 32, 16, or 8. (default 32) |
| global_pool (str): Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax' (default 'avg') |
| cardinality (int): number of convolution groups for 3x3 conv in Bottleneck. (default 1) |
| base_width (int): bottleneck channels factor. `planes * base_width / 64 * cardinality` (default 64) |
| stem_width (int): number of channels in stem convolutions (default 64) |
| stem_type (str): The type of stem (default ''): |
| * '', default - a single 7x7 conv with a width of stem_width |
| * 'deep' - three 3x3 convolution layers of widths stem_width, stem_width, stem_width * 2 |
| * 'deep_tiered' - three 3x3 conv layers of widths stem_width//4 * 3, stem_width, stem_width * 2 |
| replace_stem_pool (bool): replace stem max-pooling layer with a 3x3 stride-2 convolution |
| block_reduce_first (int): Reduction factor for first convolution output width of residual blocks, |
| 1 for all archs except senets, where 2 (default 1) |
| down_kernel_size (int): kernel size of residual block downsample path, |
| 1x1 for most, 3x3 for senets (default: 1) |
| avg_down (bool): use avg pooling for projection skip connection between stages/downsample (default False) |
| act_layer (str, nn.Module): activation layer |
| norm_layer (str, nn.Module): normalization layer |
| zero_init_last (bool): zero-init the last weight in residual path (usually last BN affine weight) |
| block_args (dict): Extra kwargs to pass through to block module |
| """ |
| super(QLNet, self).__init__() |
| block_args = block_args or dict() |
| assert output_stride in (8, 16, 32) |
| self.num_classes = num_classes |
| self.grad_checkpointing = False |
| |
| act_layer = get_act_layer(act_layer) |
| norm_layer = get_norm_layer(norm_layer) |
|
|
| |
| deep_stem = 'deep' in stem_type |
| inplanes = stem_width * 2 if deep_stem else 64 |
| if deep_stem: |
| stem_chs = (stem_width, stem_width) |
| if 'tiered' in stem_type: |
| stem_chs = (3 * (stem_width // 4), stem_width) |
| self.conv1 = nn.Sequential(*[ |
| nn.Conv2d(in_chans, stem_chs[0], 3, stride=2, padding=1, bias=False), |
| norm_layer(stem_chs[0]), |
| act_layer(inplace=True), |
| nn.Conv2d(stem_chs[0], stem_chs[1], 3, stride=1, padding=1, bias=False), |
| norm_layer(stem_chs[1]), |
| act_layer(inplace=True), |
| nn.Conv2d(stem_chs[1], inplanes, 3, stride=1, padding=1, bias=False)]) |
| else: |
| self.conv1 = nn.Conv2d(in_chans, inplanes, kernel_size=7, stride=2, padding=3, bias=False) |
| self.bn1 = norm_layer(inplanes) |
| |
| self.feature_info = [dict(num_chs=inplanes, reduction=2, module='act1')] |
|
|
| |
| if replace_stem_pool: |
| self.maxpool = nn.Sequential(*filter(None, [ |
| nn.Conv2d(inplanes, inplanes, 3, stride=2, padding=1, bias=False), |
| norm_layer(inplanes), |
| act_layer(inplace=True) |
| ])) |
| else: |
| self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) |
|
|
| |
| channels = [64, 128, 256, 512] |
| stage_modules, stage_feature_info = make_blocks( |
| block, |
| channels, |
| layers, |
| inplanes, |
| cardinality=cardinality, |
| base_width=base_width, |
| output_stride=output_stride, |
| reduce_first=block_reduce_first, |
| avg_down=avg_down, |
| down_kernel_size=down_kernel_size, |
| act_layer=act_layer, |
| norm_layer=norm_layer, |
| **block_args, |
| ) |
| for stage in stage_modules: |
| self.add_module(*stage) |
| self.feature_info.extend(stage_feature_info) |
|
|
| |
| |
| |
|
|
| |
| self.num_features = 512 * block.expansion |
| self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) |
|
|
| self.init_weights(zero_init_last=zero_init_last) |
|
|
| @staticmethod |
| def from_pretrained(model_name: str, load_weights=True, **kwargs) -> 'ResNet': |
| entry_fn = model_entrypoint(model_name, 'resnet') |
| return entry_fn(pretrained=not load_weights, **kwargs) |
|
|
| @torch.jit.ignore |
| def init_weights(self, zero_init_last=True): |
| for n, m in self.named_modules(): |
| if isinstance(m, nn.Conv2d): |
| nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='linear') |
| |
| if zero_init_last: |
| for m in self.modules(): |
| if hasattr(m, 'zero_init_last'): |
| m.zero_init_last() |
|
|
| @torch.jit.ignore |
| def group_matcher(self, coarse=False): |
| matcher = dict(stem=r'^conv1|bn1|maxpool', blocks=r'^layer(\d+)' if coarse else r'^layer(\d+)\.(\d+)') |
| return matcher |
|
|
| @torch.jit.ignore |
| def set_grad_checkpointing(self, enable=True): |
| self.grad_checkpointing = enable |
|
|
| @torch.jit.ignore |
| def get_classifier(self, name_only=False): |
| return 'fc' if name_only else self.fc |
|
|
| def reset_classifier(self, num_classes, global_pool='avg'): |
| self.num_classes = num_classes |
| self.global_pool, self.fc = create_classifier(self.num_features, 99, |
| pool_type=global_pool) |
|
|
| def forward_features(self, x): |
| x = self.conv1(x) |
| x = self.bn1(x) |
| |
| x = self.maxpool(x) |
|
|
| if self.grad_checkpointing and not torch.jit.is_scripting(): |
| x = checkpoint_seq([self.layer1, self.layer2, self.layer3, self.layer4], x, flatten=True) |
| else: |
| x = self.layer1(x) |
| x = self.layer2(x) |
| x = self.layer3(x) |
| x = self.layer4(x) |
| return x |
|
|
| def forward_head(self, x, pre_logits: bool = False): |
| x = self.global_pool(x) |
| return x if pre_logits else self.fc(x) |
|
|
| def forward(self, x): |
| x = self.forward_features(x) |
| |
| x = self.forward_head(x) |
| return x |
|
|
|
|
| |
| |
|
|
|
|
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
| |
| |
| |
| |
| |
|
|