import os from functools import lru_cache from typing import Dict, List, Tuple import torch from transformers import AutoImageProcessor, AutoModelForImageClassification LOCAL_MODEL_PATH = os.path.join(os.path.dirname(__file__), "models", "local_model") class ModelLoadError(RuntimeError): """Raised when the deepfake model cannot be loaded.""" def _normalized_labels(id2label: Dict[int, str]) -> Dict[int, str]: return {idx: str(label).strip().lower() for idx, label in id2label.items()} def _resolve_label_indices(id2label: Dict[int, str]) -> Tuple[List[int], List[int]]: normalized = _normalized_labels(id2label) fake_indices = [ idx for idx, label in normalized.items() if ("fake" in label) or ("deepfake" in label) or ("manipulated" in label) ] real_indices = [idx for idx, label in normalized.items() if ("real" in label) or ("authentic" in label)] return fake_indices, real_indices @lru_cache(maxsize=1) def load_model() -> Tuple[AutoImageProcessor, AutoModelForImageClassification, List[int], List[int]]: if not os.path.isdir(LOCAL_MODEL_PATH): raise ModelLoadError("Local model not found. Please download model first.") try: processor = AutoImageProcessor.from_pretrained(LOCAL_MODEL_PATH) model = AutoModelForImageClassification.from_pretrained(LOCAL_MODEL_PATH) model = model.to(torch.device("cpu")) model.eval() except Exception as exc: raise ModelLoadError(f"Failed to load local model from '{LOCAL_MODEL_PATH}': {exc}") from exc id2label = getattr(model.config, "id2label", {}) or {} fake_indices, real_indices = _resolve_label_indices(id2label) if not fake_indices and not real_indices: raise ModelLoadError( f"Could not infer fake/real labels from model.config.id2label: {id2label}" ) if not fake_indices: raise ModelLoadError( f"Could not find a fake/deepfake label in model.config.id2label: {id2label}" ) return processor, model, fake_indices, real_indices