Spaces:
Sleeping
Sleeping
| from torch import nn | |
| from timm import create_model | |
| from torchvision.transforms import Normalize | |
| class SwinModel(nn.Module): | |
| def __init__(self, pretrained_model="swinv2-cr-t-224", device="cuda") -> None: | |
| """ | |
| vit_tiny_patch16_224.augreg_in21k_ft_in1k | |
| swinv2_cr_tiny_ns_224.sw_in1k | |
| """ | |
| super().__init__() | |
| self.device = device | |
| self.pretrained_model = pretrained_model | |
| if pretrained_model == "swinv2-cr-t-224": | |
| self.pretrained = create_model( | |
| "swinv2_cr_tiny_ns_224.sw_in1k", | |
| pretrained=True, | |
| features_only=True, | |
| out_indices=[-4, -3, -2, -1], | |
| ).to(device) | |
| elif pretrained_model == "swinv2-t-256": | |
| self.pretrained = create_model( | |
| "swinv2_tiny_window16_256.ms_in1k", | |
| pretrained=True, | |
| features_only=True, | |
| out_indices=[-4, -3, -2, -1], | |
| ).to(device) | |
| elif pretrained_model == "swinv2-cr-s-224": | |
| self.pretrained = create_model( | |
| "swinv2_cr_small_ns_224.sw_in1k", | |
| pretrained=True, | |
| features_only=True, | |
| out_indices=[-4, -3, -2, -1], | |
| ).to(device) | |
| else: | |
| raise NotImplementedError | |
| self.pretrained.eval() | |
| self.normalizer = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| self.upsample = nn.Upsample(scale_factor=2) | |
| for params in self.pretrained.parameters(): | |
| params.requires_grad = False | |
| def forward(self, x): | |
| outputs = self.pretrained(x) | |
| if self.pretrained_model in ["swinv2-t-256"]: | |
| for i in range(len(outputs)): | |
| outputs[i] = outputs[i].permute(0, 3, 1, 2) # Change channel-last to channel-first | |
| outputs = [self.upsample(feat) for feat in outputs] | |
| return outputs |