Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torchvision.models.resnet | |
| from torchvision.models.resnet import BasicBlock, Bottleneck | |
| class ResNet(torchvision.models.resnet.ResNet): | |
| def __init__(self, block, layers, num_classes=1000): | |
| super(ResNet, self).__init__(block, layers, num_classes) | |
| self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) # change | |
| for i in range(2, 5): | |
| getattr(self, 'layer%d'%i)[0].conv1.stride = (2,2) | |
| getattr(self, 'layer%d'%i)[0].conv2.stride = (1,1) | |
| def resnet18(pretrained=False): | |
| """Constructs a ResNet-18 model. | |
| Args: | |
| pretrained (bool): If True, returns a model pre-trained on ImageNet | |
| """ | |
| model = ResNet(BasicBlock, [2, 2, 2, 2]) | |
| if pretrained: | |
| model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) | |
| return model | |
| def resnet34(pretrained=False): | |
| """Constructs a ResNet-34 model. | |
| Args: | |
| pretrained (bool): If True, returns a model pre-trained on ImageNet | |
| """ | |
| model = ResNet(BasicBlock, [3, 4, 6, 3]) | |
| if pretrained: | |
| model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) | |
| return model | |
| def resnet50(pretrained=False): | |
| """Constructs a ResNet-50 model. | |
| Args: | |
| pretrained (bool): If True, returns a model pre-trained on ImageNet | |
| """ | |
| model = ResNet(Bottleneck, [3, 4, 6, 3]) | |
| if pretrained: | |
| model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) | |
| return model | |
| def resnet101(pretrained=False): | |
| """Constructs a ResNet-101 model. | |
| Args: | |
| pretrained (bool): If True, returns a model pre-trained on ImageNet | |
| """ | |
| model = ResNet(Bottleneck, [3, 4, 23, 3]) | |
| if pretrained: | |
| model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) | |
| return model | |
| def resnet152(pretrained=False): | |
| """Constructs a ResNet-152 model. | |
| Args: | |
| pretrained (bool): If True, returns a model pre-trained on ImageNet | |
| """ | |
| model = ResNet(Bottleneck, [3, 8, 36, 3]) | |
| if pretrained: | |
| model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) | |
| return model |