import torch from torch import nn class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None, bn_momentum=0.1): super(Bottleneck, self).__init__() self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(planes, momentum=bn_momentum) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes, momentum=bn_momentum) self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(planes * self.expansion, momentum=bn_momentum) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out # class Bottleneck_Tranpose(nn.Module): # expansion = 4 # def __init__(self, inplanes, planes, stride=1, downsample=None, bn_momentum=0.1): # super(Bottleneck, self).__init__() # nn.ConvTranspose2d(c, 64, (3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1)), # self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) # self.bn1 = nn.BatchNorm2d(planes, momentum=bn_momentum) # self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) # self.bn2 = nn.BatchNorm2d(planes, momentum=bn_momentum) # self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) # self.bn3 = nn.BatchNorm2d(planes * self.expansion, momentum=bn_momentum) # self.relu = nn.ReLU(inplace=True) # self.downsample = downsample # self.stride = stride # def forward(self, x): # residual = x # out = self.conv1(x) # out = self.bn1(out) # out = self.relu(out) # out = self.conv2(out) # out = self.bn2(out) # out = self.relu(out) # out = self.conv3(out) # out = self.bn3(out) # if self.downsample is not None: # residual = self.downsample(x) # out += residual # out = self.relu(out) # return out class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None, bn_momentum=0.1): super(BasicBlock, self).__init__() self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(planes, momentum=bn_momentum) self.relu = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=1, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes, momentum=bn_momentum) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out