Spaces:
Running
Running
| # SPDX-FileCopyrightText: Copyright © 2026 Idiap Research Institute <contact@idiap.ch> | |
| # SPDX-FileContributor: Samuel Michel <samuel.michel@idiap.ch> | |
| # SPDX-License-Identifier: GPL-3.0-or-later | |
| # ArtFace contains the code for the paper: https://www.idiap.ch/paper/artface/ | |
| # It provides a facial recognition model for historical portraits, and scripts to reproduce the experiments in the paper. | |
| import os | |
| from lib.ModelWrappers import FusionModelWrapper | |
| import glob | |
| import zipfile | |
| import torch | |
| from huggingface_hub import snapshot_download, repo_info | |
| from typing import Optional | |
| from huggingface_hub.utils import RepositoryNotFoundError | |
| def repo_exists( | |
| repo_id: str, repo_type: Optional[str] = None, token: Optional[str] = None | |
| ) -> bool: | |
| try: | |
| repo_info(repo_id, repo_type=repo_type, token=token) | |
| return True | |
| except RepositoryNotFoundError: | |
| return False | |
| def resolve_checkpoint(checkpoint): | |
| if os.path.exists(checkpoint): | |
| return checkpoint | |
| if repo_exists(checkpoint): | |
| local_dir = snapshot_download(repo_id=checkpoint) | |
| for zip_path in glob.glob(os.path.join(local_dir, "*.zip")): | |
| with zipfile.ZipFile(zip_path, "r") as zf: | |
| zf.extractall(local_dir) | |
| return local_dir | |
| return checkpoint | |
| def load_checkpoint(model, checkpoint): | |
| checkpoint = resolve_checkpoint(checkpoint) | |
| if isinstance(model, FusionModelWrapper): | |
| for name, submodel in model.named_submodels(): | |
| ckpt = f"{checkpoint}/{name}" | |
| submodel = load_checkpoint(submodel, ckpt) | |
| model.set_submodel(name, submodel) | |
| return model | |
| # Extract zip files if present (for local checkpoints) | |
| zips = glob.glob(f"{checkpoint}/*.zip") + glob.glob( | |
| f"{os.path.dirname(checkpoint)}/*.zip" | |
| ) | |
| if zips: | |
| import zipfile | |
| with zipfile.ZipFile(zips[0], "r") as zip_ref: | |
| zip_ref.extractall(os.path.dirname(zips[0])) | |
| # Look for .pth state-dict files (direct or in subdirectories) | |
| ckpt = glob.glob(f"{checkpoint}/*.pth") or glob.glob( | |
| f"{checkpoint}/**/*.pth", recursive=True | |
| ) | |
| if ckpt: | |
| model.load_state_dict(torch.load(ckpt[0], map_location="cpu")) | |
| else: | |
| from peft import PeftModel | |
| # Adapter files may be in a subdirectory after zip extraction | |
| adapter_configs = glob.glob( | |
| f"{checkpoint}/**/adapter_config.json", recursive=True | |
| ) | |
| adapter_dir = ( | |
| os.path.dirname(adapter_configs[0]) if adapter_configs else checkpoint | |
| ) | |
| model = PeftModel.from_pretrained(model, adapter_dir, is_trainable=True) | |
| return model | |