File size: 1,777 Bytes
a972d65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from pathlib import Path

import torch

from cross_efficient_vit_model import CrossEfficientViT
from npr_model import NPRModel

BASE_DIR = Path(__file__).resolve().parent


def _find_model_file(file_name):
    for model_path in (
        BASE_DIR / "pretrained_model" / file_name,
        BASE_DIR / file_name,
    ):
        if model_path.exists():
            return model_path

    return BASE_DIR / "pretrained_model" / file_name


IMAGE_MODEL_PATH = _find_model_file("NPR.pth")
VIDEO_MODEL_PATH = _find_model_file("cross_efficient_vit.pth")
device = torch.device("cpu")

_image_model = None
_video_model = None


def _load_state_dict(model_path):
    if not model_path.exists():
        raise FileNotFoundError(
            f"Model weights not found at {model_path}. "
            "Make sure Git LFS files are available in the deployed backend."
        )

    with model_path.open("rb") as model_file:
        header = model_file.read(64)

    if header.startswith(b"version https://git-lfs.github.com/spec"):
        raise RuntimeError(
            f"{model_path.name} is a Git LFS pointer, not the real model file. "
            "Enable Git LFS for the Hugging Face Space or upload the real weights to the backend."
        )

    return torch.load(model_path, map_location=device)


def get_image_model():
    global _image_model

    if _image_model is None:
        _image_model = NPRModel()
        _image_model.load_state_dict(_load_state_dict(IMAGE_MODEL_PATH))
        _image_model.eval()

    return _image_model


def get_video_model():
    global _video_model

    if _video_model is None:
        _video_model = CrossEfficientViT()
        _video_model.load_state_dict(_load_state_dict(VIDEO_MODEL_PATH))
        _video_model.eval()

    return _video_model