| import torch |
| from torch import nn |
| import torch.nn.functional as F |
|
|
| from configs.paths_config import model_paths |
|
|
|
|
| class MocoLoss(nn.Module): |
|
|
| def __init__(self, opts): |
| super(MocoLoss, self).__init__() |
| print("Loading MOCO model from path: {}".format(model_paths["moco"])) |
| self.model = self.__load_model() |
| self.model.eval() |
| for param in self.model.parameters(): |
| param.requires_grad = False |
|
|
| @staticmethod |
| def __load_model(): |
| import torchvision.models as models |
| model = models.__dict__["resnet50"]() |
| |
| for name, param in model.named_parameters(): |
| if name not in ['fc.weight', 'fc.bias']: |
| param.requires_grad = False |
| checkpoint = torch.load(model_paths['moco'], map_location="cpu") |
| state_dict = checkpoint['state_dict'] |
| |
| for k in list(state_dict.keys()): |
| |
| if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'): |
| |
| state_dict[k[len("module.encoder_q."):]] = state_dict[k] |
| |
| del state_dict[k] |
| msg = model.load_state_dict(state_dict, strict=False) |
| assert set(msg.missing_keys) == {"fc.weight", "fc.bias"} |
| |
| model = nn.Sequential(*list(model.children())[:-1]).cuda() |
| return model |
|
|
| def extract_feats(self, x): |
| x = F.interpolate(x, size=224) |
| x_feats = self.model(x) |
| x_feats = nn.functional.normalize(x_feats, dim=1) |
| x_feats = x_feats.squeeze() |
| return x_feats |
|
|
| def forward(self, y_hat, y, x): |
| n_samples = x.shape[0] |
| x_feats = self.extract_feats(x) |
| y_feats = self.extract_feats(y) |
| y_hat_feats = self.extract_feats(y_hat) |
| y_feats = y_feats.detach() |
| loss = 0 |
| sim_improvement = 0 |
| sim_logs = [] |
| count = 0 |
| for i in range(n_samples): |
| diff_target = y_hat_feats[i].dot(y_feats[i]) |
| diff_input = y_hat_feats[i].dot(x_feats[i]) |
| diff_views = y_feats[i].dot(x_feats[i]) |
| sim_logs.append({'diff_target': float(diff_target), |
| 'diff_input': float(diff_input), |
| 'diff_views': float(diff_views)}) |
| loss += 1 - diff_target |
| sim_diff = float(diff_target) - float(diff_views) |
| sim_improvement += sim_diff |
| count += 1 |
|
|
| return loss / count, sim_improvement / count, sim_logs |
|
|