deepfake_detection / model_loader.py
Pranithkumar7's picture
Upload 11 files
a972d65 verified
from pathlib import Path
import torch
from cross_efficient_vit_model import CrossEfficientViT
from npr_model import NPRModel
BASE_DIR = Path(__file__).resolve().parent
def _find_model_file(file_name):
for model_path in (
BASE_DIR / "pretrained_model" / file_name,
BASE_DIR / file_name,
):
if model_path.exists():
return model_path
return BASE_DIR / "pretrained_model" / file_name
IMAGE_MODEL_PATH = _find_model_file("NPR.pth")
VIDEO_MODEL_PATH = _find_model_file("cross_efficient_vit.pth")
device = torch.device("cpu")
_image_model = None
_video_model = None
def _load_state_dict(model_path):
if not model_path.exists():
raise FileNotFoundError(
f"Model weights not found at {model_path}. "
"Make sure Git LFS files are available in the deployed backend."
)
with model_path.open("rb") as model_file:
header = model_file.read(64)
if header.startswith(b"version https://git-lfs.github.com/spec"):
raise RuntimeError(
f"{model_path.name} is a Git LFS pointer, not the real model file. "
"Enable Git LFS for the Hugging Face Space or upload the real weights to the backend."
)
return torch.load(model_path, map_location=device)
def get_image_model():
global _image_model
if _image_model is None:
_image_model = NPRModel()
_image_model.load_state_dict(_load_state_dict(IMAGE_MODEL_PATH))
_image_model.eval()
return _image_model
def get_video_model():
global _video_model
if _video_model is None:
_video_model = CrossEfficientViT()
_video_model.load_state_dict(_load_state_dict(VIDEO_MODEL_PATH))
_video_model.eval()
return _video_model