ShunTay12 commited on
Commit
3486e63
·
1 Parent(s): c9f266b

Add ViT detector api

Browse files
app/core/detector/config.py CHANGED
@@ -14,7 +14,8 @@ if torch.cuda.is_available():
14
  print(f"GPU: {torch.cuda.get_device_name(0)}")
15
 
16
  # Model configuration
17
- BASE_MODEL_NAME = "shunda012/siglip-deepfake-detector"
 
18
 
19
  # Prediction threshold
20
  REAL_THRESHOLD = 0.90 # classify as real only when P(real) >= 90%
 
14
  print(f"GPU: {torch.cuda.get_device_name(0)}")
15
 
16
  # Model configuration
17
+ SIGLIP_MODEL_NAME = "shunda012/siglip-deepfake-detector"
18
+ VIT_MODEL_NAME = "shunda012/vit-deepfake-detector"
19
 
20
  # Prediction threshold
21
  REAL_THRESHOLD = 0.90 # classify as real only when P(real) >= 90%
app/core/detector/model.py CHANGED
@@ -5,9 +5,14 @@ Model loading for the deepfake detector.
5
  from dataclasses import dataclass
6
  from typing import Optional
7
 
8
- from transformers import AutoImageProcessor, SiglipForImageClassification
 
 
 
 
 
9
 
10
- from app.core.detector.config import BASE_MODEL_NAME, DEVICE
11
 
12
 
13
  @dataclass(frozen=True)
@@ -18,7 +23,16 @@ class SiglipResources:
18
  processor: AutoImageProcessor
19
 
20
 
 
 
 
 
 
 
 
 
21
  _siglip_resources: Optional[SiglipResources] = None
 
22
 
23
 
24
  def get_siglip_model() -> SiglipResources:
@@ -34,8 +48,8 @@ def get_siglip_model() -> SiglipResources:
34
  if _siglip_resources is None:
35
  print("Loading SigLIP Model...")
36
 
37
- siglip_processor = AutoImageProcessor.from_pretrained(BASE_MODEL_NAME)
38
- siglip_model = SiglipForImageClassification.from_pretrained(BASE_MODEL_NAME)
39
  siglip_model = siglip_model.to(DEVICE)
40
  siglip_model.eval()
41
 
@@ -45,3 +59,29 @@ def get_siglip_model() -> SiglipResources:
45
  )
46
 
47
  return _siglip_resources
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  from dataclasses import dataclass
6
  from typing import Optional
7
 
8
+ from transformers import (
9
+ AutoImageProcessor,
10
+ SiglipForImageClassification,
11
+ ViTImageProcessor,
12
+ ViTForImageClassification,
13
+ )
14
 
15
+ from app.core.detector.config import SIGLIP_MODEL_NAME, VIT_MODEL_NAME, DEVICE
16
 
17
 
18
  @dataclass(frozen=True)
 
23
  processor: AutoImageProcessor
24
 
25
 
26
+ @dataclass(frozen=True)
27
+ class ViTResources:
28
+ """Container for the ViT model and processor."""
29
+
30
+ model: ViTForImageClassification
31
+ processor: ViTImageProcessor
32
+
33
+
34
  _siglip_resources: Optional[SiglipResources] = None
35
+ _vit_resources: Optional[ViTResources] = None
36
 
37
 
38
  def get_siglip_model() -> SiglipResources:
 
48
  if _siglip_resources is None:
49
  print("Loading SigLIP Model...")
50
 
51
+ siglip_processor = AutoImageProcessor.from_pretrained(SIGLIP_MODEL_NAME)
52
+ siglip_model = SiglipForImageClassification.from_pretrained(SIGLIP_MODEL_NAME)
53
  siglip_model = siglip_model.to(DEVICE)
54
  siglip_model.eval()
55
 
 
59
  )
60
 
61
  return _siglip_resources
62
+
63
+
64
+ def get_vit_model() -> ViTResources:
65
+ """
66
+ Get or load the merged ViT detector model.
67
+
68
+ Returns:
69
+ ViTResources: Loaded model and processor (cached singleton).
70
+ """
71
+
72
+ global _vit_resources
73
+
74
+ if _vit_resources is None:
75
+ print("Loading ViT Model...")
76
+
77
+ vit_processor = ViTImageProcessor.from_pretrained(VIT_MODEL_NAME)
78
+ vit_model = ViTForImageClassification.from_pretrained(VIT_MODEL_NAME)
79
+ vit_model = vit_model.to(DEVICE)
80
+ vit_model.eval()
81
+
82
+ _vit_resources = ViTResources(
83
+ model=vit_model,
84
+ processor=vit_processor,
85
+ )
86
+
87
+ return _vit_resources
app/detector.py CHANGED
@@ -11,7 +11,12 @@ from typing import Callable
11
  from fastapi import APIRouter, Depends, File, HTTPException, UploadFile
