Spaces:
Running
Running
| # SPDX-FileCopyrightText: Copyright © 2025 Idiap Research Institute <contact@idiap.ch> | |
| # SPDX-FileContributor: Francois Poh <francois.poh22@imperial.ac.uk> | |
| # SPDX-License-Identifier: GPL-3.0-or-later | |
| # ArtFace contains the code for the paper: https://www.idiap.ch/paper/artface/ | |
| # It provides a facial recognition model for historical portraits, and scripts to reproduce the experiments in the paper. | |
| import torch | |
| import torch.nn.functional as F | |
| from torchvision import transforms | |
| from PIL import Image | |
| import os | |
| class Model: | |
| def __init__(self): | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def __call__(self, path): | |
| image = Image.open(path).convert("RGB") | |
| image.filename = path | |
| embedding = self.get_embedding(image) | |
| if embedding is None: | |
| return None | |
| if isinstance(embedding, torch.Tensor): | |
| return embedding.cpu().detach().numpy().squeeze() | |
| return embedding | |
| def get_embedding(self, image): | |
| raise NotImplementedError("Subclasses must implement get_embedding") | |
| class CLIPModel(Model): | |
| def __init__(self, checkpoint="openai/clip-vit-base-patch16"): | |
| super().__init__() | |
| from transformers import AutoProcessor, CLIPVisionModel | |
| self.model = CLIPVisionModel.from_pretrained( | |
| checkpoint, attn_implementation="eager" | |
| ).to(self.device) | |
| self.processor = AutoProcessor.from_pretrained(checkpoint) | |
| def get_embedding(self, image, output_attentions=False): | |
| inputs = self.processor(images=image, return_tensors="pt").to(self.device) | |
| image_features = self.model(**inputs, output_attentions=output_attentions) | |
| if output_attentions: | |
| return image_features.pooler_output, image_features.attentions | |
| return image_features.pooler_output.squeeze(0) | |
| def torch(self): | |
| class CLIPVisionModelWrapper(torch.nn.Module): | |
| def __init__(self, model): | |
| super().__init__() | |
| self.model = model | |
| self.device = model.device | |
| def forward(self, x, output_attentions=False): | |
| output = self.model(pixel_values=x, output_attentions=output_attentions) | |
| if output_attentions: | |
| return output.pooler_output, output.attentions | |
| return output.pooler_output | |
| def preprocess(image): | |
| inputs = self.processor(images=image, return_tensors="pt").to(self.device) | |
| return inputs["pixel_values"].squeeze(0) | |
| return CLIPVisionModelWrapper(self.model), preprocess | |
| class OnnxModel(Model): | |
| def __init__(self, model_path, checkpoint=None): | |
| super().__init__() | |
| from onnx2torch import convert | |
| self.transform = transforms.Compose( | |
| [ | |
| transforms.Resize((112, 112)), | |
| transforms.Lambda( | |
| lambda x: x.float() | |
| if isinstance(x, torch.Tensor) | |
| else transforms.ToTensor()(x) | |
| ), | |
| transforms.Normalize( | |
| mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5] | |
| ), # InsightFace uses [-1, 1] pixel range | |
| ] | |
| ) | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| if not os.path.exists(model_path): | |
| import shutil | |
| from insightface.utils import download | |
| dirname = os.path.dirname(model_path) | |
| download(sub_dir=".", name="antelopev2", root=dirname) | |
| os.makedirs(dirname, exist_ok=True) | |
| os.replace(f"{dirname}/antelopev2/antelopev2/glintr100.onnx", model_path) | |
| shutil.rmtree(f"{dirname}/antelopev2") | |
| os.remove(f"{dirname}/antelopev2.zip") | |
| self.model = convert(model_path) | |
| if checkpoint: | |
| print(f"Loading weights from {checkpoint}") | |
| self.model.load_state_dict(torch.load(checkpoint)) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| def get_embedding(self, image): | |
| image = self.transform(image).unsqueeze(0).to(self.device) | |
| return self.model(image).squeeze(0) | |
| def torch(self): | |
| def preprocess(image): | |
| return self.transform(image).to(self.device) | |
| self.model.device = self.device | |
| return self.model, preprocess | |
| class TrainedModel(Model): | |
| def __init__(self, model_name, checkpoint, base_model_args={}): | |
| from lib.utils import load_checkpoint | |
| super().__init__() | |
| base_model, self.processor = get_model(model_name, base_model_args).torch() | |
| self.model = load_checkpoint(base_model, checkpoint) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| def get_embedding(self, image, output_attentions=False): | |
| inputs = self.processor(image) | |
| inputs = ( | |
| [x.to(self.device).unsqueeze(0) for x in inputs] | |
| if isinstance(inputs, list) | |
| else inputs.to(self.device).unsqueeze(0) | |
| ) | |
| with torch.no_grad(): | |
| if output_attentions: | |
| image_features = self.model(inputs, output_attentions=output_attentions) | |
| return image_features[0].squeeze(0), image_features[1] | |
| else: | |
| image_features = self.model(inputs) | |
| return image_features.squeeze(0) | |
| def torch(self): | |
| return self.model, self.processor | |
| class FusionInput(list): | |
| def __init__(self, images): | |
| super().__init__(images) | |
| def unsqueeze(self, dim): | |
| self[:] = [item.unsqueeze(dim) for item in self] | |
| return self | |
| def to(self, device): | |
| self[:] = [item.to(device) for item in self] | |
| return self | |
| class FusionModel(Model): | |
| def __init__(self, models, head="none"): | |
| super().__init__() | |
| self.model_names = [name for name, _ in models] | |
| self.models = [get_model(name, args) for name, args in models] | |
| def get_embedding(self, image): | |
| embeddings = [model.get_embedding(image) for model in self.models] | |
| embeddings = [F.normalize(emb, dim=0) for emb in embeddings] | |
| return F.normalize(torch.cat(embeddings, dim=0), dim=0) | |
| def torch(self): | |
| from lib.ModelWrappers import FusionModelWrapper | |
| preprocessors = [model.torch()[1] for model in self.models] | |
| def preprocess(image): | |
| return FusionInput(preprocessor(image) for preprocessor in preprocessors) | |
| return FusionModelWrapper( | |
| self.models, self.model_names, self.device | |
| ), preprocess | |
| def get_submodel(self, name): | |
| return getattr(self, name) | |
| def set_submodel(self, name, model): | |
| return setattr(self, name, model) | |
| models = { | |
| "clip": (CLIPModel, {"checkpoint": "openai/clip-vit-base-patch16"}), | |
| "lora": ( | |
| TrainedModel, | |
| { | |
| "model_name": "clip", | |
| "checkpoint": "Idiap/ArtFace-CLIP-LoRA", | |
| }, | |
| ), | |
| "ires100": (OnnxModel, {"model_path": "checkpoints/antelopev2/glintr100.onnx"}), | |
| "ires100-tune": ( | |
| TrainedModel, | |
| { | |
| "model_name": "ires100", | |
| "checkpoint": "Idiap/ArtFace-IResNet100-Tuned", | |
| }, | |
| ), | |
| } | |
| def get_model(name, args={}): | |
| import ast | |
| model_args = {} | |
| if isinstance(args, list): | |
| for arg in args: | |
| if "=" not in arg: | |
| raise ValueError( | |
| f"Invalid argument format for model arguments. Expected 'key=value' pairs, got '{arg}'." | |
| ) | |
| key, value = arg.split("=", 1) | |
| try: | |
| model_args[key] = ast.literal_eval(value) | |
| except (ValueError, SyntaxError): | |
| model_args[key] = value | |
| elif isinstance(args, dict): | |
| model_args = args | |
| if name not in models: | |
| raise ValueError("Unrecognised model name!") | |
| return models[name][0](**dict(models[name][1], **model_args)) | |