Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| from torchvision import models | |
| from torchvision.ops.misc import FrozenBatchNorm2d | |
| class Backbone(nn.Module): | |
| def __init__( | |
| self, | |
| name: str, | |
| pretrained: bool, | |
| dilation: bool, | |
| reduction: int, | |
| swav: bool, | |
| requires_grad: bool | |
| ): | |
| super(Backbone, self).__init__() | |
| resnet = getattr(models, name)( | |
| replace_stride_with_dilation=[False, False, dilation], | |
| pretrained=pretrained, norm_layer=FrozenBatchNorm2d | |
| ) | |
| self.backbone = resnet | |
| self.reduction = reduction | |
| if name == 'resnet50' and swav: | |
| checkpoint = torch.hub.load_state_dict_from_url( | |
| 'https://dl.fbaipublicfiles.com/deepcluster/swav_800ep_pretrain.pth.tar', | |
| map_location="cpu" | |
| ) | |
| state_dict = {k.replace("module.", ""): v for k, v in checkpoint.items()} | |
| self.backbone.load_state_dict(state_dict, strict=False) | |
| # concatenation of layers 2, 3 and 4 | |
| self.num_channels = 896 if name in ['resnet18', 'resnet34'] else 3584 | |
| for n, param in self.backbone.named_parameters(): | |
| if 'layer2' not in n and 'layer3' not in n and 'layer4' not in n: | |
| param.requires_grad_(False) | |
| else: | |
| param.requires_grad_(requires_grad) | |
| def forward(self, x): | |
| size = x.size(-2) // self.reduction, x.size(-1) // self.reduction | |
| x = self.backbone.conv1(x) | |
| x = self.backbone.bn1(x) | |
| x = self.backbone.relu(x) | |
| x = self.backbone.maxpool(x) | |
| x = self.backbone.layer1(x) | |
| x = layer2 = self.backbone.layer2(x) | |
| x = layer3 = self.backbone.layer3(x) | |
| x = layer4 = self.backbone.layer4(x) | |
| x = torch.cat([ | |
| F.interpolate(f, size=size, mode='bilinear', align_corners=True) | |
| for f in [layer2, layer3, layer4] | |
| ], dim=1) | |
| return x | |