File size: 2,685 Bytes
53fe336
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# SPDX-FileCopyrightText: Copyright © 2026 Idiap Research Institute <contact@idiap.ch>

# SPDX-FileContributor: Samuel Michel <samuel.michel@idiap.ch>

# SPDX-License-Identifier: GPL-3.0-or-later

# ArtFace contains the code for the paper: https://www.idiap.ch/paper/artface/
# It provides a facial recognition model for historical portraits, and scripts to reproduce the experiments in the paper.

import os
from lib.ModelWrappers import FusionModelWrapper
import glob
import zipfile
import torch
from huggingface_hub import snapshot_download, repo_info
from typing import Optional
from huggingface_hub.utils import RepositoryNotFoundError


def repo_exists(
    repo_id: str, repo_type: Optional[str] = None, token: Optional[str] = None
) -> bool:
    try:
        repo_info(repo_id, repo_type=repo_type, token=token)
        return True
    except RepositoryNotFoundError:
        return False


def resolve_checkpoint(checkpoint):
    if os.path.exists(checkpoint):
        return checkpoint
    if repo_exists(checkpoint):
        local_dir = snapshot_download(repo_id=checkpoint)
        for zip_path in glob.glob(os.path.join(local_dir, "*.zip")):
            with zipfile.ZipFile(zip_path, "r") as zf:
                zf.extractall(local_dir)
        return local_dir
    return checkpoint


def load_checkpoint(model, checkpoint):
    checkpoint = resolve_checkpoint(checkpoint)

    if isinstance(model, FusionModelWrapper):
        for name, submodel in model.named_submodels():
            ckpt = f"{checkpoint}/{name}"
            submodel = load_checkpoint(submodel, ckpt)
            model.set_submodel(name, submodel)
        return model

    # Extract zip files if present (for local checkpoints)
    zips = glob.glob(f"{checkpoint}/*.zip") + glob.glob(
        f"{os.path.dirname(checkpoint)}/*.zip"
    )
    if zips:
        import zipfile

        with zipfile.ZipFile(zips[0], "r") as zip_ref:
            zip_ref.extractall(os.path.dirname(zips[0]))

    # Look for .pth state-dict files (direct or in subdirectories)
    ckpt = glob.glob(f"{checkpoint}/*.pth") or glob.glob(
        f"{checkpoint}/**/*.pth", recursive=True
    )
    if ckpt:
        model.load_state_dict(torch.load(ckpt[0], map_location="cpu"))
    else:
        from peft import PeftModel

        # Adapter files may be in a subdirectory after zip extraction
        adapter_configs = glob.glob(
            f"{checkpoint}/**/adapter_config.json", recursive=True
        )
        adapter_dir = (
            os.path.dirname(adapter_configs[0]) if adapter_configs else checkpoint
        )
        model = PeftModel.from_pretrained(model, adapter_dir, is_trainable=True)
    return model