ash12321 commited on
Commit
41e3e77
Β·
verified Β·
1 Parent(s): 94e4b1c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +335 -0
app.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ═══════════════════════════════════════════════════════════════════════
3
+ V13 DEEPFAKE DETECTOR - GRADIO APP
4
+ ═══════════════════════════════════════════════════════════════════════
5
+ Upload an image and detect if it's real or AI-generated/deepfake
6
+ Uses the best Model 3 (Swin-Large) with 99.96% accuracy
7
+ ═══════════════════════════════════════════════════════════════════════
8
+ """
9
+
10
+ import gradio as gr
11
+ import torch
12
+ import torch.nn as nn
13
+ from torchvision import transforms
14
+ from PIL import Image
15
+ import timm
16
+ import json
17
+ from huggingface_hub import hf_hub_download
18
+ import numpy as np
19
+
20
+ print("πŸš€ Loading Deepfake Detector...")
21
+
22
+ # ═══════════════════════════════════════════════════════════════════════
23
+ # CONFIGURATION
24
+ # ═══════════════════════════════════════════════════════════════════════
25
+
26
+ REPO_ID = "ash12321/deepfake-detector-v13-optimized"
27
+ MODEL_NUM = 3 # Using Model 3 (best performance)
28
+
29
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
30
+ print(f"Device: {device}")
31
+
32
+ # ═══════════════════════════════════════════════════════════════════════
33
+ # MODEL DEFINITION
34
+ # ═══════════════════════════════════════════════════════════════════════
35
+
36
+ class DeepfakeDetector(nn.Module):
37
+ def __init__(self, backbone_name, dropout=0.3, hidden_dim=512, use_batch_norm=True):
38
+ super().__init__()
39
+
40
+ self.backbone = timm.create_model(backbone_name, pretrained=False, num_classes=0)
41
+
42
+ if hasattr(self.backbone, 'num_features'):
43
+ feat_dim = self.backbone.num_features
44
+ else:
45
+ with torch.no_grad():
46
+ feat_dim = self.backbone(torch.randn(1, 3, 224, 224)).shape[1]
47
+
48
+ if use_batch_norm:
49
+ self.classifier = nn.Sequential(
50
+ nn.Linear(feat_dim, hidden_dim),
51
+ nn.BatchNorm1d(hidden_dim),
52
+ nn.GELU(),
53
+ nn.Dropout(dropout),
54
+ nn.Linear(hidden_dim, hidden_dim // 4),
55
+ nn.BatchNorm1d(hidden_dim // 4),
56
+ nn.GELU(),
57
+ nn.Dropout(dropout * 0.5),
58
+ nn.Linear(hidden_dim // 4, 1)
59
+ )
60
+ else:
61
+ self.classifier = nn.Sequential(
62
+ nn.Linear(feat_dim, hidden_dim),
63
+ nn.LayerNorm(hidden_dim),
64
+ nn.GELU(),
65
+ nn.Dropout(dropout),
66
+ nn.Linear(hidden_dim, hidden_dim // 4),
67
+ nn.LayerNorm(hidden_dim // 4),
68
+ nn.GELU(),
69
+ nn.Dropout(dropout * 0.5),
70
+ nn.Linear(hidden_dim // 4, 1)
71
+ )
72
+
73
+ def forward(self, x):
74
+ features = self.backbone(x)
75
+ return self.classifier(features).squeeze(-1)
76
+
77
+ # ═══════════════════════════════════════════════════════════════════════
78
+ # LOAD MODEL
79
+ # ═══════════════════════════════════════════════════════════════════════
80
+
81
+ print("πŸ“₯ Downloading model from HuggingFace...")
82
+
83
+ # Download model files
84
+ model_path = hf_hub_download(
85
+ repo_id=REPO_ID,
86
+ filename=f"best_model_{MODEL_NUM}.pt"
87
+ )
88
+
89
+ params_path = hf_hub_download(
90
+ repo_id=REPO_ID,
91
+ filename=f"best_params_model_{MODEL_NUM}.json"
92
+ )
93
+
94
+ # Load parameters
95
+ with open(params_path, 'r') as f:
96
+ best_params = json.load(f)
97
+
98
+ params = best_params['params']
99
+ threshold = params['classification_threshold']
100
+
101
+ print(f"βœ“ Using Model {MODEL_NUM}")
102
+ print(f"βœ“ Threshold: {threshold:.4f}")
103
+ print(f"βœ“ Test F1 Score: {best_params.get('f1_score', 'N/A')}")
104
+
105
+ # Model architecture map
106
+ backbone_map = {
107
+ 1: 'convnext_large',
108
+ 2: 'vit_large_patch16_224',
109
+ 3: 'swin_large_patch4_window7_224'
110
+ }
111
+
112
+ # Create model
113
+ print("πŸ”¨ Building model...")
114
+ model = DeepfakeDetector(
115
+ backbone_name=backbone_map[MODEL_NUM],
116
+ dropout=params['dropout'],
117
+ hidden_dim=params['hidden_dim'],
118
+ use_batch_norm=params['use_batch_norm']
119
+ )
120
+
121
+ # Load weights
122
+ checkpoint = torch.load(model_path, map_location=device)
123
+ model.load_state_dict(checkpoint['model_state_dict'])
124
+ model = model.to(device)
125
+ model.eval()
126
+
127
+ print("βœ… Model loaded successfully!\n")
128
+
129
+ # ═══════════════════════════════════════════════════════════════════════
130
+ # IMAGE PREPROCESSING
131
+ # ═══════════════════════════════════════════════════════════════════════
132
+
133
+ transform = transforms.Compose([
134
+ transforms.Resize((224, 224)),
135
+ transforms.ToTensor(),
136
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
137
+ ])
138
+
139
+ # ═══════════════════════════════════════════════════════════════════════
140
+ # PREDICTION FUNCTION
141
+ # ═══════════════════════════════════════════════════════════════════════
142
+
143
+ def predict_image(image):
144
+ """
145
+ Predict if an image is real or fake
146
+
147
+ Args:
148
+ image: PIL Image
149
+
150
+ Returns:
151
+ dict: Prediction results with confidence scores
152
+ """
153
+ if image is None:
154
+ return {
155
+ "Error": "Please upload an image"
156
+ }
157
+
158
+ try:
159
+ # Convert to RGB if needed
160
+ if image.mode != 'RGB':
161
+ image = image.convert('RGB')
162
+
163
+ # Preprocess
164
+ img_tensor = transform(image).unsqueeze(0).to(device)
165
+
166
+ # Predict
167
+ with torch.no_grad():
168
+ logit = model(img_tensor)
169
+ probability = torch.sigmoid(logit).item()
170
+
171
+ # Determine prediction
172
+ is_fake = probability > threshold
173
+
174
+ # Calculate confidence
175
+ if is_fake:
176
+ confidence = probability * 100
177
+ label = "🚨 FAKE / AI-GENERATED"
178
+ color = "red"
179
+ else:
180
+ confidence = (1 - probability) * 100
181
+ label = "βœ… REAL"
182
+ color = "green"
183
+
184
+ # Create result dictionary for Gradio
185
+ result = {
186
+ "Prediction": label,
187
+ "Confidence": f"{confidence:.2f}%",
188
+ "Raw Score": f"{probability:.4f}",
189
+ "Threshold": f"{threshold:.4f}"
190
+ }
191
+
192
+ # Additional context
193
+ if confidence > 95:
194
+ certainty = "Very High Certainty"
195
+ elif confidence > 85:
196
+ certainty = "High Certainty"
197
+ elif confidence > 70:
198
+ certainty = "Moderate Certainty"
199
+ else:
200
+ certainty = "Low Certainty - Manual Review Recommended"
201
+
202
+ result["Certainty Level"] = certainty
203
+
204
+ return result
205
+
206
+ except Exception as e:
207
+ return {
208
+ "Error": f"Prediction failed: {str(e)}"
209
+ }
210
+
211
+ # ═══════════════════════════════════════════════════════════════════════
212
+ # GRADIO INTERFACE
213
+ # ═══════════════════════════════════════════════════════════════════════
214
+
215
+ # Create the interface
216
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
217
+
218
+ gr.Markdown(
219
+ """
220
+ # πŸ” Deepfake Detector V13
221
+
222
+ Upload an image to detect if it's **REAL** or **AI-GENERATED/DEEPFAKE**
223
+
224
+ **Model Performance:**
225
+ - βœ… 99.96% Accuracy on test set
226
+ - βœ… 100% Recall (catches all fakes)
227
+ - βœ… Model 3: Swin-Large (197M parameters)
228
+
229
+ **Supported:** Faces, portraits, AI-generated images, deepfakes
230
+ """
231
+ )
232
+
233
+ with gr.Row():
234
+ with gr.Column():
235
+ image_input = gr.Image(
236
+ type="pil",
237
+ label="Upload Image",
238
+ height=400
239
+ )
240
+
241
+ predict_btn = gr.Button(
242
+ "πŸ” Analyze Image",
243
+ variant="primary",
244
+ size="lg"
245
+ )
246
+
247
+ gr.Markdown(
248
+ """
249
+ ### πŸ’‘ Tips:
250
+ - Upload clear images with visible faces
251
+ - Works best with portraits and headshots
252
+ - Supports: JPG, PNG, WebP
253
+ """
254
+ )
255
+
256
+ with gr.Column():
257
+ result_output = gr.JSON(
258
+ label="Detection Results"
259
+ )
260
+
261
+ gr.Markdown(
262
+ """
263
+ ### πŸ“Š Understanding Results:
264
+
265
+ **Prediction:** REAL or FAKE classification
266
+
267
+ **Confidence:** How certain the model is (0-100%)
268
+
269
+ **Raw Score:** Internal probability (0-1)
270
+ - Above threshold β†’ FAKE
271
+ - Below threshold β†’ REAL
272
+
273
+ **Certainty Level:**
274
+ - Very High (>95%): Trust the result
275
+ - High (85-95%): Reliable
276
+ - Moderate (70-85%): Generally accurate
277
+ - Low (<70%): Consider manual review
278
+ """
279
+ )
280
+
281
+ # Examples
282
+ gr.Markdown("### πŸ“Έ Try These Examples:")
283
+ gr.Examples(
284
+ examples=[
285
+ # Add example image paths here if you have them
286
+ ],
287
+ inputs=image_input,
288
+ outputs=result_output,
289
+ fn=predict_image,
290
+ cache_examples=False
291
+ )
292
+
293
+ # Connect button to function
294
+ predict_btn.click(
295
+ fn=predict_image,
296
+ inputs=image_input,
297
+ outputs=result_output
298
+ )
299
+
300
+ # Auto-predict on upload
301
+ image_input.change(
302
+ fn=predict_image,
303
+ inputs=image_input,
304
+ outputs=result_output
305
+ )
306
+
307
+ gr.Markdown(
308
+ """
309
+ ---
310
+ **Model Details:**
311
+ - Architecture: Swin Transformer Large
312
+ - Parameters: 197M
313
+ - Training Data: 60,000 balanced real/fake images
314
+ - Optimized with Optuna hyperparameter search
315
+
316
+ **Limitations:**
317
+ - Best for human faces and portraits
318
+ - May not work well on heavily compressed images
319
+ - Performance may vary on new AI generation methods
320
+
321
+ **Version:** V13 Model 3 | **Accuracy:** 99.96%
322
+ """
323
+ )
324
+
325
+ # ═══════════════════════════════════════════════════════════════════════
326
+ # LAUNCH
327
+ # ═══════════════════════════════════════════════════════════════════════
328
+
329
+ if __name__ == "__main__":
330
+ print("🌐 Launching Gradio interface...")
331
+ demo.launch(
332
+ share=True, # Creates public link
333
+ server_name="0.0.0.0",
334
+ server_port=7860
335
+ )