Mhara's picture
Upload folder using huggingface_hub
dae5c90 verified
Raw
History Blame Contribute Delete
11.9 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
from itertools import chain
import torch.utils.checkpoint as cp
from collections import defaultdict
from math import log2
from timm.models.registry import register_model
from ..utils.utils import _UpsampleBlend
__all__ = ['ResNet', 'resnet18', 'resnet34', 'BasicBlock']
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}
def convkxk(in_planes, out_planes, stride=1, k=3):
"""kxk convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=k, stride=stride, padding=k // 2, bias=False)
def _bn_function_factory(conv, norm, relu=None):
def bn_function(x):
x = norm(conv(x))
if relu is not None:
x = relu(x)
return x
return bn_function
def do_efficient_fwd(block, x, efficient):
# return block(x)
if efficient and x.requires_grad:
return cp.checkpoint(block, x)
else:
return block(x)
class Identity(nn.Module):
def __init__(self, *args, **kwargs):
super(Identity, self).__init__()
def forward(self, input):
return input
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, efficient=True, bn_class=nn.BatchNorm2d, levels=3):
super(BasicBlock, self).__init__()
self.conv1 = convkxk(inplanes, planes, stride)
self.bn1 = nn.ModuleList([bn_class(planes) for _ in range(levels)])
self.relu_inp = nn.ReLU(inplace=True)
self.relu = nn.ReLU(inplace=False)
self.conv2 = convkxk(planes, planes)
self.bn2 = nn.ModuleList([bn_class(planes) for _ in range(levels)])
self.downsample = downsample
self.stride = stride
self.efficient = efficient
self.num_levels = levels
def forward(self, x, level):
residual = x
bn_1 = _bn_function_factory(self.conv1, self.bn1[level], self.relu_inp)
bn_2 = _bn_function_factory(self.conv2, self.bn2[level])
out = do_efficient_fwd(bn_1, x, self.efficient)
out = do_efficient_fwd(bn_2, out, self.efficient)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
relu = self.relu(out)
return relu, out
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
super(BasicBlock, self)._load_from_state_dict(state_dict, prefix, local_metadata, False, missing_keys,
unexpected_keys, error_msgs)
missing_keys = []
unexpected_keys = []
for bn in self.bn1:
bn._load_from_state_dict(state_dict, prefix + 'bn1.', local_metadata, strict, missing_keys, unexpected_keys,
error_msgs)
for bn in self.bn2:
bn._load_from_state_dict(state_dict, prefix + 'bn2.', local_metadata, strict, missing_keys, unexpected_keys,
error_msgs)
class ResNet(nn.Module):
def _make_layer(self, block, planes, blocks, stride=1, bn_class=nn.BatchNorm2d):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
bn_class(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, self.efficient, bn_class=bn_class,
levels=self.pyramid_levels))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, bn_class=bn_class, levels=self.pyramid_levels, efficient=self.efficient))
return nn.Sequential(*layers)
def __init__(self, block, layers, *, num_features=128, pyramid_levels=3, use_bn=True, k_bneck=1, k_upsample=3,
efficient=False, upsample_skip=True, mean=(73.1584, 82.9090, 72.3924),
std=(44.9149, 46.1529, 45.3192), scale=1, detach_upsample_skips=(), detach_upsample_in=False,
align_corners=None, pyramid_subsample='bicubic', target_size=None,
output_stride=4, **kwargs):
self.inplanes = 64
self.efficient = efficient
super(ResNet, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
bn_class = nn.BatchNorm2d if use_bn else Identity
self.register_buffer('img_mean', torch.tensor(mean).view(1, -1, 1, 1))
self.register_buffer('img_std', torch.tensor(std).view(1, -1, 1, 1))
if scale != 1:
self.register_buffer('img_scale', torch.tensor(scale).view(1, -1, 1, 1).float())
self.pyramid_levels = pyramid_levels
self.num_features = num_features
self.replicated = False
self.align_corners = align_corners
self.pyramid_subsample = pyramid_subsample
self.bn1 = nn.ModuleList([bn_class(64) for _ in range(pyramid_levels)])
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
bottlenecks = []
self.layer1 = self._make_layer(block, 64, layers[0], bn_class=bn_class)
bottlenecks += [convkxk(self.inplanes, num_features, k=k_bneck)]
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, bn_class=bn_class)
bottlenecks += [convkxk(self.inplanes, num_features, k=k_bneck)]
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, bn_class=bn_class)
bottlenecks += [convkxk(self.inplanes, num_features, k=k_bneck)]
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, bn_class=bn_class)
bottlenecks += [convkxk(self.inplanes, num_features, k=k_bneck)]
num_bn_remove = max(0, int(log2(output_stride) - 2))
self.num_skip_levels = self.pyramid_levels + 3 - num_bn_remove
bottlenecks = bottlenecks[num_bn_remove:]
self.fine_tune = [self.conv1, self.maxpool, self.layer1, self.layer2, self.layer3, self.layer4, self.bn1]
self.upsample_bottlenecks = nn.ModuleList(bottlenecks[::-1])
num_pyr_modules = 2 + pyramid_levels - num_bn_remove
self.target_size = target_size
if self.target_size is not None:
h, w = target_size
target_sizes = [(h // 2 ** i, w // 2 ** i) for i in range(2, 2 + num_pyr_modules)][::-1]
else:
target_sizes = [None] * num_pyr_modules
self.upsample_blends = nn.ModuleList(
[_UpsampleBlend(num_features,
use_bn=use_bn,
use_skip=upsample_skip,
detach_skip=i in detach_upsample_skips,
fixed_size=ts,
k=k_upsample)
for i, ts in enumerate(target_sizes)])
self.detach_upsample_in = detach_upsample_in
self.random_init = [self.upsample_bottlenecks, self.upsample_blends]
self.features = num_features
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def random_init_params(self):
return chain(*[f.parameters() for f in self.random_init])
def fine_tune_params(self):
return chain(*[f.parameters() for f in self.fine_tune])
def forward_resblock(self, x, layers, idx):
skip = None
for l in layers:
x = l(x) if not isinstance(l, BasicBlock) else l(x, idx)
if isinstance(x, tuple):
x, skip = x
return x, skip
def forward_down(self, image, skips, idx=-1):
x = self.conv1(image)
x = self.bn1[idx](x)
x = self.relu(x)
x = self.maxpool(x)
features = []
x, skip = self.forward_resblock(x, self.layer1, idx)
features += [skip]
x, skip = self.forward_resblock(x, self.layer2, idx)
features += [skip]
x, skip = self.forward_resblock(x, self.layer3, idx)
features += [skip]
x, skip = self.forward_resblock(x, self.layer4, idx)
features += [skip]
skip_feats = [b(f) for b, f in zip(self.upsample_bottlenecks, reversed(features))]
for i, s in enumerate(reversed(skip_feats)):
skips[idx + i] += [s]
return skips
def forward(self, image):
if isinstance(self.bn1[0], nn.BatchNorm2d):
if hasattr(self, 'img_scale'):
image /= self.img_scale
image -= self.img_mean
image /= self.img_std
pyramid = [image]
for l in range(1, self.pyramid_levels):
if self.target_size is not None:
ts = list([si // 2 ** l for si in self.target_size])
pyramid += [
F.interpolate(image, size=ts, mode=self.pyramid_subsample, align_corners=self.align_corners)]
else:
pyramid += [F.interpolate(image, scale_factor=1 / 2 ** l, mode=self.pyramid_subsample,
align_corners=self.align_corners)]
skips = [[] for _ in range(self.num_skip_levels)]
additional = {'pyramid': pyramid}
for idx, p in enumerate(pyramid):
skips = self.forward_down(p, skips, idx=idx)
skips = skips[::-1]
x = skips[0][0]
if self.detach_upsample_in:
x = x.detach()
for i, (sk, blend) in enumerate(zip(skips[1:], self.upsample_blends)):
x = blend(x, sum(sk))
return x, additional
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
super(ResNet, self)._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys,
unexpected_keys, error_msgs)
for bn in self.bn1:
bn._load_from_state_dict(state_dict, prefix + 'bn1.', local_metadata, strict, missing_keys, unexpected_keys,
error_msgs)
@register_model
def resnet18(pretrained=True, **kwargs):
"""Constructs a ResNet-18 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet18']), strict=False)
return model
@register_model
def resnet34(pretrained=True, **kwargs):
"""Constructs a ResNet-34 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet34']), strict=False)
return model