astrosbd
Initial commit
5c783e4
"""
Creates an Xception Model as defined in:
Francois Chollet
Xception: Deep Learning with Depthwise Separable Convolutions
https://arxiv.org/pdf/1610.02357.pdf
This weights ported from the Keras implementation. Achieves the following performance on the validation set:
Loss:0.9173 Prec@1:78.892 Prec@5:94.292
REMEMBER to set your image size to 3x299x299 for both test and validation
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5])
The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
"""
import math
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
from torch.nn import init
import torch
from ..builder import MODELS
from .common import conv_block, BN_MOMENTUM
model_urls = {
'xception':'https://www.dropbox.com/s/1hplpzet9d7dv29/xception-c0a72b38.pth.tar?dl=1'
}
class SeparableConv2d(nn.Module):
def __init__(self,in_channels,out_channels,kernel_size=1,stride=1,padding=0,dilation=1,bias=False):
super(SeparableConv2d,self).__init__()
self.conv1 = nn.Conv2d(in_channels,in_channels,kernel_size,stride,padding,dilation,groups=in_channels,bias=bias)
self.pointwise = nn.Conv2d(in_channels,out_channels,1,1,0,1,1,bias=bias)
def forward(self,x):
x = self.conv1(x)
x = self.pointwise(x)
return x
class Block(nn.Module):
def __init__(self,in_filters,out_filters,reps,strides=1,start_with_relu=True,grow_first=True):
super(Block, self).__init__()
if out_filters != in_filters or strides!=1:
self.skip = nn.Conv2d(in_filters,out_filters,1,stride=strides, bias=False)
self.skipbn = nn.BatchNorm2d(out_filters)
else:
self.skip=None
self.relu = nn.ReLU(inplace=True)
rep=[]
filters=in_filters
if grow_first:
rep.append(self.relu)
rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=1,bias=False))
rep.append(nn.BatchNorm2d(out_filters))
filters = out_filters
for i in range(reps-1):
rep.append(self.relu)
rep.append(SeparableConv2d(filters,filters,3,stride=1,padding=1,bias=False))
rep.append(nn.BatchNorm2d(filters))
if not grow_first:
rep.append(self.relu)
rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=1,bias=False))
rep.append(nn.BatchNorm2d(out_filters))
if not start_with_relu:
rep = rep[1:]
else:
rep[0] = nn.ReLU(inplace=False)
if strides != 1:
rep.append(nn.MaxPool2d(3,strides,1))
self.rep = nn.Sequential(*rep)
def forward(self,inp):
x = self.rep(inp)
if self.skip is not None:
skip = self.skip(inp)
skip = self.skipbn(skip)
else:
skip = inp
x+=skip
return x
@MODELS.register_module()
class Xception(nn.Module):
"""
Xception optimized for the ImageNet dataset, as specified in
https://arxiv.org/pdf/1610.02357.pdf
"""
def __init__(self,
heads,
head_conv=64,
cls_based_hm=True,
dropout_prob=0.5,
**kwargs):
""" Constructor
Args:
num_classes: number of classes
"""
self.heads = heads
self.head_conv = head_conv
self.cls_based_hm = cls_based_hm
self.dropout_prob = dropout_prob
super(Xception, self).__init__()
self.conv1 = nn.Conv2d(3, 32, 3,2, 0, bias=False)
self.bn1 = nn.BatchNorm2d(32)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(32,64,3,bias=False)
self.bn2 = nn.BatchNorm2d(64)
#do relu here
self.block1=Block(64,128,2,2,start_with_relu=False,grow_first=True)
self.block2=Block(128,256,2,2,start_with_relu=True,grow_first=True)
self.block3=Block(256,728,2,2,start_with_relu=True,grow_first=True)
self.block4=Block(728,728,3,1,start_with_relu=True,grow_first=True)
self.block5=Block(728,728,3,1,start_with_relu=True,grow_first=True)
self.block6=Block(728,728,3,1,start_with_relu=True,grow_first=True)
self.block7=Block(728,728,3,1,start_with_relu=True,grow_first=True)
self.block8=Block(728,728,3,1,start_with_relu=True,grow_first=True)
self.block9=Block(728,728,3,1,start_with_relu=True,grow_first=True)
self.block10=Block(728,728,3,1,start_with_relu=True,grow_first=True)
self.block11=Block(728,728,3,1,start_with_relu=True,grow_first=True)
self.block12=Block(728,1024,2,2,start_with_relu=True,grow_first=False)
self.conv3 = SeparableConv2d(1024,1536,3,1,1)
self.bn3 = nn.BatchNorm2d(1536)
#do relu here
self.conv4 = SeparableConv2d(1536,2048,3,1,1)
self.bn4 = nn.BatchNorm2d(2048)
self.dropout = nn.Dropout2d(p=self.dropout_prob)
self.conv_block_1 = conv_block(2048, 256, (3,3), padding=1)
self.deconv_1 = nn.Sequential(
nn.ConvTranspose2d(
in_channels=256,
out_channels=256,
kernel_size=(4,4),
stride=2,
padding=1,
output_padding=0,
bias=False),
nn.BatchNorm2d(256, momentum=BN_MOMENTUM),
nn.ReLU(inplace=True)
)
self.conv_block_2 = conv_block(256, 256, (3,3), padding=1)
self.deconv_2 = nn.Sequential(
nn.ConvTranspose2d(
in_channels=256,
out_channels=128,
kernel_size=(4,4),
stride=2,
padding=1,
output_padding=0,
bias=False),
nn.BatchNorm2d(128, momentum=BN_MOMENTUM),
nn.ReLU(inplace=True)
)
self.conv_block_3 = conv_block(128, 128, (3,3), padding=1)
self.deconv_3 = nn.Sequential(
nn.ConvTranspose2d(
in_channels=128,
out_channels=64,
kernel_size=(4,4),
stride=2,
padding=1,
output_padding=0,
bias=False),
nn.BatchNorm2d(64, momentum=BN_MOMENTUM),
nn.ReLU(inplace=True)
)
for head in sorted(self.heads):
num_output = self.heads[head]
if self.head_conv > 0:
if head != 'cls':
fc = nn.Sequential(
nn.Conv2d(64, self.head_conv,
kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(self.head_conv),
nn.ReLU(inplace=True),
nn.Conv2d(self.head_conv, num_output,
kernel_size=1, stride=1, padding=0)
)
else:
if self.cls_based_hm:
fc = nn.Sequential(
nn.AdaptiveAvgPool2d(head_conv//4),
nn.Flatten(),
nn.Linear((head_conv//4)**2, head_conv, bias=False),
nn.BatchNorm1d(head_conv, momentum=BN_MOMENTUM),
nn.ReLU(inplace=True),
nn.Linear(head_conv, num_output, bias=True),
nn.Sigmoid()
)
else:
fc = nn.Sequential(
nn.Conv2d(64, head_conv, kernel_size=3,
padding=1, bias=False),
nn.BatchNorm2d(head_conv, momentum=BN_MOMENTUM),
nn.ReLU(inplace=True),
nn.Conv2d(head_conv, num_output, kernel_size=1,
stride=1, padding=0, bias=False),
nn.BatchNorm2d(num_output),
# nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d(head_conv//4),
nn.Flatten(),
nn.Linear((head_conv//4)**2, head_conv, bias=False),
nn.BatchNorm1d(head_conv, momentum=BN_MOMENTUM),
nn.ReLU(inplace=True),
nn.Linear(head_conv, num_output, bias=True),
nn.Sigmoid()
)
else:
fc = nn.Conv2d(
in_channels=64,
out_channels=num_output,
kernel_size=1,
stride=1,
padding=0
)
self.__setattr__(head, fc)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
x = self.block4(x)
x = self.block5(x)
x = self.block6(x)
x = self.block7(x)
x = self.block8(x)
x = self.block9(x)
x = self.block10(x)
x = self.block11(x)
x = self.block12(x)
x = self.conv3(x)
x = self.bn3(x)
x = self.relu(x)
x = self.conv4(x)
x = self.bn4(x)
x = self.relu(x)
x = self.dropout(x)
x = self.conv_block_1(x)
x = self.deconv_1(x)
x = self.conv_block_2(x)
x = self.deconv_2(x)
x = self.conv_block_3(x)
x = self.deconv_3(x)
ret = {}
x1_hm = None
for head in self.heads:
if not self.cls_based_hm or head != 'cls':
ret[head] = self.__getattr__(head)(x)
if head == 'hm':
x1_hm = ret[head]
else:
assert 'hm' in ret.keys(), "Other heads need features from heatmap, please check it!"
ret[head] = self.__getattr__(head)(x1_hm)
return [ret]
def init_weights(self, pretrained=False):
if not pretrained:
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.ConvTranspose2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
if self.deconv_with_bias:
nn.init.constant_(m.bias, 0)
else:
self.load_state_dict(model_zoo.load_url(model_urls['xception']), strict=False)
# Init head parameters
for head in self.heads:
final_layer = self.__getattr__(head)
for i, m in enumerate(final_layer.modules()):
prior = 1/71
# if isinstance(m, nn.Conv2d):
# if m.weight.shape[0] == self.heads[head]:
# if 'hm' in head:
# # nn.init.constant_(m.bias, -2.19)
# nn.init.constant_(m.bias, -math.log((1-prior)/prior))
# else:
# nn.init.normal_(m.weight, std=0.001)
# # nn.init.constant_(m.bias, 0)
if isinstance(m, nn.Linear):
if m.weight.shape[0] == self.heads[head]:
nn.init.constant_(m.bias, -math.log((1-prior)/prior))
# else:
# n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
# m.weight.data.normal_(0, math.sqrt(2. / n))
# # nn.init.constant_(m.bias, 0)