Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchvision.models import resnet18 | |
| DEVICE = torch.device( | |
| "cuda" if torch.cuda.is_available() else "cpu" | |
| ) | |
| class SimCLRResNet18(nn.Module): | |
| def __init__(self, out_dim=256): | |
| super().__init__() | |
| backbone = resnet18(weights=None) | |
| modules = list(backbone.children())[:-1] # remove fc | |
| self.backbone = nn.Sequential(*modules) | |
| self.fc = nn.Linear(512, out_dim) | |
| def forward(self, x): | |
| x = self.backbone(x) | |
| x = torch.flatten(x, 1) | |
| x = self.fc(x) | |
| return F.normalize(x, dim=1) | |
| def load_encoder(path="encoder_resnet18_simclr.pth"): | |
| model = SimCLRResNet18().to(DEVICE) | |
| state = torch.load(path, map_location=DEVICE) | |
| # If state dict has "backbone." prefix | |
| new_state = {} | |
| for k, v in state.items(): | |
| k = k.replace("backbone.", "") | |
| k = k.replace("projector.", "") | |
| new_state[k] = v | |
| try: | |
| model.load_state_dict(state, strict=False) | |
| except: | |
| model.load_state_dict(new_state, strict=False) | |
| model.eval() | |
| return model | |