12
  from PIL import Image, UnidentifiedImageError
13
 
14
- from app.core.detector.model import SiglipResources, get_siglip_model
 
 
 
 
 
15
  from app.services.detector.prediction import predict_single_image
16
  from app.services.detector.transforms import get_eval_transforms
17
 
@@ -27,8 +32,16 @@ def get_siglip_transforms():
27
  return get_eval_transforms(resources.processor, "siglip")
28
 
29
 
30
- @detector.post("/detect")
31
- async def detect_deepfake(
 
 
 
 
 
 
 
 
32
  file: UploadFile = File(...),
33
  resources: SiglipResources = Depends(get_siglip_model),
34
  siglip_transforms: Callable = Depends(get_siglip_transforms),
@@ -59,3 +72,37 @@ async def detect_deepfake(
59
  except Exception as exc: # pragma: no cover - defensive server guard
60
  logger.exception("Unhandled error during deepfake detection")
61
  raise HTTPException(status_code=500, detail="Error processing image") from exc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  from fastapi import APIRouter, Depends, File, HTTPException, UploadFile
12
  from PIL import Image, UnidentifiedImageError
13
 
14
+ from app.core.detector.model import (
15
+ SiglipResources,
16
+ ViTResources,
17
+ get_siglip_model,
18
+ get_vit_model,
19
+ )
20
  from app.services.detector.prediction import predict_single_image
21
  from app.services.detector.transforms import get_eval_transforms
22
 
 
32
  return get_eval_transforms(resources.processor, "siglip")
33
 
34
 
35
+ @lru_cache(maxsize=1)
36
+ def get_vit_transforms():
37
+ """Build and cache ViT evaluation transforms once per process."""
38
+
39
+ resources = get_vit_model()
40
+ return get_eval_transforms(resources.processor, "vit")
41
+
42
+
43
+ @detector.post("/siglip-detect")
44
+ async def siglip_detect_deepfake(
45
  file: UploadFile = File(...),
46
  resources: SiglipResources = Depends(get_siglip_model),
47
  siglip_transforms: Callable = Depends(get_siglip_transforms),
 
72
  except Exception as exc: # pragma: no cover - defensive server guard
73
  logger.exception("Unhandled error during deepfake detection")
74
  raise HTTPException(status_code=500, detail="Error processing image") from exc
75
+
76
+
77
+ @detector.post("/vit-detect")
78
+ async def vit_detect_deepfake(
79
+ file: UploadFile = File(...),
80
+ resources: ViTResources = Depends(get_vit_model),
81
+ vit_transforms: Callable = Depends(get_vit_transforms),
82
+ ):
83
+ """
84
+ Detect if an image is a deepfake or real using ViT + LoRA model.
85
+
86
+ Args:
87
+ file: Uploaded image file
88
+
89
+ Returns:
90
+ JSON response with prediction results
91
+ """
92
+
93
+ try:
94
+ image_bytes = await file.read()
95
+ image = Image.open(BytesIO(image_bytes)).convert("RGB")
96
+
97
+ result = predict_single_image(
98
+ image, resources.model, vit_transforms, "ViT + LoRA"
99
+ )
100
+
101
+ return result
102
+ except UnidentifiedImageError:
103
+ raise HTTPException(status_code=422, detail="Invalid or unsupported image file")
104
+ except HTTPException:
105
+ raise
106
+ except Exception as exc: # pragma: no cover - defensive server guard
107
+ logger.exception("Unhandled error during deepfake detection")
108
+ raise HTTPException(status_code=500, detail="Error processing image") from exc
app/services/detector/transforms.py CHANGED
@@ -8,11 +8,11 @@ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normal
8
  def get_eval_transforms(processor, model_type="vit"):
9
  """
10
  Create evaluation transforms based on processor settings.
11
-
12
  Args:
13
  processor: The image processor from the model
14
  model_type: Type of model ("vit" or "siglip")
15
-
16
  Returns:
17
  Composed transforms for image preprocessing
18
  """
@@ -20,7 +20,7 @@ def get_eval_transforms(processor, model_type="vit"):
20
  image_mean = processor.image_mean
21
  image_std = processor.image_std
22
  normalize = Normalize(mean=image_mean, std=image_std)
23
-
24
  return Compose(
25
  [
26
  Resize(size if model_type == "siglip" else 256),
 
8
  def get_eval_transforms(processor, model_type="vit"):
9
  """
10
  Create evaluation transforms based on processor settings.
11
+
12
  Args:
13
  processor: The image processor from the model
14
  model_type: Type of model ("vit" or "siglip")
15
+
16
  Returns:
17
  Composed transforms for image preprocessing
18
  """
 
20
  image_mean = processor.image_mean
21
  image_std = processor.image_std
22
  normalize = Normalize(mean=image_mean, std=image_std)
23
+
24
  return Compose(
25
  [
26
  Resize(size if model_type == "siglip" else 256),