Spaces:
Sleeping
Sleeping
| from pathlib import Path | |
| import torch | |
| from cross_efficient_vit_model import CrossEfficientViT | |
| from npr_model import NPRModel | |
| BASE_DIR = Path(__file__).resolve().parent | |
| def _find_model_file(file_name): | |
| for model_path in ( | |
| BASE_DIR / "pretrained_model" / file_name, | |
| BASE_DIR / file_name, | |
| ): | |
| if model_path.exists(): | |
| return model_path | |
| return BASE_DIR / "pretrained_model" / file_name | |
| IMAGE_MODEL_PATH = _find_model_file("NPR.pth") | |
| VIDEO_MODEL_PATH = _find_model_file("cross_efficient_vit.pth") | |
| device = torch.device("cpu") | |
| _image_model = None | |
| _video_model = None | |
| def _load_state_dict(model_path): | |
| if not model_path.exists(): | |
| raise FileNotFoundError( | |
| f"Model weights not found at {model_path}. " | |
| "Make sure Git LFS files are available in the deployed backend." | |
| ) | |
| with model_path.open("rb") as model_file: | |
| header = model_file.read(64) | |
| if header.startswith(b"version https://git-lfs.github.com/spec"): | |
| raise RuntimeError( | |
| f"{model_path.name} is a Git LFS pointer, not the real model file. " | |
| "Enable Git LFS for the Hugging Face Space or upload the real weights to the backend." | |
| ) | |
| return torch.load(model_path, map_location=device) | |
| def get_image_model(): | |
| global _image_model | |
| if _image_model is None: | |
| _image_model = NPRModel() | |
| _image_model.load_state_dict(_load_state_dict(IMAGE_MODEL_PATH)) | |
| _image_model.eval() | |
| return _image_model | |
| def get_video_model(): | |
| global _video_model | |
| if _video_model is None: | |
| _video_model = CrossEfficientViT() | |
| _video_model.load_state_dict(_load_state_dict(VIDEO_MODEL_PATH)) | |
| _video_model.eval() | |
| return _video_model | |