Spaces:
Running
Running
| # SPDX-FileCopyrightText: Copyright © 2025 Idiap Research Institute <contact@idiap.ch> | |
| # SPDX-FileContributor: Francois Poh <francois.poh22@imperial.ac.uk> | |
| # SPDX-License-Identifier: GPL-3.0-or-later | |
| # ArtFace contains the code for the paper: https://www.idiap.ch/paper/artface/ | |
| # It provides a facial recognition model for historical portraits, and scripts to reproduce the experiments in the paper. | |
| from PIL import Image | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| class ImagePreprocessor: | |
| def __init__(self): | |
| pass | |
| def __call__(self, image): | |
| # -------------------- | |
| # Accept path OR PIL image | |
| # -------------------- | |
| if isinstance(image, str): | |
| image = Image.open(image).convert("RGB") | |
| elif isinstance(image, Image.Image): | |
| image = image.convert("RGB") | |
| else: | |
| raise TypeError( | |
| f"Unsupported input type {type(image)}. " | |
| "Expected file path or PIL.Image." | |
| ) | |
| return self.process(image) | |
| def process(self, image): | |
| raise NotImplementedError("Subclasses should implement this method.") | |
| class FaceAligner(ImagePreprocessor): | |
| def __init__(self, detector="buffalo_l", crop_size=(112, 112), padding=0): | |
| crop_size = tuple(map(int, crop_size)) | |
| super().__init__() | |
| from lib.face_alignment import mtcnn | |
| from insightface.app import FaceAnalysis | |
| # -------------------- | |
| # Device selection | |
| # -------------------- | |
| self.use_cuda = torch.cuda.is_available() | |
| if self.use_cuda: | |
| device = "cuda:0" | |
| providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] | |
| ctx_id = 0 | |
| else: | |
| device = "cpu" | |
| providers = ["CPUExecutionProvider"] | |
| ctx_id = -1 | |
| # -------------------- | |
| # MTCNN (landmark warping) | |
| # -------------------- | |
| self.mtcnn = mtcnn.MTCNN( | |
| device=device, | |
| crop_size=tuple(int(s) for s in crop_size), | |
| padding=float(padding), | |
| ) | |
| # -------------------- | |
| # InsightFace detector | |
| # -------------------- | |
| self.detector = FaceAnalysis( | |
| name=detector, | |
| root=".", | |
| providers=providers, | |
| ) | |
| self.detector.prepare(ctx_id=ctx_id) | |
| print( | |
| f"✅ FaceAligner initialized | " | |
| f"CUDA: {self.use_cuda} | " | |
| f"providers: {providers}" | |
| ) | |
| def process(self, image): | |
| from lib.face_alignment import mtcnn | |
| image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
| # Detect face | |
| _, kpss = self.detector.det_model.detect(image, max_num=1, metric="default") | |
| if kpss is None or len(kpss) == 0: | |
| return None | |
| # Convert landmarks | |
| landmarks = np.array( | |
| [np.concatenate([kpss[:, :, 0][0], kpss[:, :, 1][0]], axis=0)] | |
| ) | |
| facial5points = [[landmarks[0][j], landmarks[0][j + 5]] for j in range(5)] | |
| # Warp & crop | |
| warped_face = mtcnn.warp_and_crop_face( | |
| image, | |
| facial5points, | |
| self.mtcnn.refrence, | |
| crop_size=self.mtcnn.crop_size, | |
| ) | |
| rgb_face = cv2.cvtColor(warped_face, cv2.COLOR_BGR2RGB) | |
| return Image.fromarray(rgb_face) | |
| preprocessors = { | |
| "align": (FaceAligner, {"crop_size": (112, 112)}), | |
| "align-224": (FaceAligner, {"crop_size": (224, 224)}), | |
| "align-pad": (FaceAligner, {"crop_size": (224, 224), "padding": 0.5}), | |
| } | |
| def get_preprocessor(name, args={}): | |
| model_args = {} | |
| if isinstance(args, list): | |
| for arg in args: | |
| if "=" not in arg: | |
| raise ValueError( | |
| f"Invalid argument format for model arguments. Expected 'key=value' pairs, got '{arg}'." | |
| ) | |
| key, value = arg.split("=", 1) | |
| value = value.strip("'") | |
| if "," in value: | |
| value = [v.strip("'") for v in value.split(",")] | |
| model_args[key] = value | |
| if name in preprocessors: | |
| return preprocessors[name][0](**{**preprocessors[name][1], **model_args}) | |
| else: | |
| raise ValueError( | |
| f"Unknown preprocessor: {name}\n Please choose from: {', '.join(preprocessors.keys())}" | |
| ) | |