File size: 2,133 Bytes
cfcf570
 
 
 
 
 
 
3486e63
 
 
 
 
 
cfcf570
3486e63
cfcf570
 
 
 
 
 
 
 
 
 
3486e63
 
 
 
 
 
 
 
cfcf570
3486e63
cfcf570
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3486e63
 
cfcf570
 
 
 
 
 
 
 
 
3486e63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
"""
Model loading for the deepfake detector.
"""

from dataclasses import dataclass
from typing import Optional

from transformers import (
    AutoImageProcessor,
    SiglipForImageClassification,
    ViTImageProcessor,
    ViTForImageClassification,
)

from app.core.detector.config import SIGLIP_MODEL_NAME, VIT_MODEL_NAME, DEVICE


@dataclass(frozen=True)
class SiglipResources:
    """Container for the SigLIP model and processor."""

    model: SiglipForImageClassification
    processor: AutoImageProcessor


@dataclass(frozen=True)
class ViTResources:
    """Container for the ViT model and processor."""

    model: ViTForImageClassification
    processor: ViTImageProcessor


_siglip_resources: Optional[SiglipResources] = None
_vit_resources: Optional[ViTResources] = None


def get_siglip_model() -> SiglipResources:
    """
    Get or load the merged SigLIP detector model.

    Returns:
        SiglipResources: Loaded model and processor (cached singleton).
    """

    global _siglip_resources

    if _siglip_resources is None:
        print("Loading SigLIP Model...")

        siglip_processor = AutoImageProcessor.from_pretrained(SIGLIP_MODEL_NAME)
        siglip_model = SiglipForImageClassification.from_pretrained(SIGLIP_MODEL_NAME)
        siglip_model = siglip_model.to(DEVICE)
        siglip_model.eval()

        _siglip_resources = SiglipResources(
            model=siglip_model,
            processor=siglip_processor,
        )

    return _siglip_resources


def get_vit_model() -> ViTResources:
    """
    Get or load the merged ViT detector model.

    Returns:
        ViTResources: Loaded model and processor (cached singleton).
    """

    global _vit_resources

    if _vit_resources is None:
        print("Loading ViT Model...")

        vit_processor = ViTImageProcessor.from_pretrained(VIT_MODEL_NAME)
        vit_model = ViTForImageClassification.from_pretrained(VIT_MODEL_NAME)
        vit_model = vit_model.to(DEVICE)
        vit_model.eval()

        _vit_resources = ViTResources(
            model=vit_model,
            processor=vit_processor,
        )

    return _vit_resources