# SPDX-FileCopyrightText: Copyright © 2025 Idiap Research Institute # SPDX-FileContributor: Francois Poh # SPDX-License-Identifier: GPL-3.0-or-later # ArtFace contains the code for the paper: https://www.idiap.ch/paper/artface/ # It provides a facial recognition model for historical portraits, and scripts to reproduce the experiments in the paper. import torch import torch.nn.functional as F from torchvision import transforms from PIL import Image import os class Model: def __init__(self): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def __call__(self, path): image = Image.open(path).convert("RGB") image.filename = path embedding = self.get_embedding(image) if embedding is None: return None if isinstance(embedding, torch.Tensor): return embedding.cpu().detach().numpy().squeeze() return embedding def get_embedding(self, image): raise NotImplementedError("Subclasses must implement get_embedding") class CLIPModel(Model): def __init__(self, checkpoint="openai/clip-vit-base-patch16"): super().__init__() from transformers import AutoProcessor, CLIPVisionModel self.model = CLIPVisionModel.from_pretrained( checkpoint, attn_implementation="eager" ).to(self.device) self.processor = AutoProcessor.from_pretrained(checkpoint) def get_embedding(self, image, output_attentions=False): inputs = self.processor(images=image, return_tensors="pt").to(self.device) image_features = self.model(**inputs, output_attentions=output_attentions) if output_attentions: return image_features.pooler_output, image_features.attentions return image_features.pooler_output.squeeze(0) def torch(self): class CLIPVisionModelWrapper(torch.nn.Module): def __init__(self, model): super().__init__() self.model = model self.device = model.device def forward(self, x, output_attentions=False): output = self.model(pixel_values=x, output_attentions=output_attentions) if output_attentions: return output.pooler_output, output.attentions return output.pooler_output def preprocess(image): inputs = self.processor(images=image, return_tensors="pt").to(self.device) return inputs["pixel_values"].squeeze(0) return CLIPVisionModelWrapper(self.model), preprocess class OnnxModel(Model): def __init__(self, model_path, checkpoint=None): super().__init__() from onnx2torch import convert self.transform = transforms.Compose( [ transforms.Resize((112, 112)), transforms.Lambda( lambda x: x.float() if isinstance(x, torch.Tensor) else transforms.ToTensor()(x) ), transforms.Normalize( mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5] ), # InsightFace uses [-1, 1] pixel range ] ) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if not os.path.exists(model_path): import shutil from insightface.utils import download dirname = os.path.dirname(model_path) download(sub_dir=".", name="antelopev2", root=dirname) os.makedirs(dirname, exist_ok=True) os.replace(f"{dirname}/antelopev2/antelopev2/glintr100.onnx", model_path) shutil.rmtree(f"{dirname}/antelopev2") os.remove(f"{dirname}/antelopev2.zip") self.model = convert(model_path) if checkpoint: print(f"Loading weights from {checkpoint}") self.model.load_state_dict(torch.load(checkpoint)) self.model.to(self.device) self.model.eval() def get_embedding(self, image): image = self.transform(image).unsqueeze(0).to(self.device) return self.model(image).squeeze(0) def torch(self): def preprocess(image): return self.transform(image).to(self.device) self.model.device = self.device return self.model, preprocess class TrainedModel(Model): def __init__(self, model_name, checkpoint, base_model_args={}): from lib.utils import load_checkpoint super().__init__() base_model, self.processor = get_model(model_name, base_model_args).torch() self.model = load_checkpoint(base_model, checkpoint) self.model.to(self.device) self.model.eval() def get_embedding(self, image, output_attentions=False): inputs = self.processor(image) inputs = ( [x.to(self.device).unsqueeze(0) for x in inputs] if isinstance(inputs, list) else inputs.to(self.device).unsqueeze(0) ) with torch.no_grad(): if output_attentions: image_features = self.model(inputs, output_attentions=output_attentions) return image_features[0].squeeze(0), image_features[1] else: image_features = self.model(inputs) return image_features.squeeze(0) def torch(self): return self.model, self.processor class FusionInput(list): def __init__(self, images): super().__init__(images) def unsqueeze(self, dim): self[:] = [item.unsqueeze(dim) for item in self] return self def to(self, device): self[:] = [item.to(device) for item in self] return self class FusionModel(Model): def __init__(self, models, head="none"): super().__init__() self.model_names = [name for name, _ in models] self.models = [get_model(name, args) for name, args in models] def get_embedding(self, image): embeddings = [model.get_embedding(image) for model in self.models] embeddings = [F.normalize(emb, dim=0) for emb in embeddings] return F.normalize(torch.cat(embeddings, dim=0), dim=0) def torch(self): from lib.ModelWrappers import FusionModelWrapper preprocessors = [model.torch()[1] for model in self.models] def preprocess(image): return FusionInput(preprocessor(image) for preprocessor in preprocessors) return FusionModelWrapper( self.models, self.model_names, self.device ), preprocess def get_submodel(self, name): return getattr(self, name) def set_submodel(self, name, model): return setattr(self, name, model) models = { "clip": (CLIPModel, {"checkpoint": "openai/clip-vit-base-patch16"}), "lora": ( TrainedModel, { "model_name": "clip", "checkpoint": "Idiap/ArtFace-CLIP-LoRA", }, ), "ires100": (OnnxModel, {"model_path": "checkpoints/antelopev2/glintr100.onnx"}), "ires100-tune": ( TrainedModel, { "model_name": "ires100", "checkpoint": "Idiap/ArtFace-IResNet100-Tuned", }, ), } def get_model(name, args={}): import ast model_args = {} if isinstance(args, list): for arg in args: if "=" not in arg: raise ValueError( f"Invalid argument format for model arguments. Expected 'key=value' pairs, got '{arg}'." ) key, value = arg.split("=", 1) try: model_args[key] = ast.literal_eval(value) except (ValueError, SyntaxError): model_args[key] = value elif isinstance(args, dict): model_args = args if name not in models: raise ValueError("Unrecognised model name!") return models[name][0](**dict(models[name][1], **model_args))