ArtFace / lib /models.py
Anjith GEORGE
initial commit
53fe336
# SPDX-FileCopyrightText: Copyright © 2025 Idiap Research Institute <contact@idiap.ch>
# SPDX-FileContributor: Francois Poh <francois.poh22@imperial.ac.uk>
# SPDX-License-Identifier: GPL-3.0-or-later
# ArtFace contains the code for the paper: https://www.idiap.ch/paper/artface/
# It provides a facial recognition model for historical portraits, and scripts to reproduce the experiments in the paper.
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import os
class Model:
def __init__(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def __call__(self, path):
image = Image.open(path).convert("RGB")
image.filename = path
embedding = self.get_embedding(image)
if embedding is None:
return None
if isinstance(embedding, torch.Tensor):
return embedding.cpu().detach().numpy().squeeze()
return embedding
def get_embedding(self, image):
raise NotImplementedError("Subclasses must implement get_embedding")
class CLIPModel(Model):
def __init__(self, checkpoint="openai/clip-vit-base-patch16"):
super().__init__()
from transformers import AutoProcessor, CLIPVisionModel
self.model = CLIPVisionModel.from_pretrained(
checkpoint, attn_implementation="eager"
).to(self.device)
self.processor = AutoProcessor.from_pretrained(checkpoint)
def get_embedding(self, image, output_attentions=False):
inputs = self.processor(images=image, return_tensors="pt").to(self.device)
image_features = self.model(**inputs, output_attentions=output_attentions)
if output_attentions:
return image_features.pooler_output, image_features.attentions
return image_features.pooler_output.squeeze(0)
def torch(self):
class CLIPVisionModelWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
self.device = model.device
def forward(self, x, output_attentions=False):
output = self.model(pixel_values=x, output_attentions=output_attentions)
if output_attentions:
return output.pooler_output, output.attentions
return output.pooler_output
def preprocess(image):
inputs = self.processor(images=image, return_tensors="pt").to(self.device)
return inputs["pixel_values"].squeeze(0)
return CLIPVisionModelWrapper(self.model), preprocess
class OnnxModel(Model):
def __init__(self, model_path, checkpoint=None):
super().__init__()
from onnx2torch import convert
self.transform = transforms.Compose(
[
transforms.Resize((112, 112)),
transforms.Lambda(
lambda x: x.float()
if isinstance(x, torch.Tensor)
else transforms.ToTensor()(x)
),
transforms.Normalize(
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]
), # InsightFace uses [-1, 1] pixel range
]
)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if not os.path.exists(model_path):
import shutil
from insightface.utils import download
dirname = os.path.dirname(model_path)
download(sub_dir=".", name="antelopev2", root=dirname)
os.makedirs(dirname, exist_ok=True)
os.replace(f"{dirname}/antelopev2/antelopev2/glintr100.onnx", model_path)
shutil.rmtree(f"{dirname}/antelopev2")
os.remove(f"{dirname}/antelopev2.zip")
self.model = convert(model_path)
if checkpoint:
print(f"Loading weights from {checkpoint}")
self.model.load_state_dict(torch.load(checkpoint))
self.model.to(self.device)
self.model.eval()
def get_embedding(self, image):
image = self.transform(image).unsqueeze(0).to(self.device)
return self.model(image).squeeze(0)
def torch(self):
def preprocess(image):
return self.transform(image).to(self.device)
self.model.device = self.device
return self.model, preprocess
class TrainedModel(Model):
def __init__(self, model_name, checkpoint, base_model_args={}):
from lib.utils import load_checkpoint
super().__init__()
base_model, self.processor = get_model(model_name, base_model_args).torch()
self.model = load_checkpoint(base_model, checkpoint)
self.model.to(self.device)
self.model.eval()
def get_embedding(self, image, output_attentions=False):
inputs = self.processor(image)
inputs = (
[x.to(self.device).unsqueeze(0) for x in inputs]
if isinstance(inputs, list)
else inputs.to(self.device).unsqueeze(0)
)
with torch.no_grad():
if output_attentions:
image_features = self.model(inputs, output_attentions=output_attentions)
return image_features[0].squeeze(0), image_features[1]
else:
image_features = self.model(inputs)
return image_features.squeeze(0)
def torch(self):
return self.model, self.processor
class FusionInput(list):
def __init__(self, images):
super().__init__(images)
def unsqueeze(self, dim):
self[:] = [item.unsqueeze(dim) for item in self]
return self
def to(self, device):
self[:] = [item.to(device) for item in self]
return self
class FusionModel(Model):
def __init__(self, models, head="none"):
super().__init__()
self.model_names = [name for name, _ in models]
self.models = [get_model(name, args) for name, args in models]
def get_embedding(self, image):
embeddings = [model.get_embedding(image) for model in self.models]
embeddings = [F.normalize(emb, dim=0) for emb in embeddings]
return F.normalize(torch.cat(embeddings, dim=0), dim=0)
def torch(self):
from lib.ModelWrappers import FusionModelWrapper
preprocessors = [model.torch()[1] for model in self.models]
def preprocess(image):
return FusionInput(preprocessor(image) for preprocessor in preprocessors)
return FusionModelWrapper(
self.models, self.model_names, self.device
), preprocess
def get_submodel(self, name):
return getattr(self, name)
def set_submodel(self, name, model):
return setattr(self, name, model)
models = {
"clip": (CLIPModel, {"checkpoint": "openai/clip-vit-base-patch16"}),
"lora": (
TrainedModel,
{
"model_name": "clip",
"checkpoint": "Idiap/ArtFace-CLIP-LoRA",
},
),
"ires100": (OnnxModel, {"model_path": "checkpoints/antelopev2/glintr100.onnx"}),
"ires100-tune": (
TrainedModel,
{
"model_name": "ires100",
"checkpoint": "Idiap/ArtFace-IResNet100-Tuned",
},
),
}
def get_model(name, args={}):
import ast
model_args = {}
if isinstance(args, list):
for arg in args:
if "=" not in arg:
raise ValueError(
f"Invalid argument format for model arguments. Expected 'key=value' pairs, got '{arg}'."
)
key, value = arg.split("=", 1)
try:
model_args[key] = ast.literal_eval(value)
except (ValueError, SyntaxError):
model_args[key] = value
elif isinstance(args, dict):
model_args = args
if name not in models:
raise ValueError("Unrecognised model name!")
return models[name][0](**dict(models[name][1], **model_args))