Spaces:
Sleeping
Sleeping
| import os | |
| from functools import lru_cache | |
| from typing import Dict, List, Tuple | |
| import torch | |
| from transformers import AutoImageProcessor, AutoModelForImageClassification | |
| LOCAL_MODEL_PATH = os.path.join(os.path.dirname(__file__), "models", "local_model") | |
| class ModelLoadError(RuntimeError): | |
| """Raised when the deepfake model cannot be loaded.""" | |
| def _normalized_labels(id2label: Dict[int, str]) -> Dict[int, str]: | |
| return {idx: str(label).strip().lower() for idx, label in id2label.items()} | |
| def _resolve_label_indices(id2label: Dict[int, str]) -> Tuple[List[int], List[int]]: | |
| normalized = _normalized_labels(id2label) | |
| fake_indices = [ | |
| idx | |
| for idx, label in normalized.items() | |
| if ("fake" in label) or ("deepfake" in label) or ("manipulated" in label) | |
| ] | |
| real_indices = [idx for idx, label in normalized.items() if ("real" in label) or ("authentic" in label)] | |
| return fake_indices, real_indices | |
| def load_model() -> Tuple[AutoImageProcessor, AutoModelForImageClassification, List[int], List[int]]: | |
| if not os.path.isdir(LOCAL_MODEL_PATH): | |
| raise ModelLoadError("Local model not found. Please download model first.") | |
| try: | |
| processor = AutoImageProcessor.from_pretrained(LOCAL_MODEL_PATH) | |
| model = AutoModelForImageClassification.from_pretrained(LOCAL_MODEL_PATH) | |
| model = model.to(torch.device("cpu")) | |
| model.eval() | |
| except Exception as exc: | |
| raise ModelLoadError(f"Failed to load local model from '{LOCAL_MODEL_PATH}': {exc}") from exc | |
| id2label = getattr(model.config, "id2label", {}) or {} | |
| fake_indices, real_indices = _resolve_label_indices(id2label) | |
| if not fake_indices and not real_indices: | |
| raise ModelLoadError( | |
| f"Could not infer fake/real labels from model.config.id2label: {id2label}" | |
| ) | |
| if not fake_indices: | |
| raise ModelLoadError( | |
| f"Could not find a fake/deepfake label in model.config.id2label: {id2label}" | |
| ) | |
| return processor, model, fake_indices, real_indices | |