ash12321 commited on
Commit
74e3cab
Β·
verified Β·
1 Parent(s): b4ebd38

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +337 -298
app.py CHANGED
@@ -1,10 +1,14 @@
1
  """
2
- ═══════════════════════════════════════════════════════════════════════
3
- V13 DEEPFAKE DETECTOR - GRADIO APP
4
- ═══════════════════════════════════════════════════════════════════════
5
- Upload an image and detect if it's real or AI-generated/deepfake
6
- Uses the best Model 3 (Swin-Large) with 99.96% accuracy
7
- ═══════════════════════════════════════════════════════════════════════
 
 
 
 
8
  """
9
 
10
  import gradio as gr
@@ -13,337 +17,372 @@ import torch.nn as nn
13
  from torchvision import transforms
14
  from PIL import Image
15
  import timm
16
- import json
17
- from huggingface_hub import hf_hub_download
18
  import numpy as np
 
 
 
 
19
 
20
- print("πŸš€ Loading Deepfake Detector...")
 
21
 
22
- # ═══════════════════════════════════════════════════════════════════════
23
- # CONFIGURATION
24
- # ═══════════════════════════════════════════════════════════════════════
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- REPO_ID = "ash12321/deepfake-detector-v13-optimized"
27
- MODEL_NUM = 1 # Using Model 1 (ConvNeXt - most reliable, 99.90% test F1)
28
 
