ash12321 commited on
Commit
13ee6ba
·
verified ·
1 Parent(s): 0865846

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +256 -0
model.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SDXL Detector Model
3
+ ===================
4
+
5
+ Vision Transformer-based model for detecting SDXL-generated images.
6
+
7
+ This model is a binary classifier that detects whether an image
8
+ was generated by Stable Diffusion XL (SDXL).
9
+
10
+ ⚠️ IMPORTANT: This model ONLY detects SDXL images!
11
+ - SDXL images → Classified as "Fake"
12
+ - Real images → Classified as "Real"
13
+ - FLUX/Midjourney/other AI → Classified as "Real" (not trained on these!)
14
+
15
+ For comprehensive AI detection, use this as part of an ensemble with
16
+ other specialized detectors.
17
+
18
+ Architecture:
19
+ - Base: Vision Transformer (ViT-base-patch16-224)
20
+ - Classifier: Dropout + Linear (768 → 2)
21
+ - Output: Binary (0=Real, 1=SDXL-Fake)
22
+
23
+ Quick Start:
24
+ from transformers import ViTForImageClassification, ViTImageProcessor
25
+ from PIL import Image
26
+
27
+ # Load model
28
+ model = ViTForImageClassification.from_pretrained(
29
+ "ash12321/sdxl-detector-vit"
30
+ )
31
+ processor = ViTImageProcessor.from_pretrained(
32
+ "google/vit-base-patch16-224"
33
+ )
34
+
35
+ # Process image
36
+ image = Image.open("test.jpg")
37
+ inputs = processor(images=image, return_tensors="pt")
38
+
39
+ # Get prediction
40
+ outputs = model(**inputs)
41
+ probs = torch.softmax(outputs.logits, dim=1)
42
+
43
+ if probs[0][1] > 0.5:
44
+ print(f"SDXL-Generated: {probs[0][1]:.2%}")
45
+ else:
46
+ print(f"Not SDXL: {probs[0][0]:.2%}")
47
+
48
+ Performance:
49
+ Test Accuracy: 99.60%
50
+ Precision: 99.30%
51
+ Recall: 99.90%
52
+ False Positive Rate: 0.70%
53
+ False Negative Rate: 0.10%
54
+ """
55
+
56
+ import torch
57
+ import torch.nn as nn
58
+ from transformers import ViTForImageClassification, ViTImageProcessor
59
+ from PIL import Image
60
+ from typing import Dict, Union, Optional
61
+ from pathlib import Path
62
+
63
+
64
+ class SDXLDetector:
65
+ """
66
+ SDXL Image Detector
67
+
68
+ Easy-to-use wrapper for detecting SDXL-generated images.
69
+ """
70
+
71
+ def __init__(
72
+ self,
73
+ model_path: str = "ash12321/sdxl-detector-vit",
74
+ device: str = None
75
+ ):
76
+ """
77
+ Initialize SDXL detector
78
+
79
+ Args:
80
+ model_path: HuggingFace model repo or local path
81
+ device: Device to use ('cuda', 'cpu', or None for auto)
82
+ """
83
+ if device is None:
84
+ device = "cuda" if torch.cuda.is_available() else "cpu"
85
+
86
+ self.device = device
87
+ self.model_path = model_path
88
+
89
+ # Load model and processor
90
+ self.model = ViTForImageClassification.from_pretrained(model_path)
91
+ self.model.to(device)
92
+ self.model.eval()
93
+
94
+ self.processor = ViTImageProcessor.from_pretrained(
95
+ "google/vit-base-patch16-224"
96
+ )
97
+
98
+ print(f"✅ SDXL Detector loaded on {device}")
99
+
100
+ def detect(
101
+ self,
102
+ image: Union[str, Path, Image.Image],
103
+ threshold: float = 0.5
104
+ ) -> Dict[str, Union[bool, float]]:
105
+ """
106
+ Detect if image is SDXL-generated
107
+
108
+ Args:
109
+ image: Image path or PIL Image
110
+ threshold: Classification threshold (default 0.5)
111
+
112
+ Returns:
113
+ dict with keys:
114
+ - is_sdxl: bool - True if SDXL-generated
115
+ - confidence: float - Confidence in prediction
116
+ - sdxl_probability: float - Probability of being SDXL
117
+ - real_probability: float - Probability of being real
118
+ - label: str - Human-readable label
119
+ """
120
+ # Load image if path
121
+ if isinstance(image, (str, Path)):
122
+ image = Image.open(image).convert('RGB')
123
+
124
+ # Process image
125
+ inputs = self.processor(images=image, return_tensors="pt")
126
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
127
+
128
+ # Get prediction
129
+ with torch.no_grad():
130
+ outputs = self.model(**inputs)
131
+ probs = torch.softmax(outputs.logits, dim=1)
132
+ sdxl_prob = probs[0][1].item()
133
+ real_prob = probs[0][0].item()
134
+
135
+ is_sdxl = sdxl_prob > threshold
136
+
137
+ return {
138
+ 'is_sdxl': is_sdxl,
139
+ 'confidence': sdxl_prob if is_sdxl else real_prob,
140
+ 'sdxl_probability': sdxl_prob,
141
+ 'real_probability': real_prob,
142
+ 'label': 'SDXL-Generated' if is_sdxl else 'Not SDXL'
143
+ }
144
+
145
+ def batch_detect(
146
+ self,
147
+ images: list,
148
+ threshold: float = 0.5
149
+ ) -> list:
150
+ """
151
+ Detect SDXL on multiple images
152
+
153
+ Args:
154
+ images: List of image paths or PIL Images
155
+ threshold: Classification threshold
156
+
157
+ Returns:
158
+ List of detection results
159
+ """
160
+ return [self.detect(img, threshold) for img in images]
161
+
162
+
163
+ def detect_sdxl(
164
+ image_path: str,
165
+ threshold: float = 0.5,
166
+ device: str = None
167
+ ) -> Dict[str, Union[bool, float]]:
168
+ """
169
+ Quick function to detect SDXL image
170
+
171
+ Args:
172
+ image_path: Path to image
173
+ threshold: Classification threshold
174
+ device: Device to use
175
+
176
+ Returns:
177
+ Detection results dictionary
178
+
179
+ Example:
180
+ >>> result = detect_sdxl("image.jpg")
181
+ >>> print(f"Is SDXL: {result['is_sdxl']}")
182
+ >>> print(f"Confidence: {result['confidence']:.2%}")
183
+ """
184
+ detector = SDXLDetector(device=device)
185
+ return detector.detect(image_path, threshold)
186
+
187
+
188
+ # Model specifications
189
+ MODEL_INFO = {
190
+ 'name': 'SDXL Detector',
191
+ 'version': '1.0',
192
+ 'type': 'Binary Classifier',
193
+ 'detects': 'Stable Diffusion XL (SDXL) images',
194
+ 'does_not_detect': [
195
+ 'FLUX images',
196
+ 'Midjourney images',
197
+ 'DALL-E images',
198
+ 'Other AI generators'
199
+ ],
200
+ 'architecture': 'Vision Transformer (ViT-base-patch16-224)',
201
+ 'input_size': (224, 224),
202
+ 'classes': {
203
+ 0: 'Real / Not SDXL',
204
+ 1: 'SDXL-Generated'
205
+ },
206
+ 'performance': {
207
+ 'test_accuracy': 0.9960,
208
+ 'precision': 0.9930,
209
+ 'recall': 0.9990,
210
+ 'f1_score': 0.9960,
211
+ 'false_positive_rate': 0.0070,
212
+ 'false_negative_rate': 0.0010
213
+ },
214
+ 'training': {
215
+ 'real_images': 8000,
216
+ 'sdxl_images': 8000,
217
+ 'epochs': 12,
218
+ 'best_epoch': 3
219
+ }
220
+ }
221
+
222
+
223
+ if __name__ == "__main__":
224
+ print("="*60)
225
+ print("SDXL Detector - Model Information")
226
+ print("="*60)
227
+ print(f"\nModel: {MODEL_INFO['name']}")
228
+ print(f"Detects: {MODEL_INFO['detects']}")
229
+ print(f"\n⚠️ Does NOT detect:")
230
+ for item in MODEL_INFO['does_not_detect']:
231
+ print(f" - {item}")
232
+ print(f"\n📊 Performance:")
233
+ print(f" Accuracy: {MODEL_INFO['performance']['test_accuracy']:.2%}")
234
+ print(f" Precision: {MODEL_INFO['performance']['precision']:.2%}")
235
+ print(f" Recall: {MODEL_INFO['performance']['recall']:.2%}")
236
+ print(f" FPR: {MODEL_INFO['performance']['false_positive_rate']:.2%}")
237
+ print(f" FNR: {MODEL_INFO['performance']['false_negative_rate']:.2%}")
238
+
239
+ print("\n" + "="*60)
240
+ print("Example Usage:")
241
+ print("="*60)
242
+ print("""
243
+ from model import SDXLDetector
244
+
245
+ # Initialize detector
246
+ detector = SDXLDetector()
247
+
248
+ # Detect single image
249
+ result = detector.detect("image.jpg")
250
+ print(f"Is SDXL: {result['is_sdxl']}")
251
+ print(f"Confidence: {result['confidence']:.2%}")
252
+
253
+ # Or use quick function
254
+ from model import detect_sdxl
255
+ result = detect_sdxl("image.jpg")
256
+ """)