ColorMNet / model /resnet.py
root
add test code
d01f62c
"""
resnet.py - A modified ResNet structure
We append extra channels to the first conv by some network surgery
"""
from collections import OrderedDict
import math
import torch
import torch.nn as nn
from torch.utils import model_zoo
from torch.hub import load
import torchvision.models as models
import warnings
warnings.filterwarnings("ignore")
import torch.nn.functional as F
from einops import rearrange
def load_weights_add_extra_dim(target, source_state, extra_dim=1):
new_dict = OrderedDict()
for k1, v1 in target.state_dict().items():
if not 'num_batches_tracked' in k1:
if k1 in source_state:
tar_v = source_state[k1]
if v1.shape != tar_v.shape:
# Init the new segmentation channel with zeros
# print(v1.shape, tar_v.shape)
c, _, w, h = v1.shape
pads = torch.zeros((c,extra_dim,w,h), device=tar_v.device)
nn.init.orthogonal_(pads)
tar_v = torch.cat([tar_v, pads], 1)
new_dict[k1] = tar_v
target.load_state_dict(new_dict)
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
}
def conv3x3(in_planes, out_planes, stride=1, dilation=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, dilation=dilation, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride=stride, dilation=dilation)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes, stride=1, dilation=dilation)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, dilation=dilation,
padding=dilation, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, layers=(3, 4, 23, 3), extra_dim=0):
self.inplanes = 64
super(ResNet, self).__init__()
self.conv1 = nn.Conv2d(3+extra_dim, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = [block(self.inplanes, planes, stride, downsample)]
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, dilation=dilation))
return nn.Sequential(*layers)
def resnet18(pretrained=True, extra_dim=0):
model = ResNet(BasicBlock, [2, 2, 2, 2], extra_dim)
if pretrained:
load_weights_add_extra_dim(model, model_zoo.load_url(model_urls['resnet18']), extra_dim)
return model
def resnet50(pretrained=True, extra_dim=0):
model = ResNet(Bottleneck, [3, 4, 6, 3], extra_dim)
if pretrained:
load_weights_add_extra_dim(model, model_zoo.load_url(model_urls['resnet50']), extra_dim)
return model
dino_backbones = {
'dinov2_s':{
'name':'dinov2_vits14',
'embedding_size':384,
'patch_size':14
},
'dinov2_b':{
'name':'dinov2_vitb14',
'embedding_size':768,
'patch_size':14
},
'dinov2_l':{
'name':'dinov2_vitl14',
'embedding_size':1024,
'patch_size':14
},
'dinov2_g':{
'name':'dinov2_vitg14',
'embedding_size':1536,
'patch_size':14
},
}
class conv_head(nn.Module):
def __init__(self, embedding_size = 384, num_classes = 5):
super(conv_head, self).__init__()
self.segmentation_conv = nn.Sequential(
nn.Upsample(scale_factor=2),
nn.Conv2d(embedding_size, 64, (3,3), padding=(1,1)),
nn.Upsample(scale_factor=2),
nn.Conv2d(64, num_classes, (3,3), padding=(1,1)),
)
def forward(self, x):
x = self.segmentation_conv(x)
x = torch.sigmoid(x)
return x
class Segmentor(nn.Module):
def __init__(self, num_classes=5, backbone = 'dinov2_s', head = 'conv', backbones = dino_backbones):
super(Segmentor, self).__init__()
self.heads = {
'conv':conv_head
}
# internet
self.backbones = dino_backbones
self.backbone = load('facebookresearch/dinov2', self.backbones[backbone]['name']) # add trust_repo to
self.backbone.eval()
# # local
# self.backbones = dino_backbones
# self.backbone = load('/root/.cache/torch/hub/facebookresearch_dinov2_main', self.backbones[backbone]['name'], source='local', pretrained=False) # add trust_repo to
# self.backbone.load_state_dict(torch.load('/root/.cache/torch/hub/checkpoints/dinov2_vits14_pretrain.pth'))
# self.backbone.eval()
self.conv3 = nn.Conv2d(1536, 1536, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(1536)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
with torch.no_grad():
tokens = self.backbone.get_intermediate_layers(x, n=[8, 9, 10, 11], reshape=True) # last n=4 [8, 9, 10, 11]
f16 = torch.cat(tokens, dim=1)
f16 = self.conv3(f16)
f16 = self.bn3(f16)
f16 = self.relu(f16)
old_size = (f16.shape[2], f16.shape[3])
new_size = (int(old_size[0]*14/16), int(old_size[1]*14/16))
f16 = F.interpolate(f16, size=new_size, mode='bilinear', align_corners=False) # scale_factor=3.5
return f16
class LayerNormFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, weight, bias, eps):
ctx.eps = eps
N, C, H, W = x.size()
mu = x.mean(1, keepdim=True)
var = (x - mu).pow(2).mean(1, keepdim=True)
y = (x - mu) / (var + eps).sqrt()
ctx.save_for_backward(y, var, weight)
y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
return y
@staticmethod
def backward(ctx, grad_output):
eps = ctx.eps
N, C, H, W = grad_output.size()
y, var, weight = ctx.saved_variables
g = grad_output * weight.view(1, C, 1, 1)
mean_g = g.mean(dim=1, keepdim=True)
mean_gy = (g * y).mean(dim=1, keepdim=True)
gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum(
dim=0), None
class LayerNorm2d(nn.Module):
def __init__(self, channels, eps=1e-6):
super(LayerNorm2d, self).__init__()
self.register_parameter('weight', nn.Parameter(torch.ones(channels)))
self.register_parameter('bias', nn.Parameter(torch.zeros(channels)))
self.eps = eps
def forward(self, x):
return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)
class CrossChannelAttention(nn.Module):
def __init__(self, dim, heads=8):
super().__init__()
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
self.heads = heads
self.to_q = nn.Conv2d(dim, dim * 2, kernel_size=1, bias=True)
self.to_q_dw = nn.Conv2d(dim * 2, dim * 2, kernel_size=3, stride=1, padding=1, groups=dim * 2, bias=True)
self.to_k = nn.Conv2d(dim, dim * 2, kernel_size=1, bias=True)
self.to_k_dw = nn.Conv2d(dim * 2, dim * 2, kernel_size=3, stride=1, padding=1, groups=dim * 2, bias=True)
self.to_v = nn.Conv2d(dim, dim * 2, kernel_size=1, bias=True)
self.to_v_dw = nn.Conv2d(dim * 2, dim * 2, kernel_size=3, stride=1, padding=1, groups=dim * 2, bias=True)
self.to_out = nn.Sequential(
nn.Conv2d(dim*2, dim,1,1,0),
)
def forward(self, encoder, decoder):
# h = self.heads
b, c, h, w = encoder.shape
q = self.to_q_dw(self.to_q(encoder))
k = self.to_k_dw(self.to_k(decoder))
v = self.to_v_dw(self.to_v(decoder))
q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.heads)
k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.heads)
v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.heads)
q = torch.nn.functional.normalize(q, dim=-1)
k = torch.nn.functional.normalize(k, dim=-1)
attn = (q @ k.transpose(-2, -1)) * self.temperature
attn = attn.softmax(dim=-1)
out = (attn @ v)
out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.heads, h=h, w=w)
return self.to_out(out)
def normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
@torch.jit.script
def swish(x):
return x * torch.sigmoid(x)
class ResBlock(nn.Module):
def __init__(self, in_channels, out_channels=None):
super(ResBlock, self).__init__()
self.in_channels = in_channels
self.out_channels = in_channels if out_channels is None else out_channels
self.norm1 = normalize(in_channels)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.norm2 = normalize(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x_in):
x = x_in
x = self.norm1(x)
x = swish(x)
# x = x * torch.sigmoid(x)
x = self.conv1(x)
x = self.norm2(x)
x = swish(x)
# x = x * torch.sigmoid(x)
x = self.conv2(x)
if self.in_channels != self.out_channels:
x_in = self.conv_out(x_in)
return x + x_in
class Fuse(nn.Module):
def __init__(self, dine_feat, out_feat):
# need to key same channel and HW for enc / dnc
super(Fuse, self).__init__()
self.encode_enc = nn.Conv2d(dine_feat, out_feat, kernel_size=3, stride=1, padding=1)
self.dim = out_feat
self.norm1 = LayerNorm2d(self.dim)
self.norm2 = LayerNorm2d(self.dim)
self.dine_feat = dine_feat
self.out_feat = out_feat
self.crossattn = CrossChannelAttention(dim=out_feat)
self.norm3 = LayerNorm2d(self.dim)
self.relu3 = nn.ReLU(inplace=True)
def forward(self, enc, dnc):
enc = self.encode_enc(enc)
res = enc
enc = self.norm1(enc)
dnc = self.norm2(dnc)
output = self.crossattn(enc, dnc) + res
output = self.norm3(output)
output = self.relu3(output)
return output