Spaces:
Sleeping
Sleeping
| from torch import nn as nn | |
| from torchvision.models import resnet18 | |
| class BasicBlock(nn.Module): | |
| expansion = 1 | |
| def __init__(self, in_channels, out_channels, stride=1, downsample=None, num_groups=8): | |
| super(BasicBlock, self).__init__() | |
| self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, | |
| stride=stride, padding=1, bias=False) | |
| self.gn1 = nn.GroupNorm(num_groups, out_channels) | |
| self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, | |
| stride=1, padding=1, bias=False) | |
| self.gn2 = nn.GroupNorm(num_groups, out_channels) | |
| self.downsample = downsample | |
| self.gelu = nn.GELU() | |
| def forward(self, x): | |
| identity = x | |
| out = self.gelu(self.gn1(self.conv1(x))) | |
| out = self.gn2(self.conv2(out)) | |
| if self.downsample is not None: | |
| identity = self.downsample(x) | |
| out += identity | |
| out = self.gelu(out) | |
| return out | |
| class ResNetFeatureWarpper(nn.Module): | |
| def __init__(self, shallow_resnet_feature=False): | |
| super(ResNetFeatureWarpper, self).__init__() | |
| self.shallow_resnet_feature = shallow_resnet_feature | |
| resnet = resnet18(pretrained=True) | |
| self.conv1 = resnet.conv1 | |
| self.bn1 = resnet.bn1 | |
| self.relu = resnet.relu | |
| self.maxpool = resnet.maxpool | |
| self.layer1 = resnet.layer1 | |
| if not shallow_resnet_feature: | |
| self.layer2 = resnet.layer2 | |
| def forward(self, x): | |
| out = [] | |
| x = self.conv1(x) | |
| out.append(x) | |
| x = self.bn1(x) | |
| x = self.relu(x) | |
| x = self.maxpool(x) | |
| x = self.layer1(x) | |
| out.append(x) | |
| if not self.shallow_resnet_feature: | |
| x = self.layer2(x) | |
| out.append(x) | |
| return out | |