ShunTay12
Add ViT detector api
3486e63
"""
Model loading for the deepfake detector.
"""
from dataclasses import dataclass
from typing import Optional
from transformers import (
AutoImageProcessor,
SiglipForImageClassification,
ViTImageProcessor,
ViTForImageClassification,
)
from app.core.detector.config import SIGLIP_MODEL_NAME, VIT_MODEL_NAME, DEVICE
@dataclass(frozen=True)
class SiglipResources:
"""Container for the SigLIP model and processor."""
model: SiglipForImageClassification
processor: AutoImageProcessor
@dataclass(frozen=True)
class ViTResources:
"""Container for the ViT model and processor."""
model: ViTForImageClassification
processor: ViTImageProcessor
_siglip_resources: Optional[SiglipResources] = None
_vit_resources: Optional[ViTResources] = None
def get_siglip_model() -> SiglipResources:
"""
Get or load the merged SigLIP detector model.
Returns:
SiglipResources: Loaded model and processor (cached singleton).
"""
global _siglip_resources
if _siglip_resources is None:
print("Loading SigLIP Model...")
siglip_processor = AutoImageProcessor.from_pretrained(SIGLIP_MODEL_NAME)
siglip_model = SiglipForImageClassification.from_pretrained(SIGLIP_MODEL_NAME)
siglip_model = siglip_model.to(DEVICE)
siglip_model.eval()
_siglip_resources = SiglipResources(
model=siglip_model,
processor=siglip_processor,
)
return _siglip_resources
def get_vit_model() -> ViTResources:
"""
Get or load the merged ViT detector model.
Returns:
ViTResources: Loaded model and processor (cached singleton).
"""
global _vit_resources
if _vit_resources is None:
print("Loading ViT Model...")
vit_processor = ViTImageProcessor.from_pretrained(VIT_MODEL_NAME)
vit_model = ViTForImageClassification.from_pretrained(VIT_MODEL_NAME)
vit_model = vit_model.to(DEVICE)
vit_model.eval()
_vit_resources = ViTResources(
model=vit_model,
processor=vit_processor,
)
return _vit_resources