File size: 2,078 Bytes
6a8443b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import os
from functools import lru_cache
from typing import Dict, List, Tuple

import torch
from transformers import AutoImageProcessor, AutoModelForImageClassification


LOCAL_MODEL_PATH = os.path.join(os.path.dirname(__file__), "models", "local_model")


class ModelLoadError(RuntimeError):
    """Raised when the deepfake model cannot be loaded."""


def _normalized_labels(id2label: Dict[int, str]) -> Dict[int, str]:
    return {idx: str(label).strip().lower() for idx, label in id2label.items()}


def _resolve_label_indices(id2label: Dict[int, str]) -> Tuple[List[int], List[int]]:
    normalized = _normalized_labels(id2label)
    fake_indices = [
        idx
        for idx, label in normalized.items()
        if ("fake" in label) or ("deepfake" in label) or ("manipulated" in label)
    ]
    real_indices = [idx for idx, label in normalized.items() if ("real" in label) or ("authentic" in label)]
    return fake_indices, real_indices


@lru_cache(maxsize=1)
def load_model() -> Tuple[AutoImageProcessor, AutoModelForImageClassification, List[int], List[int]]:
    if not os.path.isdir(LOCAL_MODEL_PATH):
        raise ModelLoadError("Local model not found. Please download model first.")

    try:
        processor = AutoImageProcessor.from_pretrained(LOCAL_MODEL_PATH)
        model = AutoModelForImageClassification.from_pretrained(LOCAL_MODEL_PATH)
        model = model.to(torch.device("cpu"))
        model.eval()
    except Exception as exc:
        raise ModelLoadError(f"Failed to load local model from '{LOCAL_MODEL_PATH}': {exc}") from exc

    id2label = getattr(model.config, "id2label", {}) or {}
    fake_indices, real_indices = _resolve_label_indices(id2label)

    if not fake_indices and not real_indices:
        raise ModelLoadError(
            f"Could not infer fake/real labels from model.config.id2label: {id2label}"
        )

    if not fake_indices:
        raise ModelLoadError(
            f"Could not find a fake/deepfake label in model.config.id2label: {id2label}"
        )

    return processor, model, fake_indices, real_indices