Spaces:
Running
Running
| # SPDX-FileCopyrightText: Copyright © 2025 Idiap Research Institute <contact@idiap.ch> | |
| # SPDX-FileContributor: Francois Poh <francois.poh22@imperial.ac.uk> | |
| # SPDX-License-Identifier: GPL-3.0-or-later | |
| # ArtFace contains the code for the paper: https://www.idiap.ch/paper/artface/ | |
| # It provides a facial recognition model for historical portraits, and scripts to reproduce the experiments in the paper. | |
| from collections import defaultdict | |
| from pathlib import Path | |
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| class FaceLossHead(nn.Module): | |
| def __init__(self, in_features, out_features, scale, margin, mode): | |
| super().__init__() | |
| self.mode = mode | |
| if mode not in ("cosface", "arcface"): | |
| raise ValueError(f"Unsupported mode: {mode}. Use 'cosface' or 'arcface'.") | |
| self.scale = scale | |
| self.margin = margin or (0.35 if mode == "cosface" else 0.5) | |
| self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features)) | |
| nn.init.xavier_uniform_(self.weight) | |
| def forward(self, features, labels): | |
| if self.mode == "cosface": | |
| return self.cosface_forward(features, labels) | |
| elif self.mode == "arcface": | |
| return self.arcface_forward(features, labels) | |
| def cosface_forward(self, features, labels): | |
| cosine = F.linear(F.normalize(features), F.normalize(self.weight)) | |
| one_hot = torch.zeros_like(cosine) | |
| one_hot.scatter_(1, labels.view(-1, 1), 1) | |
| output = self.scale * (cosine - one_hot * self.margin) | |
| return output | |
| def arcface_forward(self, features, labels): | |
| cosine = F.linear(F.normalize(features), F.normalize(self.weight)) | |
| one_hot = torch.zeros_like(cosine) | |
| one_hot.scatter_(1, labels.view(-1, 1), 1) | |
| theta = torch.acos(cosine.clamp(-1.0, 1.0)) | |
| target_theta = theta + self.margin | |
| output = self.scale * torch.cos(target_theta) * one_hot + cosine * (1 - one_hot) | |
| return output | |
| def to(self, device): | |
| super().to(device) | |
| self.weight = self.weight.to(device) | |
| return self | |
| class FaceLossWrapper(nn.Module): | |
| def __init__( | |
| self, | |
| backbone, | |
| input_shape, | |
| out_features, | |
| scale=64.0, | |
| margin=None, | |
| mode="cosface", | |
| ): | |
| super().__init__() | |
| margin = margin or (0.35 if mode == "cosface" else 0.5) | |
| self.backbone = backbone | |
| self.device = backbone.device | |
| dummy = torch.zeros(*input_shape).to(backbone.device) | |
| with torch.no_grad(): | |
| feat = self.backbone(dummy) | |
| if isinstance(feat, (tuple, list)): | |
| feat = feat[0] | |
| in_features = feat.shape[-1] | |
| self.head = FaceLossHead(in_features, out_features, scale, margin, mode).to( | |
| backbone.device | |
| ) | |
| def forward(self, x, labels=None): | |
| features = self.backbone(x) | |
| if self.training and labels is not None: | |
| return self.head(features, labels) | |
| return features | |
| def save_pretrained(self, path): | |
| self.backbone.save_pretrained(path) | |
| class FusionModelWrapper(nn.Module): | |
| def __init__(self, models, model_names, device="cuda"): | |
| super().__init__(device) | |
| counts = defaultdict(int) | |
| self.model_names = [] | |
| for name in model_names: | |
| self.model_names.append(f"{name}_{counts[name]}") | |
| counts[name] += 1 | |
| for name, model in zip(self.model_names, models): | |
| self.set_submodel(name, model.torch()[0]) | |
| def forward(self, xs): | |
| models = (self.get_submodel(name) for name in self.model_names) | |
| embeddings = [F.normalize((model(x)), dim=-1) for model, x in zip(models, xs)] | |
| x = torch.cat(embeddings, dim=-1) | |
| return F.normalize(x, dim=-1) | |
| def named_submodels(self): | |
| return [(name, self.get_submodel(name)) for name in self.model_names] | |
| def save_pretrained(self, path): | |
| for name, submodel in self.named_submodels(): | |
| Path(f"{path}/{name}").mkdir(parents=True, exist_ok=True) | |
| submodel.save_pretrained(f"{path}/{name}") | |
| def get_submodel(self, name): | |
| return getattr(self, name) | |
| def set_submodel(self, name, model): | |
| return setattr(self, name, model) | |