Spaces:
Sleeping
Sleeping
File size: 2,078 Bytes
6a8443b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 | import os
from functools import lru_cache
from typing import Dict, List, Tuple
import torch
from transformers import AutoImageProcessor, AutoModelForImageClassification
LOCAL_MODEL_PATH = os.path.join(os.path.dirname(__file__), "models", "local_model")
class ModelLoadError(RuntimeError):
"""Raised when the deepfake model cannot be loaded."""
def _normalized_labels(id2label: Dict[int, str]) -> Dict[int, str]:
return {idx: str(label).strip().lower() for idx, label in id2label.items()}
def _resolve_label_indices(id2label: Dict[int, str]) -> Tuple[List[int], List[int]]:
normalized = _normalized_labels(id2label)
fake_indices = [
idx
for idx, label in normalized.items()
if ("fake" in label) or ("deepfake" in label) or ("manipulated" in label)
]
real_indices = [idx for idx, label in normalized.items() if ("real" in label) or ("authentic" in label)]
return fake_indices, real_indices
@lru_cache(maxsize=1)
def load_model() -> Tuple[AutoImageProcessor, AutoModelForImageClassification, List[int], List[int]]:
if not os.path.isdir(LOCAL_MODEL_PATH):
raise ModelLoadError("Local model not found. Please download model first.")
try:
processor = AutoImageProcessor.from_pretrained(LOCAL_MODEL_PATH)
model = AutoModelForImageClassification.from_pretrained(LOCAL_MODEL_PATH)
model = model.to(torch.device("cpu"))
model.eval()
except Exception as exc:
raise ModelLoadError(f"Failed to load local model from '{LOCAL_MODEL_PATH}': {exc}") from exc
id2label = getattr(model.config, "id2label", {}) or {}
fake_indices, real_indices = _resolve_label_indices(id2label)
if not fake_indices and not real_indices:
raise ModelLoadError(
f"Could not infer fake/real labels from model.config.id2label: {id2label}"
)
if not fake_indices:
raise ModelLoadError(
f"Could not find a fake/deepfake label in model.config.id2label: {id2label}"
)
return processor, model, fake_indices, real_indices
|