import torch import torch.nn as nn import torch.nn.functional as F from torchvision.models import resnet18 DEVICE = torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) class SimCLRResNet18(nn.Module): def __init__(self, out_dim=256): super().__init__() backbone = resnet18(weights=None) modules = list(backbone.children())[:-1] # remove fc self.backbone = nn.Sequential(*modules) self.fc = nn.Linear(512, out_dim) def forward(self, x): x = self.backbone(x) x = torch.flatten(x, 1) x = self.fc(x) return F.normalize(x, dim=1) def load_encoder(path="encoder_resnet18_simclr.pth"): model = SimCLRResNet18().to(DEVICE) state = torch.load(path, map_location=DEVICE) # If state dict has "backbone." prefix new_state = {} for k, v in state.items(): k = k.replace("backbone.", "") k = k.replace("projector.", "") new_state[k] = v try: model.load_state_dict(state, strict=False) except: model.load_state_dict(new_state, strict=False) model.eval() return model