|
|
import torch |
|
|
from torch import nn |
|
|
|
|
|
|
|
|
class Bottleneck(nn.Module): |
|
|
expansion = 4 |
|
|
|
|
|
def __init__(self, inplanes, planes, stride=1, downsample=None, bn_momentum=0.1): |
|
|
super(Bottleneck, self).__init__() |
|
|
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) |
|
|
self.bn1 = nn.BatchNorm2d(planes, momentum=bn_momentum) |
|
|
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) |
|
|
self.bn2 = nn.BatchNorm2d(planes, momentum=bn_momentum) |
|
|
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) |
|
|
self.bn3 = nn.BatchNorm2d(planes * self.expansion, momentum=bn_momentum) |
|
|
self.relu = nn.ReLU(inplace=True) |
|
|
self.downsample = downsample |
|
|
self.stride = stride |
|
|
|
|
|
def forward(self, x): |
|
|
residual = x |
|
|
|
|
|
out = self.conv1(x) |
|
|
out = self.bn1(out) |
|
|
out = self.relu(out) |
|
|
|
|
|
out = self.conv2(out) |
|
|
out = self.bn2(out) |
|
|
out = self.relu(out) |
|
|
|
|
|
out = self.conv3(out) |
|
|
out = self.bn3(out) |
|
|
|
|
|
if self.downsample is not None: |
|
|
residual = self.downsample(x) |
|
|
|
|
|
out += residual |
|
|
out = self.relu(out) |
|
|
|
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BasicBlock(nn.Module): |
|
|
expansion = 1 |
|
|
|
|
|
def __init__(self, inplanes, planes, stride=1, downsample=None, bn_momentum=0.1): |
|
|
super(BasicBlock, self).__init__() |
|
|
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) |
|
|
self.bn1 = nn.BatchNorm2d(planes, momentum=bn_momentum) |
|
|
self.relu = nn.ReLU(inplace=True) |
|
|
self.conv2 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=1, padding=1, bias=False) |
|
|
self.bn2 = nn.BatchNorm2d(planes, momentum=bn_momentum) |
|
|
self.downsample = downsample |
|
|
self.stride = stride |
|
|
|
|
|
def forward(self, x): |
|
|
residual = x |
|
|
|
|
|
out = self.conv1(x) |
|
|
out = self.bn1(out) |
|
|
out = self.relu(out) |
|
|
|
|
|
out = self.conv2(out) |
|
|
out = self.bn2(out) |
|
|
|
|
|
if self.downsample is not None: |
|
|
residual = self.downsample(x) |
|
|
|
|
|
out += residual |
|
|
out = self.relu(out) |
|
|
|
|
|
return out |