ArtFace / lib /utils.py
Anjith GEORGE
initial commit
53fe336
# SPDX-FileCopyrightText: Copyright © 2026 Idiap Research Institute <contact@idiap.ch>
# SPDX-FileContributor: Samuel Michel <samuel.michel@idiap.ch>
# SPDX-License-Identifier: GPL-3.0-or-later
# ArtFace contains the code for the paper: https://www.idiap.ch/paper/artface/
# It provides a facial recognition model for historical portraits, and scripts to reproduce the experiments in the paper.
import os
from lib.ModelWrappers import FusionModelWrapper
import glob
import zipfile
import torch
from huggingface_hub import snapshot_download, repo_info
from typing import Optional
from huggingface_hub.utils import RepositoryNotFoundError
def repo_exists(
repo_id: str, repo_type: Optional[str] = None, token: Optional[str] = None
) -> bool:
try:
repo_info(repo_id, repo_type=repo_type, token=token)
return True
except RepositoryNotFoundError:
return False
def resolve_checkpoint(checkpoint):
if os.path.exists(checkpoint):
return checkpoint
if repo_exists(checkpoint):
local_dir = snapshot_download(repo_id=checkpoint)
for zip_path in glob.glob(os.path.join(local_dir, "*.zip")):
with zipfile.ZipFile(zip_path, "r") as zf:
zf.extractall(local_dir)
return local_dir
return checkpoint
def load_checkpoint(model, checkpoint):
checkpoint = resolve_checkpoint(checkpoint)
if isinstance(model, FusionModelWrapper):
for name, submodel in model.named_submodels():
ckpt = f"{checkpoint}/{name}"
submodel = load_checkpoint(submodel, ckpt)
model.set_submodel(name, submodel)
return model
# Extract zip files if present (for local checkpoints)
zips = glob.glob(f"{checkpoint}/*.zip") + glob.glob(
f"{os.path.dirname(checkpoint)}/*.zip"
)
if zips:
import zipfile
with zipfile.ZipFile(zips[0], "r") as zip_ref:
zip_ref.extractall(os.path.dirname(zips[0]))
# Look for .pth state-dict files (direct or in subdirectories)
ckpt = glob.glob(f"{checkpoint}/*.pth") or glob.glob(
f"{checkpoint}/**/*.pth", recursive=True
)
if ckpt:
model.load_state_dict(torch.load(ckpt[0], map_location="cpu"))
else:
from peft import PeftModel
# Adapter files may be in a subdirectory after zip extraction
adapter_configs = glob.glob(
f"{checkpoint}/**/adapter_config.json", recursive=True
)
adapter_dir = (
os.path.dirname(adapter_configs[0]) if adapter_configs else checkpoint
)
model = PeftModel.from_pretrained(model, adapter_dir, is_trainable=True)
return model