Spaces:
Runtime error
Runtime error
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| from criteria.model_irse import Backbone | |
| from criteria.backbones import get_model | |
| class IDLoss(nn.Module): | |
| """ | |
| Computes a cosine similarity between people in two images. | |
| Taken from TreB1eN's [1] implementation of InsightFace [2, 3], as used in pixel2style2pixel [4]. | |
| [1] https://github.com/TreB1eN/InsightFace_Pytorch | |
| [2] https://github.com/deepinsight/insightface | |
| [3] Deng, Jiankang and Guo, Jia and Niannan, Xue and Zafeiriou, Stefanos. | |
| ArcFace: Additive Angular Margin Loss for Deep Face Recognition. In CVPR, 2019 | |
| [4] https://github.com/eladrich/pixel2style2pixel | |
| """ | |
| def __init__(self, model_path, official=False, device="cpu"): | |
| """ | |
| Arguments: | |
| model_path (str): Path to IR-SE50 model. | |
| """ | |
| super(IDLoss, self).__init__() | |
| print("Loading ResNet ArcFace") | |
| self.official = official | |
| if official: | |
| self.facenet = get_model("r100", fp16=False) | |
| else: | |
| self.facenet = Backbone( | |
| input_size=112, num_layers=50, drop_ratio=0.6, mode="ir_se" | |
| ) | |
| self.facenet.load_state_dict(torch.load(model_path, map_location=device)) | |
| self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) | |
| self.facenet.eval() | |
| def extract_feats(self, x): | |
| x = x[:, :, 35:223, 32:220] # Crop interesting region | |
| x = self.face_pool(x) | |
| x_feats = self.facenet(x) | |
| return x_feats | |
| def forward(self, x, y): | |
| """ | |
| Arguments: | |
| x (Tensor): The batch of original images | |
| y (Tensor): The batch of generated images | |
| Returns: | |
| loss (Tensor): Cosine similarity between the | |
| features of the original and generated images. | |
| """ | |
| x_feats = self.extract_feats(x) | |
| y_feats = self.extract_feats(y) | |
| if self.official: | |
| x_feats = F.normalize(x_feats) | |
| y_feats = F.normalize(y_feats) | |
| loss = (1 - (x_feats * y_feats).sum(dim=1)).mean() | |
| return loss | |