kriskraw commited on
Commit
6cfc411
Β·
verified Β·
1 Parent(s): fbbb9b5

Upload pixelprism_detector.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. pixelprism_detector.py +98 -0
pixelprism_detector.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PixelPrism v0.1 β€” DeepfakeDetector subclass for the BitMind DFD Arena.
2
+
3
+ This is a sanitized single-detector submission representing the most
4
+ informative component of PixelPrism's production V9 16-detector ensemble:
5
+ the Swin V2 transformer head (haywoodsloan/ai-image-detector-deploy),
6
+ which V9's permutation-importance audit ranked at 0.271 (38% of total
7
+ discriminative power, the strongest single feature in our ensemble).
8
+
9
+ The full PixelPrism V9 ensemble fuses 16 detectors via a
10
+ HistGradientBoostingClassifier meta-classifier and runs against a live
11
+ production API at https://pixelprism.ai. Some V9 components depend on
12
+ non-MIT-redistributable weights (FLUX-schnell for DIRE-FLUX, SDXL for
13
+ DIRE-SDXL) so this submission is the open-redistributable subset.
14
+
15
+ For the full 16-detector live numbers see https://pixelprism.ai/leaderboard
16
+ (refreshed monthly with each retrain).
17
+ """
18
+ from __future__ import annotations
19
+
20
+ from pathlib import Path
21
+
22
+ import numpy as np
23
+ import torch
24
+ from huggingface_hub import hf_hub_download
25
+ from PIL import Image
26
+ from transformers import AutoImageProcessor, Swinv2ForImageClassification
27
+
28
+ from arena.detectors.deepfake_detectors import DeepfakeDetector
29
+ from arena.detectors import DETECTOR_REGISTRY
30
+
31
+
32
+ # Module-level cache so repeated instantiations don't reload the model.
33
+ _PROCESSOR = None
34
+ _MODEL = None
35
+
36
+
37
+ @DETECTOR_REGISTRY.register_module(module_name='PixelPrism')
38
+ class PixelPrismDetector(DeepfakeDetector):
39
+ """Single-detector submission: Swin V2 wrapped with PixelPrism's production
40
+ inference path. Outputs P(AI) ∈ [0, 1] per image.
41
+
42
+ Attributes set from the YAML config:
43
+ hf_repo (str): "pixelprism-ai/dfd-arena-mini"
44
+ backbone_repo (str): "haywoodsloan/ai-image-detector-deploy"
45
+ ai_label_idx (int): the class index that means "AI" in the Swin head (0)
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ model_name: str = 'PixelPrism',
51
+ config: str = 'pixelprism_config.yaml',
52
+ cuda: bool = True,
53
+ ):
54
+ super().__init__(model_name, config, cuda)
55
+
56
+ # ───────────────────────── model load ─────────────────────────
57
+
58
+ def load_model(self):
59
+ global _PROCESSOR, _MODEL
60
+
61
+ if _MODEL is not None and _PROCESSOR is not None:
62
+ self.processor = _PROCESSOR
63
+ self.model = _MODEL
64
+ return
65
+
66
+ # The Swin V2 model + preprocessor come from haywoodsloan's MIT-licensed
67
+ # repo; we re-host the same files at pixelprism-ai/dfd-arena-mini so the
68
+ # arena evaluator can pull everything from one place.
69
+ repo = self.hf_repo
70
+
71
+ config_path = hf_hub_download(repo, 'config.json')
72
+ weights_path = hf_hub_download(repo, 'model.safetensors')
73
+ preproc_path = hf_hub_download(repo, 'preprocessor_config.json')
74
+
75
+ self.processor = AutoImageProcessor.from_pretrained(Path(preproc_path).parent)
76
+ self.model = Swinv2ForImageClassification.from_pretrained(Path(weights_path).parent)
77
+ self.model.to(self.device).eval()
78
+
79
+ _PROCESSOR = self.processor
80
+ _MODEL = self.model
81
+
82
+ # ───────────────────────── inference ─────────────────────────
83
+
84
+ def preprocess(self, image: Image.Image) -> torch.Tensor:
85
+ if image.mode != 'RGB':
86
+ image = image.convert('RGB')
87
+ inputs = self.processor(images=image, return_tensors='pt')
88
+ return inputs['pixel_values'].to(self.device)
89
+
90
+ def __call__(self, image: Image.Image) -> float:
91
+ """Returns P(AI) ∈ [0, 1] for a single PIL image."""
92
+ pixel_values = self.preprocess(image)
93
+ with torch.no_grad():
94
+ logits = self.model(pixel_values=pixel_values).logits
95
+ probs = torch.softmax(logits, dim=-1).cpu().numpy().ravel()
96
+ # ai_label_idx is set from the YAML config (typically 0 for haywoodsloan)
97
+ ai_idx = int(getattr(self, 'ai_label_idx', 0))
98
+ return float(probs[ai_idx])