29
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
30
- print(f"Device: {device}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- # ═══════════════════════════���═══════════════════════════════════════════
33
- # MODEL DEFINITION
34
- # ═══════════════════════════════════════════════════════════════════════
35
 
36
- class DeepfakeDetector(nn.Module):
37
- def __init__(self, backbone_name, dropout=0.3, hidden_dim=512, use_batch_norm=True):
 
38
  super().__init__()
39
 
40
- self.backbone = timm.create_model(backbone_name, pretrained=False, num_classes=0)
 
41
 
42
- if hasattr(self.backbone, 'num_features'):
43
- feat_dim = self.backbone.num_features
44
- else:
45
- with torch.no_grad():
46
- feat_dim = self.backbone(torch.randn(1, 3, 224, 224)).shape[1]
47
-
48
- if use_batch_norm:
49
- self.classifier = nn.Sequential(
50
- nn.Linear(feat_dim, hidden_dim),
51
- nn.BatchNorm1d(hidden_dim),
52
- nn.GELU(),
53
- nn.Dropout(dropout),
54
- nn.Linear(hidden_dim, hidden_dim // 4),
55
- nn.BatchNorm1d(hidden_dim // 4),
56
- nn.GELU(),
57
- nn.Dropout(dropout * 0.5),
58
- nn.Linear(hidden_dim // 4, 1)
59
- )
60
  else:
61
- self.classifier = nn.Sequential(
62
- nn.Linear(feat_dim, hidden_dim),
63
- nn.LayerNorm(hidden_dim),
64
- nn.GELU(),
65
- nn.Dropout(dropout),
66
- nn.Linear(hidden_dim, hidden_dim // 4),
67
- nn.LayerNorm(hidden_dim // 4),
68
- nn.GELU(),
69
- nn.Dropout(dropout * 0.5),
70
- nn.Linear(hidden_dim // 4, 1)
71
- )
72
-
73
  def forward(self, x):
74
- features = self.backbone(x)
75
- return self.classifier(features).squeeze(-1)
76
-
77
- # ═══════════════════════════════════════════════════════════════════════
78
- # LOAD MODEL
79
- # ═══════════════════════════════════════════════════════════════════════
80
-
81
- print("πŸ“₯ Downloading model from HuggingFace...")
82
-
83
- # Download model files
84
- model_path = hf_hub_download(
85
- repo_id=REPO_ID,
86
- filename=f"best_model_{MODEL_NUM}.pt"
87
- )
88
-
89
- params_path = hf_hub_download(
90
- repo_id=REPO_ID,
91
- filename=f"best_params_model_{MODEL_NUM}.json"
92
- )
93
-
94
- # Load parameters
95
- with open(params_path, 'r') as f:
96
- best_params = json.load(f)
97
-
98
- params = best_params['params']
99
- threshold = params['classification_threshold']
100
-
101
- print(f"βœ“ Using Model {MODEL_NUM}")
102
- print(f"βœ“ Threshold: {threshold:.4f}")
103
- print(f"βœ“ Test F1 Score: {best_params.get('f1_score', 'N/A')}")
104
 
105
- # Model architecture map
106
- backbone_map = {
107
- 1: 'convnext_large',
108
- 2: 'vit_large_patch16_224',
109
- 3: 'swin_large_patch4_window7_224'
110
- }
111
 
112
- # Create model
113
- print("πŸ”¨ Building model...")
114
- model = DeepfakeDetector(
115
- backbone_name=backbone_map[MODEL_NUM],
116
- dropout=params['dropout'],
117
- hidden_dim=params['hidden_dim'],
118
- use_batch_norm=params['use_batch_norm']
119
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
- # Load weights
122
- checkpoint = torch.load(model_path, map_location=device)
123
- model.load_state_dict(checkpoint['model_state_dict'])
124
- model = model.to(device)
125
- model.eval()
126
 
127
- print("βœ… Model loaded successfully!\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
- # ═══════════════════════════════════════════════════════════════════════
130
- # IMAGE PREPROCESSING
131
- # ═══════════════════════════════════════════════════════════════════════
132
 
133
- transform = transforms.Compose([
134
- transforms.Resize((224, 224)),
135
- transforms.ToTensor(),
136
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
137
- ])
 
 
 
 
 
 
 
 
 
 
138
 
139
- # ═══════════════════════════════════════════════════════════════════════
140
- # PREDICTION FUNCTION
141
- # ═══════════════════════════════════════════════════════════════════════
142
 
143
- def predict_image(image, custom_threshold=None):
 
144
  """
145
- Predict if an image is real or fake
146
 
147
  Args:
148
  image: PIL Image
149
- custom_threshold: Optional custom threshold (0-1)
150
-
151
  Returns:
152
- dict: Prediction results with confidence scores
153
  """
154
- if image is None:
155
- return {
156
- "Error": "Please upload an image"
157
- }
158
 
159
- # Use custom threshold if provided, otherwise use default
160
- thresh = custom_threshold if custom_threshold is not None else threshold
 
 
161
 
162
- try:
163
- # Convert to RGB if needed
164
- if image.mode != 'RGB':
165
- image = image.convert('RGB')
166
-
167
- # Preprocess
168
- img_tensor = transform(image).unsqueeze(0).to(device)
169
-
170
- # Predict
171
- with torch.no_grad():
172
- logit = model(img_tensor)
173
- probability = torch.sigmoid(logit).item()
174
-
175
- # Determine prediction
176
- is_fake = probability > thresh
177
-
178
- # Calculate confidence
179
- if is_fake:
180
- confidence = probability * 100
181
- label = "🚨 FAKE / AI-GENERATED"
182
- color = "red"
183
- else:
184
- confidence = (1 - probability) * 100
185
- label = "βœ… REAL"
186
- color = "green"
187
-
188
- # Create result dictionary for Gradio
189
- result = {
190
- "Prediction": label,
191
- "Confidence": f"{confidence:.2f}%",
192
- "Raw Score": f"{probability:.4f}",
193
- "Threshold Used": f"{thresh:.4f}"
194
- }
195
-
196
- # Additional context
197
- if confidence > 95:
198
- certainty = "Very High Certainty"
199
- elif confidence > 85:
200
- certainty = "High Certainty"
201
- elif confidence > 70:
202
- certainty = "Moderate Certainty"
203
- else:
204
- certainty = "Low Certainty - Manual Review Recommended"
205
-
206
- result["Certainty Level"] = certainty
207
-
208
- return result
209
 
210
- except Exception as e:
211
- return {
212
- "Error": f"Prediction failed: {str(e)}"
213
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
 
215
- # ═══════════════════════════════════════════════════════════════════════
216
- # GRADIO INTERFACE
217
- # ═══════════════════════════════════════════════════════════════════════
 
218
 
219
- # Create the interface
220
- with gr.Blocks() as demo:
 
 
221
 
222
- gr.Markdown(
223
- """
224
- # πŸ” Deepfake Detector V13
225
 
226
- Upload an image to detect if it's **REAL** or **AI-GENERATED/DEEPFAKE**
 
227
 
228
- **Model Performance:**
229
- - βœ… 99.96% Accuracy on test set
230
- - βœ… 100% Recall (catches all fakes)
231
- - βœ… Model 3: Swin-Large (197M parameters)
232
 
233
- **Supported:** Faces, portraits, AI-generated images, deepfakes
234
- """
235
- )
236
-
237
- with gr.Row():
238
- with gr.Column():
239
- image_input = gr.Image(
240
- type="pil",
241
- label="Upload Image",
242
- height=400
243
- )
244
-
245
- threshold_slider = gr.Slider(
246
- minimum=0.1,
247
- maximum=0.9,
248
- value=threshold,
249
- step=0.05,
250
- label="Detection Threshold (Lower = More Sensitive to Fakes)",
251
- info="Adjust if getting too many false positives/negatives"
252
- )
253
-
254
- predict_btn = gr.Button(
255
- "πŸ” Analyze Image",
256
- variant="primary",
257
- size="lg"
258
- )
259
-
260
- gr.Markdown(
261
- """
262
- ### πŸ’‘ Tips:
263
- - Upload clear images with visible faces
264
- - Works best with portraits and headshots
265
- - Supports: JPG, PNG, WebP
266
- - **Adjust threshold if results seem off**
267
- """
268
- )
269
-
270
- with gr.Column():
271
- result_output = gr.JSON(
272
- label="Detection Results"
273
- )
274
-
275
- gr.Markdown(
276
- """
277
- ### πŸ“Š Understanding Results:
278
-
279
- **Prediction:** REAL or FAKE classification
280
 
281
- **Confidence:** How certain the model is (0-100%)
 
 
 
 
 
 
 
282
 
283
- **Raw Score:** Internal probability (0-1)
284
- - Above threshold β†’ FAKE
285
- - Below threshold β†’ REAL
286
 
287
- **Certainty Level:**
288
- - Very High (>95%): Trust the result
289
- - High (85-95%): Reliable
290
- - Moderate (70-85%): Generally accurate
291
- - Low (<70%): Consider manual review
292
- """
293
- )
294
-
295
- # Examples
296
- gr.Markdown("### πŸ“Έ Try These Examples:")
297
- gr.Examples(
298
- examples=[
299
- # Add example image paths here if you have them
300
- ],
301
- inputs=image_input,
302
- outputs=result_output,
303
- fn=predict_image,
304
- cache_examples=False
305
- )
306
-
307
- # Connect button to function
308
- predict_btn.click(
309
- fn=predict_image,
310
- inputs=[image_input, threshold_slider],
311
- outputs=result_output
312
- )
313
-
314
- # Auto-predict on upload
315
- image_input.change(
316
- fn=predict_image,
317
- inputs=[image_input, threshold_slider],
318
- outputs=result_output
319
- )
320
-
321
- gr.Markdown(
322
- """
323
  ---
324
- **Model Details:**
325
- - Architecture: Swin Transformer Large
326
- - Parameters: 197M
327
- - Training Data: 60,000 balanced real/fake images
328
- - Optimized with Optuna hyperparameter search
329
-
330
- **Limitations:**
331
- - Best for human faces and portraits
332
- - May not work well on heavily compressed images
333
- - Performance may vary on new AI generation methods
334
-
335
- **Version:** V13 Model 3 | **Accuracy:** 99.96%
336
- """
337
- )
 
 
 
338
 
339
- # ═══════════════════════════════════════════════════════════════════════
340
- # LAUNCH
341
- # ═══════════════════════════════════════════════════════════════════════
342
 
 
343
  if __name__ == "__main__":
344
- print("🌐 Launching Gradio interface...")
345
- demo.launch(
346
- share=True, # Creates public link
347
- server_name="0.0.0.0",
348
- server_port=7860
349
- )
 
1
  """
2
+ Gradio App for One-Class Deepfake Detector
3
+ This app loads the DeepSVDD model from HuggingFace and provides an interface to test images.
4
+
5
+ Create a new Space on HuggingFace:
6
+ 1. Go to https://huggingface.co/spaces
7
+ 2. Click "Create new Space"
8
+ 3. Name it (e.g., "deepfake-detector-demo")
9
+ 4. Select SDK: Gradio
10
+ 5. Create the Space
11
+ 6. Upload this file as "app.py"
12
  """
13
 
14
  import gradio as gr
 
17
  from torchvision import transforms
18
  from PIL import Image
19
  import timm
 
 
20
  import numpy as np
21
+ from huggingface_hub import hf_hub_download
22
+ import json
23
+ import warnings
24
+ warnings.filterwarnings('ignore')
25
 
26
+ # ==================== MODEL ARCHITECTURE ====================
27
+ # Copy the same architecture classes from your training script
28
 
29
+ class FrequencyFeatureExtractor:
30
+ @staticmethod
31
+ def extract_fft_features(image_tensor, n_features=64):
32
+ if image_tensor.shape[0] == 3:
33
+ gray = 0.299 * image_tensor[0] + 0.587 * image_tensor[1] + 0.114 * image_tensor[2]
34
+ else:
35
+ gray = image_tensor[0]
36
+
37
+ gray_np = gray.cpu().numpy()
38
+ fft = np.fft.fft2(gray_np)
39
+ fft_shift = np.fft.fftshift(fft)
40
+ magnitude = np.abs(fft_shift)
41
+
42
+ h, w = magnitude.shape
43
+ center_h, center_w = h // 2, w // 2
44
+
45
+ features = []
46
+ max_radius = min(center_h, center_w)
47
+ n_bins = int(np.sqrt(n_features))
48
+
49
+ for i in range(n_bins):
50
+ r_inner = int(i * max_radius / n_bins)
51
+ r_outer = int((i + 1) * max_radius / n_bins)
52
+
53
+ y, x = np.ogrid[-center_h:h-center_h, -center_w:w-center_w]
54
+ mask = (x*x + y*y >= r_inner*r_inner) & (x*x + y*y < r_outer*r_outer)
55
+
56
+ ring_values = magnitude[mask]
57
+ if len(ring_values) > 0:
58
+ features.extend([np.mean(ring_values), np.std(ring_values)])
59
+ else:
60
+ features.extend([0.0, 0.0])
61
+
62
+ features = features[:n_features]
63
+ if len(features) < n_features:
64
+ features.extend([0.0] * (n_features - len(features)))
65
+
66
+ return torch.tensor(features, dtype=torch.float32)
67
 
 
 
68
 
69
+ class CNNEncoder(nn.Module):
70
+ def __init__(self, channels=[64, 128, 256, 512], output_dim=256, image_size=224):
71
+ super().__init__()
72
+
73
+ layers = []
74
+ in_channels = 3
75
+
76
+ for out_channels in channels:
77
+ layers.extend([
78
+ nn.Conv2d(in_channels, out_channels, 3, padding=1),
79
+ nn.BatchNorm2d(out_channels),
80
+ nn.ReLU(inplace=True),
81
+ nn.Conv2d(out_channels, out_channels, 3, padding=1),
82
+ nn.BatchNorm2d(out_channels),
83
+ nn.ReLU(inplace=True),
84
+ nn.MaxPool2d(2, 2)
85
+ ])
86
+ in_channels = out_channels
87
+
88
+ self.conv_layers = nn.Sequential(*layers)
89
+ self.feature_size = channels[-1] * (image_size // (2 ** len(channels))) ** 2
90
+
91
+ self.fc = nn.Sequential(
92
+ nn.Linear(self.feature_size, 1024),
93
+ nn.ReLU(inplace=True),
94
+ nn.Dropout(0.3),
95
+ nn.Linear(1024, output_dim)
96
+ )
97
+
98
+ def forward(self, x):
99
+ x = self.conv_layers(x)
100
+ x = x.view(x.size(0), -1)
101
+ x = self.fc(x)
102
+ return x
103
 
 
 
 
104
 
105
+ class HybridEncoder(nn.Module):
106
+ def __init__(self, cnn_channels=[64, 128, 256, 512], vit_model="vit_small_patch16_224",
107
+ embedding_dim=512, use_frequency=True, image_size=224):
108
  super().__init__()
109
 
110
+ self.use_frequency = use_frequency
111
+ self.cnn_encoder = CNNEncoder(channels=cnn_channels, output_dim=256, image_size=image_size)
112
 
113
+ self.vit = timm.create_model(vit_model, pretrained=False, num_classes=0)
114
+ vit_dim = self.vit.num_features
115
+ self.vit_projection = nn.Linear(vit_dim, 256)
116
+
117
+ if self.use_frequency:
118
+ self.freq_dim = 64
119
+ self.freq_projection = nn.Linear(self.freq_dim, 128)
120
+ fusion_dim = 256 + 256 + 128
 
 
 
 
 
 
 
 
 
 
121
  else:
122
+ fusion_dim = 256 + 256
123
+
124
+ self.fusion = nn.Sequential(
125
+ nn.Linear(fusion_dim, 512),
126
+ nn.ReLU(inplace=True),
127
+ nn.Dropout(0.3),
128
+ nn.Linear(512, embedding_dim),
129
+ nn.BatchNorm1d(embedding_dim)
130
+ )
131
+
132
+ self.freq_extractor = FrequencyFeatureExtractor()
133
+
134
  def forward(self, x):
135
+ batch_size = x.size(0)
136
+
137
+ cnn_features = self.cnn_encoder(x)
138
+ vit_features = self.vit(x)
139
+ vit_features = self.vit_projection(vit_features)
140
+
141
+ if self.use_frequency:
142
+ freq_features = []
143
+ for i in range(batch_size):
144
+ freq_feat = self.freq_extractor.extract_fft_features(x[i], self.freq_dim)
145
+ freq_features.append(freq_feat)
146
+ freq_features = torch.stack(freq_features).to(x.device)
147
+ freq_features = self.freq_projection(freq_features)
148
+ combined = torch.cat([cnn_features, vit_features, freq_features], dim=1)
149
+ else:
150
+ combined = torch.cat([cnn_features, vit_features], dim=1)
151
+
152
+ embeddings = self.fusion(combined)
153
+ return embeddings
 
 
 
 
 
 
 
 
 
 
 
154
 
 
 
 
 
 
 
155
 
156
+ class DeepSVDD(nn.Module):
157
+ def __init__(self, embedding_dim=512, cnn_channels=[64, 128, 256, 512],
158
+ vit_model="vit_small_patch16_224", use_frequency=True, image_size=224):
159
+ super().__init__()
160
+
161
+ self.encoder = HybridEncoder(
162
+ cnn_channels=cnn_channels,
163
+ vit_model=vit_model,
164
+ embedding_dim=embedding_dim,
165
+ use_frequency=use_frequency,
166
+ image_size=image_size
167
+ )
168
+ self.embedding_dim = embedding_dim
169
+
170
+ self.register_buffer('center', torch.zeros(embedding_dim))
171
+ self.radius = nn.Parameter(torch.tensor(0.0), requires_grad=False)
172
+
173
+ def forward(self, x):
174
+ embeddings = self.encoder(x)
175
+ return embeddings
176
+
177
+ def get_distance(self, embeddings):
178
+ return torch.sum((embeddings - self.center) ** 2, dim=1)
179
 
 
 
 
 
 
180
 
181
+ # ==================== MODEL LOADING ====================
182
+ @torch.no_grad()
183
+ def load_model_from_hf(repo_id="ash12321/deepsvdd-model"):
184
+ """Load the DeepSVDD model from HuggingFace"""
185
+ print("Loading model from HuggingFace...")
186
+
187
+ # Download files
188
+ model_path = hf_hub_download(repo_id=repo_id, filename="deepsvdd_model.pth")
189
+ config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
190
+
191
+ # Load config
192
+ with open(config_path, 'r') as f:
193
+ config = json.load(f)
194
+
195
+ # Initialize model
196
+ model = DeepSVDD(
197
+ embedding_dim=config.get('EMBEDDING_DIM', 512),
198
+ cnn_channels=config.get('CNN_CHANNELS', [64, 128, 256, 512]),
199
+ vit_model=config.get('VIT_MODEL', 'vit_small_patch16_224'),
200
+ use_frequency=config.get('USE_FREQUENCY_FEATURES', True),
201
+ image_size=config.get('IMAGE_SIZE', 224)
202
+ )
203
+
204
+ # Load weights
205
+ checkpoint = torch.load(model_path, map_location='cpu')
206
+ model.load_state_dict(checkpoint['model_state_dict'])
207
+ model.center = checkpoint['center']
208
+ model.radius = checkpoint['radius']
209
+
210
+ model.eval()
211
+ print(f"βœ“ Model loaded successfully!")
212
+ print(f" Hypersphere radius: {model.radius.item():.4f}")
213
+ print(f" Center norm: {model.center.norm().item():.4f}")
214
+
215
+ return model, config
216
 
 
 
 
217
 
218
+ # ==================== IMAGE PREPROCESSING ====================
219
+ def preprocess_image(image):
220
+ """Preprocess PIL Image for model input"""
221
+ transform = transforms.Compose([
222
+ transforms.Resize((224, 224)),
223
+ transforms.ToTensor(),
224
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
225
+ ])
226
+
227
+ # Convert to RGB if needed
228
+ if image.mode != 'RGB':
229
+ image = image.convert('RGB')
230
+
231
+ image_tensor = transform(image).unsqueeze(0)
232
+ return image_tensor
233
 
 
 
 
234
 
235
+ # ==================== PREDICTION FUNCTION ====================
236
+ def predict_deepfake(image, threshold_multiplier=1.5):
237
  """
238
+ Predict if an image is a deepfake
239
 
240
  Args:
241
  image: PIL Image
242
+ threshold_multiplier: How many times the radius to use as threshold
243
+
244
  Returns:
245
+ prediction, confidence, distance, details
246
  """
247
+ # Preprocess
248
+ image_tensor = preprocess_image(image)
 
 
249
 
250
+ # Get embedding
251
+ with torch.no_grad():
252
+ embedding = model(image_tensor)
253
+ distance = model.get_distance(embedding).item()
254
 
255
+ # Calculate threshold
256
+ radius = model.radius.item()
257
+ threshold = radius * threshold_multiplier
258
+
259
+ # Make prediction
260
+ is_fake = distance > threshold
261
+
262
+ # Calculate confidence score (0-100%)
263
+ # Distance closer to center = more confident it's real
264
+ # Distance far from center = more confident it's fake
265
+ if is_fake:
266
+ # How far beyond threshold (0 = at threshold, 1+ = far beyond)
267
+ confidence = min(100, (distance - threshold) / threshold * 100)
268
+ else:
269
+ # How close to center (0 = at threshold, 100 = at center)
270
+ confidence = min(100, (1 - distance / threshold) * 100)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
 
272
+ # Create result dictionary
273
+ prediction = "🚨 LIKELY FAKE" if is_fake else "βœ… LIKELY REAL"
274
+
275
+ details = f"""
276
+ **Hypersphere Distance:** {distance:.4f}
277
+ **Detection Threshold:** {threshold:.4f}
278
+ **Hypersphere Radius:** {radius:.4f}
279
+
280
+ **How it works:**
281
+ - Real images cluster tightly in embedding space (small distance)
282
+ - Fake images fall outside this cluster (large distance)
283
+ - This model was trained ONLY on real images using one-class learning
284
+ """
285
+
286
+ # Create confidence bar chart
287
+ confidence_text = f"{confidence:.1f}% Confidence"
288
+
289
+ return prediction, confidence_text, details
290
+
291
 
292
+ # ==================== LOAD MODEL ====================
293
+ print("Initializing Deepfake Detector...")
294
+ model, config = load_model_from_hf("ash12321/deepsvdd-model")
295
+ print("βœ“ Ready!")
296
 
297
+
298
+ # ==================== GRADIO INTERFACE ====================
299
+ def create_interface():
300
+ """Create Gradio interface"""
301
 
302
+ with gr.Blocks(title="One-Class Deepfake Detector", theme=gr.themes.Soft()) as demo:
303
+ gr.Markdown("""
304
+ # πŸ” One-Class Deepfake Detector
305
 
306
+ This AI model detects deepfakes using **hypersphere-based anomaly detection** (DeepSVDD).
307
+ It was trained **exclusively on real images** and learns what "real" looks like in embedding space.
308
 
309
+ ### How it works:
310
+ 1. Upload an image (photo, portrait, scene, etc.)
311
+ 2. The model computes how far the image is from the "real image hypersphere"
312
+ 3. Images far from the center are flagged as potential deepfakes
313
 
314
+ **Note:** This is a research model. Adjust the threshold slider to control sensitivity.
315
+ """)
316
+
317
+ with gr.Row():
318
+ with gr.Column(scale=1):
319
+ image_input = gr.Image(type="pil", label="Upload Image to Test")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
 
321
+ threshold_slider = gr.Slider(
322
+ minimum=1.0,
323
+ maximum=3.0,
324
+ value=1.5,
325
+ step=0.1,
326
+ label="Detection Threshold Multiplier",
327
+ info="Higher = stricter (fewer false positives, more false negatives)"
328
+ )
329
 
330
+ submit_btn = gr.Button("πŸ” Analyze Image", variant="primary", size="lg")
 
 
331
 
332
+ gr.Markdown("""
333
+ ### πŸ’‘ Tips:
334
+ - Works best on faces, portraits, and natural scenes
335
+ - Higher threshold = more conservative (flags only obvious fakes)
336
+ - Lower threshold = more aggressive (flags anything unusual)
337
+ - Default (1.5x) is a good starting point
338
+ """)
339
+
340
+ with gr.Column(scale=1):
341
+ prediction_output = gr.Textbox(label="Prediction", lines=2)
342
+ confidence_output = gr.Textbox(label="Confidence Score", lines=1)
343
+ details_output = gr.Markdown(label="Technical Details")
344
+
345
+ # Examples
346
+ gr.Markdown("### πŸ“Έ Try Example Images:")
347
+ gr.Examples(
348
+ examples=[
349
+ ["examples/real1.jpg", 1.5],
350
+ ["examples/real2.jpg", 1.5],
351
+ ["examples/fake1.jpg", 1.5],
352
+ ],
353
+ inputs=[image_input, threshold_slider],
354
+ label="Example Images"
355
+ )
356
+
357
+ # Connect button
358
+ submit_btn.click(
359
+ fn=predict_deepfake,
360
+ inputs=[image_input, threshold_slider],
361
+ outputs=[prediction_output, confidence_output, details_output]
362
+ )
363
+
364
+ gr.Markdown("""
 
 
 
365
  ---
366
+ ### 🧠 About the Model
367
+
368
+ **Architecture:** Hybrid CNN + Vision Transformer + FFT Frequency Features
369
+
370
+ **Training:** Trained on 50,000+ real images using DeepSVDD (Deep Support Vector Data Description)
371
+
372
+ **Method:** One-class learning - learns the distribution of real images only
373
+
374
+ **Novelty:** Unlike binary classifiers, this model doesn't learn specific fake patterns.
375
+ It learns what's "normal" and flags anything anomalous, making it more robust to new deepfake methods.
376
+
377
+ ---
378
+ **Model by:** [ash12321](https://huggingface.co/ash12321) |
379
+ **Source Code:** [GitHub](https://github.com/ash12321/deepfake-detector)
380
+ """)
381
+
382
+ return demo
383
 
 
 
 
384
 
385
+ # ==================== LAUNCH ====================
386
  if __name__ == "__main__":
387
+ demo = create_interface()
388
+ demo.launch()