md896's picture
Update model.py
4f09ddd verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet18
DEVICE = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
class SimCLRResNet18(nn.Module):
def __init__(self, out_dim=256):
super().__init__()
backbone = resnet18(weights=None)
modules = list(backbone.children())[:-1] # remove fc
self.backbone = nn.Sequential(*modules)
self.fc = nn.Linear(512, out_dim)
def forward(self, x):
x = self.backbone(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return F.normalize(x, dim=1)
def load_encoder(path="encoder_resnet18_simclr.pth"):
model = SimCLRResNet18().to(DEVICE)
state = torch.load(path, map_location=DEVICE)
# If state dict has "backbone." prefix
new_state = {}
for k, v in state.items():
k = k.replace("backbone.", "")
k = k.replace("projector.", "")
new_state[k] = v
try:
model.load_state_dict(state, strict=False)
except:
model.load_state_dict(new_state, strict=False)
model.eval()
return model