deepscan-ai-api / model_loader.py
Anas-Mohiuddin-Syed's picture
Upload 10 files
6a8443b verified
import os
from functools import lru_cache
from typing import Dict, List, Tuple
import torch
from transformers import AutoImageProcessor, AutoModelForImageClassification
LOCAL_MODEL_PATH = os.path.join(os.path.dirname(__file__), "models", "local_model")
class ModelLoadError(RuntimeError):
"""Raised when the deepfake model cannot be loaded."""
def _normalized_labels(id2label: Dict[int, str]) -> Dict[int, str]:
return {idx: str(label).strip().lower() for idx, label in id2label.items()}
def _resolve_label_indices(id2label: Dict[int, str]) -> Tuple[List[int], List[int]]:
normalized = _normalized_labels(id2label)
fake_indices = [
idx
for idx, label in normalized.items()
if ("fake" in label) or ("deepfake" in label) or ("manipulated" in label)
]
real_indices = [idx for idx, label in normalized.items() if ("real" in label) or ("authentic" in label)]
return fake_indices, real_indices
@lru_cache(maxsize=1)
def load_model() -> Tuple[AutoImageProcessor, AutoModelForImageClassification, List[int], List[int]]:
if not os.path.isdir(LOCAL_MODEL_PATH):
raise ModelLoadError("Local model not found. Please download model first.")
try:
processor = AutoImageProcessor.from_pretrained(LOCAL_MODEL_PATH)
model = AutoModelForImageClassification.from_pretrained(LOCAL_MODEL_PATH)
model = model.to(torch.device("cpu"))
model.eval()
except Exception as exc:
raise ModelLoadError(f"Failed to load local model from '{LOCAL_MODEL_PATH}': {exc}") from exc
id2label = getattr(model.config, "id2label", {}) or {}
fake_indices, real_indices = _resolve_label_indices(id2label)
if not fake_indices and not real_indices:
raise ModelLoadError(
f"Could not infer fake/real labels from model.config.id2label: {id2label}"
)
if not fake_indices:
raise ModelLoadError(
f"Could not find a fake/deepfake label in model.config.id2label: {id2label}"
)
return processor, model, fake_indices, real_indices