jenithjain commited on
Commit
5a125c6
·
1 Parent(s): c5de110

Switch Space backend to extension-compatible deepfake API

Browse files
Files changed (3) hide show
  1. README.md +23 -13
  2. main.py +355 -459
  3. requirements.txt +4 -7
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
- title: DuneNet Model API
3
- emoji: 🚀
4
  colorFrom: yellow
5
  colorTo: red
6
  sdk: docker
@@ -8,18 +8,28 @@ app_port: 7860
8
  pinned: false
9
  ---
10
 
11
- # DuneNet Model API
12
 
13
- FastAPI backend for DuneNet - Autonomous UGV Perception Platform.
14
-
15
- Runs a fine-tuned Segformer (nvidia/mit-b4) model for semantic segmentation of desert terrain, providing:
16
- - Semantic segmentation masks
17
- - Traversability maps for autonomous navigation
18
- - Live simulation inference with costmap grids
19
 
20
  ## Endpoints
21
 
22
- - `GET /` Health check
23
- - `POST /predict` Full segmentation prediction
24
- - `POST /predict/sim` — Simulation-optimized prediction with traversability grid
25
- - `GET /model/info` Model metadata
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Deepfake Detection API
3
+ emoji: 🧠
4
  colorFrom: yellow
5
  colorTo: red
6
  sdk: docker
 
8
  pinned: false
9
  ---
10
 
11
+ # Deepfake Detection API (Hugging Face Space)
12
 
13
+ This Space serves a Deepfake detection backend that is compatible with your browser extension.
 
 
 
 
 
14
 
15
  ## Endpoints
16
 
17
+ - `GET /` - Health check
18
+ - `GET /health` - Extension health endpoint
19
+ - `POST /analyze` - Analyze one frame (`multipart/form-data`, field: `frame`)
20
+ - `POST /reset` - Reset temporal tracker
21
+
22
+ ## Model Files
23
+
24
+ Put your deepfake checkpoint in:
25
+
26
+ - `models/best_model.pth` (preferred)
27
+
28
+ If no compatible checkpoint is found, the API runs in forensic-only mode.
29
+
30
+ ## Extension Backend URL
31
+
32
+ After deployment, set the extension backend URL to your Space URL:
33
+
34
+ - `https://<your-username>-<your-space-name>.hf.space`
35
+
main.py CHANGED
@@ -1,496 +1,392 @@
1
- from fastapi import FastAPI, UploadFile, File, HTTPException
2
- from fastapi.middleware.cors import CORSMiddleware
3
- from pydantic import BaseModel
4
- import torch
5
- import torch.nn.functional as F
6
- from PIL import Image
7
  import io
8
- import base64
 
 
 
9
  import numpy as np
10
- from typing import Optional
11
- import uvicorn
12
- import albumentations as A
13
- from albumentations.pytorch import ToTensorV2
14
- from transformers import SegformerConfig, SegformerForSemanticSegmentation
 
15
 
16
- app = FastAPI(title="DuneNet Model API", version="1.0.0")
17
 
18
- # CORS middleware for Next.js frontend
19
  app.add_middleware(
20
  CORSMiddleware,
21
- allow_origins=[
22
- "http://localhost:3000",
23
- "https://*.vercel.app",
24
- "*",
25
- ],
26
  allow_credentials=True,
27
  allow_methods=["*"],
28
  allow_headers=["*"],
29
  )
30
 
31
- # Configuration
32
- NUM_CLASSES = 10
33
- IMG_SIZE = 512
34
- MODEL_NAME = 'nvidia/mit-b4'
35
-
36
- CLASS_NAMES = [
37
- 'Trees', 'Lush Bushes', 'Dry Grass', 'Dry Bushes', 'Ground Clutter',
38
- 'Flowers', 'Logs', 'Rocks', 'Landscape', 'Sky'
39
- ]
40
-
41
- CLASS_COLORS = np.array([
42
- [34, 139, 34], # Trees
43
- [0, 255, 127], # Lush Bushes
44
- [189, 183, 107], # Dry Grass
45
- [139, 119, 101], # Dry Bushes
46
- [160, 82, 45], # Ground Clutter
47
- [255, 105, 180], # Flowers
48
- [139, 69, 19], # Logs
49
- [128, 128, 128], # Rocks
50
- [210, 180, 140], # Landscape
51
- [135, 206, 235], # Sky
52
- ], dtype=np.uint8)
53
-
54
- # Traversability mapping
55
- TRAVERSABILITY = {
56
- 0: 'no_go', # Trees
57
- 1: 'no_go', # Lush Bushes
58
- 2: 'go', # Dry Grass
59
- 3: 'caution', # Dry Bushes
60
- 4: 'caution', # Ground Clutter
61
- 5: 'go', # Flowers
62
- 6: 'no_go', # Logs
63
- 7: 'caution', # Rocks
64
- 8: 'go', # Landscape
65
- 9: 'sky', # Sky
66
- }
67
-
68
- TRAV_COLORS = {
69
- 'go': np.array([0, 200, 0], dtype=np.uint8), # Green
70
- 'caution': np.array([255, 180, 0], dtype=np.uint8), # Orange
71
- 'no_go': np.array([220, 30, 30], dtype=np.uint8), # Red
72
- 'sky': np.array([180, 210, 240], dtype=np.uint8), # Light blue
73
- }
74
-
75
- # Global model variable
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  model = None
77
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
78
-
79
- class PredictionResponse(BaseModel):
80
- prediction: int
81
- class_name: str
82
- confidence: float
83
- device_used: str
84
- class_distribution: dict
85
- segmentation_mask: str # base64 encoded image
86
- overlay_image: str # base64 encoded overlay
87
- traversability_map: str # base64 encoded traversability
88
- traversability_overlay: str # base64 encoded traversability overlay
89
- traversability_stats: dict # safe, caution, blocked percentages
90
-
91
- class HealthResponse(BaseModel):
92
- status: str
93
- model_loaded: bool
94
- device: str
95
 
96
- @app.on_event("startup")
97
- async def load_model():
98
- """Load the Segformer model on startup"""
99
  global model
100
- try:
101
- import os
102
- possible_paths = [
103
- "models/latest_model_ft.pth",
104
- "api_server/models/latest_model_ft.pth",
105
- os.path.join(os.path.dirname(__file__), "models/latest_model_ft.pth"),
106
- ]
107
-
108
- model_path = None
109
- for path in possible_paths:
110
- if os.path.exists(path):
111
- model_path = path
112
- break
113
-
114
- if model_path is None:
115
- raise FileNotFoundError("latest_model_ft.pth not found in api_server/models/")
116
-
117
- print(f"Loading Segformer model from: {os.path.abspath(model_path)}")
118
-
119
- # Build Segformer model
120
- config = SegformerConfig.from_pretrained(MODEL_NAME)
121
- config.num_labels = NUM_CLASSES
122
- model = SegformerForSemanticSegmentation(config)
123
-
124
- # Load checkpoint
125
- checkpoint = torch.load(model_path, map_location=device, weights_only=False)
126
- model.load_state_dict(checkpoint['model_state_dict'])
127
-
128
- model = model.to(device)
129
  model.eval()
130
-
131
- miou = checkpoint.get('miou', 0)
132
- epoch = checkpoint.get('epoch', '?')
133
-
134
- print(f"✓ Segformer model loaded successfully on {device}")
135
- print(f" Epoch: {epoch}, Val mIoU: {miou:.4f}")
136
- print(f" Classes: {NUM_CLASSES}")
137
- print(f" Model: {MODEL_NAME}")
138
-
139
- except Exception as e:
140
- print(f"✗ Error loading model: {e}")
141
- import traceback
142
- traceback.print_exc()
143
- model = None
144
-
145
- @app.get("/", response_model=HealthResponse)
146
- async def health_check():
147
- """Health check endpoint"""
148
  return {
149
- "status": "running",
150
- "model_loaded": model is not None,
151
- "device": str(device)
 
152
  }
153
 
154
- def colorize_mask(class_mask):
155
- """Convert class mask to RGB colored image"""
156
- h, w = class_mask.shape
157
- rgb = np.zeros((h, w, 3), dtype=np.uint8)
158
- for c in range(NUM_CLASSES):
159
- rgb[class_mask == c] = CLASS_COLORS[c]
160
- return rgb
161
-
162
-
163
- def create_overlay(image_np, class_mask, alpha=0.5):
164
- """Blend original image with colored segmentation mask"""
165
- colored = colorize_mask(class_mask)
166
- overlay = (image_np.astype(np.float32) * (1 - alpha) + colored.astype(np.float32) * alpha)
167
- return overlay.astype(np.uint8)
168
-
169
-
170
- def create_traversability_map(class_mask):
171
- """Generate traversability map from segmentation mask"""
172
- h, w = class_mask.shape
173
- trav_mask = np.zeros((h, w, 3), dtype=np.uint8)
174
-
175
- for class_id, category in TRAVERSABILITY.items():
176
- region = (class_mask == class_id)
177
- trav_mask[region] = TRAV_COLORS[category]
178
-
179
- return trav_mask
180
-
181
-
182
- def calculate_traversability_stats(class_mask):
183
- """Calculate traversability statistics"""
184
- total_pixels = class_mask.size
185
- sky_pixels = (class_mask == 9).sum() # Sky class
186
- ground_pixels = total_pixels - sky_pixels
187
-
188
- if ground_pixels == 0:
189
- return {'safe': '0%', 'caution': '0%', 'blocked': '0%'}
190
-
191
- safe_pixels = 0
192
- caution_pixels = 0
193
- blocked_pixels = 0
194
-
195
- for class_id, category in TRAVERSABILITY.items():
196
- if category == 'sky':
197
- continue
198
- count = (class_mask == class_id).sum()
199
- if category == 'go':
200
- safe_pixels += count
201
- elif category == 'caution':
202
- caution_pixels += count
203
- elif category == 'no_go':
204
- blocked_pixels += count
205
-
206
  return {
207
- 'safe': f"{(safe_pixels / ground_pixels * 100):.1f}%",
208
- 'caution': f"{(caution_pixels / ground_pixels * 100):.1f}%",
209
- 'blocked': f"{(blocked_pixels / ground_pixels * 100):.1f}%"
 
 
 
 
 
 
210
  }
211
 
212
 
213
- def numpy_to_base64(image_np):
214
- """Convert numpy array to base64 string"""
215
- img = Image.fromarray(image_np)
216
- buffered = io.BytesIO()
217
- img.save(buffered, format="PNG")
218
- img_str = base64.b64encode(buffered.getvalue()).decode()
219
- return f"data:image/png;base64,{img_str}"
 
 
 
 
 
 
 
220
 
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
- @app.post("/predict", response_model=PredictionResponse)
223
- async def predict(file: UploadFile = File(...)):
224
- """Make prediction on uploaded image using Segformer"""
225
- if model is None:
226
- raise HTTPException(status_code=503, detail="Model not loaded")
227
-
228
  try:
229
- # Read and process image
230
- contents = await file.read()
231
- image = Image.open(io.BytesIO(contents)).convert('RGB')
232
- image_np = np.array(image)
233
- orig_h, orig_w = image_np.shape[:2]
234
-
235
- # Preprocessing with albumentations
236
- transform = A.Compose([
237
- A.Resize(height=IMG_SIZE, width=IMG_SIZE),
238
- A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
239
- ToTensorV2(),
240
- ])
241
-
242
- aug = transform(image=image_np)
243
- tensor = aug['image'].unsqueeze(0).to(device)
244
-
245
- # Inference
246
  with torch.no_grad():
247
- use_fp16 = device.type == 'cuda'
248
- with torch.amp.autocast(device_type=device.type, enabled=use_fp16):
249
- outputs = model(pixel_values=tensor)
250
-
251
- # Get logits and resize
252
- logits = F.interpolate(
253
- outputs.logits,
254
- size=(IMG_SIZE, IMG_SIZE),
255
- mode='bilinear',
256
- align_corners=False
257
- )
258
-
259
- # Get probabilities
260
- probs = torch.softmax(logits, dim=1).squeeze().cpu().numpy()
261
-
262
- # Get prediction mask
263
- pred_mask = np.argmax(probs, axis=0).astype(np.uint8)
264
-
265
- # Resize prediction to original image size
266
- pred_mask_orig = np.array(
267
- Image.fromarray(pred_mask).resize((orig_w, orig_h), Image.NEAREST)
268
- )
269
-
270
- # Calculate class distribution
271
- class_dist = {}
272
- total_pixels = pred_mask_orig.size
273
- for c in range(NUM_CLASSES):
274
- count = (pred_mask_orig == c).sum()
275
- if count > 0:
276
- class_dist[CLASS_NAMES[c]] = f"{(count / total_pixels * 100):.1f}%"
277
-
278
- # Get dominant class
279
- dominant_class = np.bincount(pred_mask_orig.flatten()).argmax()
280
- confidence = probs[dominant_class].mean()
281
-
282
- # Generate visualizations
283
- colored_mask = colorize_mask(pred_mask_orig)
284
- overlay = create_overlay(image_np, pred_mask_orig, alpha=0.5)
285
-
286
- # Generate traversability map
287
- print(f"Generating traversability map...")
288
- trav_map = create_traversability_map(pred_mask_orig)
289
- print(f"Traversability map shape: {trav_map.shape}")
290
-
291
- trav_overlay = create_overlay(image_np, pred_mask_orig, alpha=0.6)
292
- # Replace with traversability colors
293
- for class_id, category in TRAVERSABILITY.items():
294
- region = (pred_mask_orig == class_id)
295
- trav_overlay[region] = (
296
- image_np[region].astype(np.float32) * 0.4 +
297
- TRAV_COLORS[category].astype(np.float32) * 0.6
298
- ).astype(np.uint8)
299
-
300
- trav_stats = calculate_traversability_stats(pred_mask_orig)
301
- print(f"Traversability stats: {trav_stats}")
302
-
303
- # Convert to base64
304
- mask_base64 = numpy_to_base64(colored_mask)
305
- overlay_base64 = numpy_to_base64(overlay)
306
- trav_map_base64 = numpy_to_base64(trav_map)
307
- trav_overlay_base64 = numpy_to_base64(trav_overlay)
308
- print(f"All images converted to base64 successfully")
309
-
310
- return {
311
- "prediction": int(dominant_class),
312
- "class_name": CLASS_NAMES[dominant_class],
313
- "confidence": float(confidence),
314
- "device_used": str(device),
315
- "class_distribution": class_dist,
316
- "segmentation_mask": mask_base64,
317
- "overlay_image": overlay_base64,
318
- "traversability_map": trav_map_base64,
319
- "traversability_overlay": trav_overlay_base64,
320
- "traversability_stats": trav_stats
321
- }
322
- except Exception as e:
323
- import traceback
324
- traceback.print_exc()
325
- raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}")
326
-
327
-
328
- # ═══════════════════════════════════════════════════════════════
329
- # Simulation Live Inference
330
- # ═══════════════════════════════════════════════════════════════
331
-
332
- class SimPredictionResponse(BaseModel):
333
- segmentation_mask: str
334
- traversability_map: str
335
- traversability_overlay: str
336
- traversability_stats: dict
337
- traversability_grid: list
338
- class_distribution: dict
339
- inference_time_ms: float
340
- dominant_class: str
341
- confidence: float
342
-
343
-
344
- def create_traversability_grid(class_mask, grid_cols=12, grid_rows=8):
345
- """Create a coarse traversability grid from the prediction mask.
346
- Uses the bottom 65 % of the image (ground portion, excluding sky).
347
- Returns 2-D list of costmap values: 0 = go, 5 = caution, 10 = no_go.
348
- """
349
- h, w = class_mask.shape
350
- ground_start = int(h * 0.35)
351
- ground_mask = class_mask[ground_start:, :]
352
- gh, gw = ground_mask.shape
353
-
354
- cell_h = max(1, gh // grid_rows)
355
- cell_w = max(1, gw // grid_cols)
356
-
357
- grid = []
358
- for r in range(grid_rows):
359
- row = []
360
- for c in range(grid_cols):
361
- y0 = r * cell_h
362
- y1 = min((r + 1) * cell_h, gh)
363
- x0 = c * cell_w
364
- x1 = min((c + 1) * cell_w, gw)
365
-
366
- cell = ground_mask[y0:y1, x0:x1]
367
- if cell.size == 0:
368
- row.append(0)
369
- continue
370
-
371
- go_count = caution_count = no_go_count = 0
372
- for cid in range(NUM_CLASSES):
373
- cnt = int((cell == cid).sum())
374
- cat = TRAVERSABILITY[cid]
375
- if cat == 'go':
376
- go_count += cnt
377
- elif cat == 'caution':
378
- caution_count += cnt
379
- elif cat == 'no_go':
380
- no_go_count += cnt
381
-
382
- total = go_count + caution_count + no_go_count
383
- if total == 0:
384
- row.append(0)
385
- elif no_go_count / total > 0.3:
386
- row.append(10)
387
- elif caution_count / total > 0.3:
388
- row.append(5)
389
- else:
390
- row.append(0)
391
- grid.append(row)
392
- return grid
393
-
394
-
395
- @app.post("/predict/sim", response_model=SimPredictionResponse)
396
- async def predict_sim(file: UploadFile = File(...)):
397
- """Prediction endpoint optimised for simulation live inference.
398
- Returns a traversability grid suitable for direct costmap updates."""
399
- if model is None:
400
- raise HTTPException(status_code=503, detail="Model not loaded")
401
-
402
- import time
403
- t0 = time.time()
404
 
405
- try:
406
- contents = await file.read()
407
- image = Image.open(io.BytesIO(contents)).convert('RGB')
408
- image_np = np.array(image)
409
- orig_h, orig_w = image_np.shape[:2]
410
 
411
- transform = A.Compose([
412
- A.Resize(height=IMG_SIZE, width=IMG_SIZE),
413
- A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
414
- ToTensorV2(),
415
- ])
416
 
417
- aug = transform(image=image_np)
418
- tensor = aug['image'].unsqueeze(0).to(device)
419
 
420
- with torch.no_grad():
421
- use_fp16 = device.type == 'cuda'
422
- with torch.amp.autocast(device_type=device.type, enabled=use_fp16):
423
- outputs = model(pixel_values=tensor)
424
-
425
- logits = F.interpolate(
426
- outputs.logits,
427
- size=(IMG_SIZE, IMG_SIZE),
428
- mode='bilinear',
429
- align_corners=False,
430
- )
431
-
432
- probs = torch.softmax(logits, dim=1).squeeze().cpu().numpy()
433
- pred_mask = np.argmax(probs, axis=0).astype(np.uint8)
434
- pred_mask_orig = np.array(
435
- Image.fromarray(pred_mask).resize((orig_w, orig_h), Image.NEAREST)
436
- )
437
-
438
- # Visualisations
439
- colored_mask = colorize_mask(pred_mask_orig)
440
- trav_map = create_traversability_map(pred_mask_orig)
441
-
442
- trav_overlay_img = image_np.copy()
443
- for cid, category in TRAVERSABILITY.items():
444
- region = (pred_mask_orig == cid)
445
- trav_overlay_img[region] = (
446
- image_np[region].astype(np.float32) * 0.4
447
- + TRAV_COLORS[category].astype(np.float32) * 0.6
448
- ).astype(np.uint8)
449
-
450
- trav_stats = calculate_traversability_stats(pred_mask_orig)
451
- trav_grid = create_traversability_grid(pred_mask_orig)
452
-
453
- class_dist = {}
454
- total_pixels = pred_mask_orig.size
455
- for cid in range(NUM_CLASSES):
456
- cnt = int((pred_mask_orig == cid).sum())
457
- if cnt > 0:
458
- class_dist[CLASS_NAMES[cid]] = f"{cnt / total_pixels * 100:.1f}%"
459
-
460
- dominant = int(np.bincount(pred_mask_orig.flatten()).argmax())
461
- conf = float(probs[dominant].mean())
462
- elapsed = (time.time() - t0) * 1000
463
 
464
- return {
465
- "segmentation_mask": numpy_to_base64(colored_mask),
466
- "traversability_map": numpy_to_base64(trav_map),
467
- "traversability_overlay": numpy_to_base64(trav_overlay_img),
468
- "traversability_stats": trav_stats,
469
- "traversability_grid": trav_grid,
470
- "class_distribution": class_dist,
471
- "inference_time_ms": round(elapsed, 1),
472
- "dominant_class": CLASS_NAMES[dominant],
473
- "confidence": conf,
474
- }
475
- except Exception as e:
476
- import traceback
477
- traceback.print_exc()
478
- raise HTTPException(status_code=500, detail=f"Sim prediction error: {str(e)}")
479
-
480
-
481
- @app.get("/model/info")
482
- async def model_info():
483
- """Get model information"""
484
- if model is None:
485
- raise HTTPException(status_code=503, detail="Model not loaded")
486
-
487
  return {
488
- "model_type": str(type(model).__name__),
489
- "device": str(device),
490
- "parameters": sum(p.numel() for p in model.parameters() if hasattr(model, 'parameters'))
 
 
 
 
 
 
 
 
 
 
491
  }
492
 
 
493
  if __name__ == "__main__":
494
- import os
 
495
  port = int(os.environ.get("PORT", 7860))
496
  uvicorn.run(app, host="0.0.0.0", port=port)
 
 
 
 
 
 
 
 
1
  import io
2
+ import os
3
+ import time
4
+
5
+ import cv2
6
  import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ from efficientnet_pytorch import EfficientNet
10
+ from fastapi import FastAPI, File, HTTPException, UploadFile
11
+ from fastapi.middleware.cors import CORSMiddleware
12
+
13
 
14
+ app = FastAPI(title="Deepfake Detection API", version="2.0.0")
15
 
 
16
  app.add_middleware(
17
  CORSMiddleware,
18
+ allow_origins=["*"],
 
 
 
 
19
  allow_credentials=True,
20
  allow_methods=["*"],
21
  allow_headers=["*"],
22
  )
23
 
24
+
25
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
26
+ DETECTION_THRESHOLD = 0.40
27
+
28
+
29
+ class DeepfakeEfficientNet(nn.Module):
30
+ def __init__(self, pretrained: bool = True, dropout: float = 0.5):
31
+ super().__init__()
32
+ if pretrained:
33
+ self.net = EfficientNet.from_pretrained("efficientnet-b0")
34
+ else:
35
+ self.net = EfficientNet.from_name("efficientnet-b0")
36
+
37
+ in_features = self.net._fc.in_features
38
+ self.net._fc = nn.Sequential(
39
+ nn.Dropout(dropout),
40
+ nn.Linear(in_features, 512),
41
+ nn.BatchNorm1d(512),
42
+ nn.ReLU(),
43
+ nn.Dropout(dropout * 0.7),
44
+ nn.Linear(512, 256),
45
+ nn.BatchNorm1d(256),
46
+ nn.ReLU(),
47
+ nn.Dropout(dropout * 0.5),
48
+ nn.Linear(256, 1),
49
+ )
50
+
51
+ def forward(self, rgb_input, freq_input=None):
52
+ return self.net(rgb_input)
53
+
54
+
55
+ class TemporalTracker:
56
+ def __init__(self, window_size: int = 60, voting_window: int = 10, threshold: float = DETECTION_THRESHOLD):
57
+ self.window_size = window_size
58
+ self.voting_window = voting_window
59
+ self.threshold = threshold
60
+ self.score_history = []
61
+ self.frame_votes = []
62
+
63
+ def update(self, fake_probability: float):
64
+ self.score_history.append(float(fake_probability))
65
+ if len(self.score_history) > self.window_size:
66
+ self.score_history = self.score_history[-self.window_size :]
67
+
68
+ vote = "FAKE" if fake_probability > self.threshold else "REAL"
69
+ self.frame_votes.append(vote)
70
+ if len(self.frame_votes) > self.voting_window:
71
+ self.frame_votes = self.frame_votes[-self.voting_window :]
72
+
73
+ def get_temporal_average(self) -> float:
74
+ if not self.score_history:
75
+ return 0.0
76
+ return float(sum(self.score_history) / len(self.score_history))
77
+
78
+ def get_stability_score(self) -> float:
79
+ if len(self.score_history) < 10:
80
+ return 0.0
81
+ arr = np.array(self.score_history[-10:], dtype=np.float32)
82
+ variance = float(np.var(arr))
83
+ return float(1.0 - min(variance * 4.0, 1.0))
84
+
85
+ def get_confidence_level(self) -> str:
86
+ if len(self.frame_votes) < self.voting_window:
87
+ return "UNCERTAIN"
88
+ fake_count = sum(1 for x in self.frame_votes if x == "FAKE")
89
+ real_count = len(self.frame_votes) - fake_count
90
+ return "FAKE" if fake_count > real_count else "REAL"
91
+
92
+ def reset(self):
93
+ self.score_history = []
94
+ self.frame_votes = []
95
+
96
+
97
+ class ForensicAnalyzer:
98
+ def __init__(self, analysis_size=(256, 256)):
99
+ self.analysis_size = analysis_size
100
+ self.prev_gray = None
101
+
102
+ def analyze(self, frame_bgr: np.ndarray):
103
+ resized = cv2.resize(frame_bgr, self.analysis_size, interpolation=cv2.INTER_LINEAR)
104
+
105
+ frequency = self._analyze_frequency(resized)
106
+ noise = self._analyze_noise(resized)
107
+ ela = self._analyze_ela(resized)
108
+ edge = self._analyze_edges(resized)
109
+ temporal = self._analyze_temporal(resized)
110
+
111
+ score = (
112
+ 0.32 * frequency
113
+ + 0.20 * noise
114
+ + 0.18 * ela
115
+ + 0.18 * edge
116
+ + 0.12 * temporal
117
+ )
118
+
119
+ return {
120
+ "fake_probability": float(np.clip(score, 0.0, 1.0)),
121
+ "scores": {
122
+ "frequency": float(frequency),
123
+ "noise": float(noise),
124
+ "ela": float(ela),
125
+ "edge": float(edge),
126
+ "temporal": float(temporal),
127
+ },
128
+ }
129
+
130
+ def _analyze_frequency(self, frame):
131
+ gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY).astype(np.float32)
132
+ magnitude = np.log1p(np.abs(np.fft.fftshift(np.fft.fft2(gray))))
133
+ h, w = magnitude.shape
134
+ cy, cx = h // 2, w // 2
135
+
136
+ y_grid, x_grid = np.ogrid[:h, :w]
137
+ dist = np.sqrt((x_grid - cx) ** 2 + (y_grid - cy) ** 2)
138
+
139
+ inner = min(h, w) // 8
140
+ outer = min(h, w) // 3
141
+ low = magnitude[dist <= inner]
142
+ high = magnitude[(dist > inner) & (dist <= outer)]
143
+
144
+ low_mean = float(low.mean()) if low.size else 0.0
145
+ high_mean = float(high.mean()) if high.size else 0.0
146
+ ratio = high_mean / (low_mean + high_mean + 1e-9)
147
+
148
+ if ratio < 0.18:
149
+ return 0.75
150
+ if ratio < 0.24:
151
+ return 0.45
152
+ return 0.10
153
+
154
+ def _analyze_noise(self, frame):
155
+ gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY).astype(np.float32)
156
+ blurred = cv2.GaussianBlur(gray, (5, 5), 0)
157
+ residual = gray - blurred
158
+ std = float(np.std(residual))
159
+
160
+ if std < 2.0:
161
+ return 0.70
162
+ if std < 4.0:
163
+ return 0.35
164
+ return 0.12
165
+
166
+ def _analyze_ela(self, frame):
167
+ encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 90]
168
+ ok, encoded = cv2.imencode(".jpg", frame, encode_param)
169
+ if not ok:
170
+ return 0.0
171
+
172
+ recompressed = cv2.imdecode(encoded, cv2.IMREAD_COLOR)
173
+ if recompressed is None:
174
+ return 0.0
175
+
176
+ diff = cv2.absdiff(frame, recompressed)
177
+ mean_diff = float(np.mean(diff))
178
+
179
+ if mean_diff > 14:
180
+ return 0.65
181
+ if mean_diff > 8:
182
+ return 0.35
183
+ return 0.08
184
+
185
+ def _analyze_edges(self, frame):
186
+ gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
187
+ edges = cv2.Canny(gray, 50, 150)
188
+ edge_density = float(np.mean(edges > 0))
189
+ lap_var = float(np.var(cv2.Laplacian(gray, cv2.CV_64F)))
190
+
191
+ score = 0.0
192
+ if edge_density < 0.02:
193
+ score += 0.45
194
+ elif edge_density < 0.04:
195
+ score += 0.20
196
+
197
+ if lap_var < 60:
198
+ score += 0.35
199
+ elif lap_var < 120:
200
+ score += 0.15
201
+
202
+ return float(np.clip(score, 0.0, 1.0))
203
+
204
+ def _analyze_temporal(self, frame):
205
+ gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY).astype(np.float32)
206
+ if self.prev_gray is None:
207
+ self.prev_gray = gray
208
+ return 0.0
209
+
210
+ diff = cv2.absdiff(gray, self.prev_gray)
211
+ self.prev_gray = gray
212
+ mean_delta = float(np.mean(diff))
213
+
214
+ if mean_delta < 1.2:
215
+ return 0.40
216
+ if mean_delta < 2.5:
217
+ return 0.20
218
+ return 0.08
219
+
220
+ def reset(self):
221
+ self.prev_gray = None
222
+
223
+
224
  model = None
225
+ model_loaded = False
226
+ tracker = TemporalTracker()
227
+ forensics = ForensicAnalyzer()
228
+ frame_count = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
+
231
+ def load_checkpoint_model():
 
232
  global model
233
+ global model_loaded
234
+
235
+ checkpoint_candidates = [
236
+ os.path.join(os.path.dirname(__file__), "models", "best_model.pth"),
237
+ os.path.join(os.path.dirname(__file__), "models", "latest_model_ft.pth"),
238
+ ]
239
+
240
+ model = DeepfakeEfficientNet(pretrained=True).to(DEVICE)
241
+
242
+ loaded_any = False
243
+ for path in checkpoint_candidates:
244
+ if not os.path.exists(path):
245
+ continue
246
+
247
+ try:
248
+ checkpoint = torch.load(path, map_location=DEVICE, weights_only=False)
249
+ state_dict = checkpoint.get("model_state_dict", checkpoint)
250
+ model.load_state_dict(state_dict, strict=False)
251
+ loaded_any = True
252
+ print(f"Loaded checkpoint: {path}")
253
+ break
254
+ except Exception as ex:
255
+ print(f"Failed loading checkpoint {path}: {ex}")
256
+
257
+ if loaded_any:
 
 
 
 
258
  model.eval()
259
+ model_loaded = True
260
+ else:
261
+ model_loaded = False
262
+ print("No compatible deepfake checkpoint found; running forensic-only mode.")
263
+
264
+
265
+ @app.on_event("startup")
266
+ async def startup_event():
267
+ load_checkpoint_model()
268
+
269
+
270
+ @app.get("/")
271
+ async def root_health():
 
 
 
 
 
272
  return {
273
+ "status": "healthy",
274
+ "model_loaded": model_loaded,
275
+ "device": DEVICE,
276
+ "frame_count": frame_count,
277
  }
278
 
279
+
280
+ @app.get("/health")
281
+ async def health_check():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
  return {
283
+ "status": "healthy",
284
+ "model_loaded": model_loaded,
285
+ "device": DEVICE,
286
+ "frame_count": frame_count,
287
+ "capabilities": {
288
+ "frame_forensics": True,
289
+ "temporal_tracking": True,
290
+ "face_detection": False,
291
+ },
292
  }
293
 
294
 
295
+ @app.post("/reset")
296
+ async def reset_state():
297
+ global frame_count
298
+
299
+ tracker.reset()
300
+ forensics.reset()
301
+ frame_count = 0
302
+ return {"success": True, "message": "Detector state reset"}
303
+
304
+
305
+ def _prepare_model_tensor(frame_bgr: np.ndarray) -> torch.Tensor:
306
+ rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
307
+ resized = cv2.resize(rgb, (224, 224), interpolation=cv2.INTER_AREA)
308
+ arr = resized.astype(np.float32) / 255.0
309
 
310
+ mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
311
+ std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
312
+ arr = (arr - mean) / std
313
+
314
+ chw = np.transpose(arr, (2, 0, 1))
315
+ tensor = torch.from_numpy(chw).unsqueeze(0).to(DEVICE)
316
+ return tensor
317
+
318
+
319
+ def _run_model(frame_bgr: np.ndarray):
320
+ if not model_loaded or model is None:
321
+ return None
322
 
 
 
 
 
 
 
323
  try:
324
+ tensor = _prepare_model_tensor(frame_bgr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
  with torch.no_grad():
326
+ logits = model(tensor).squeeze()
327
+ prob = torch.sigmoid(logits).item()
328
+ return float(np.clip(prob, 0.0, 1.0))
329
+ except Exception as ex:
330
+ print(f"Model inference failed: {ex}")
331
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
 
 
 
 
 
 
333
 
334
+ @app.post("/analyze")
335
+ async def analyze_frame(frame: UploadFile = File(None), file: UploadFile = File(None)):
336
+ global frame_count
 
 
337
 
338
+ start = time.time()
 
339
 
340
+ uploaded = frame or file
341
+ if uploaded is None:
342
+ raise HTTPException(status_code=400, detail="No frame provided. Use multipart form field 'frame'.")
343
+
344
+ raw = await uploaded.read()
345
+ if not raw:
346
+ raise HTTPException(status_code=400, detail="Empty file")
347
+
348
+ np_bytes = np.frombuffer(raw, np.uint8)
349
+ image = cv2.imdecode(np_bytes, cv2.IMREAD_COLOR)
350
+ if image is None:
351
+ raise HTTPException(status_code=400, detail="Invalid image format")
352
+
353
+ frame_count += 1
354
+
355
+ forensic_result = forensics.analyze(image)
356
+ forensic_prob = float(forensic_result["fake_probability"])
357
+
358
+ model_prob = _run_model(image)
359
+ if model_prob is None:
360
+ combined_prob = forensic_prob
361
+ analysis_mode = "frame_only"
362
+ else:
363
+ combined_prob = float(np.clip(0.70 * model_prob + 0.30 * forensic_prob, 0.0, 1.0))
364
+ analysis_mode = "model+frame"
365
+
366
+ tracker.update(combined_prob)
367
+
368
+ elapsed_ms = (time.time() - start) * 1000.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
  return {
371
+ "success": True,
372
+ "analysis_mode": analysis_mode,
373
+ "faces_detected": 0,
374
+ "fake_probability": combined_prob,
375
+ "model_probability": model_prob,
376
+ "frame_forensic_probability": forensic_prob,
377
+ "real_probability": float(1.0 - combined_prob),
378
+ "confidence_level": tracker.get_confidence_level(),
379
+ "temporal_average": tracker.get_temporal_average(),
380
+ "stability_score": tracker.get_stability_score(),
381
+ "frame_count": frame_count,
382
+ "processing_time_ms": round(elapsed_ms, 1),
383
+ "forensic_scores": forensic_result["scores"],
384
  }
385
 
386
+
387
  if __name__ == "__main__":
388
+ import uvicorn
389
+
390
  port = int(os.environ.get("PORT", 7860))
391
  uvicorn.run(app, host="0.0.0.0", port=port)
392
+
requirements.txt CHANGED
@@ -1,10 +1,7 @@
1
  fastapi==0.109.0
2
  uvicorn[standard]==0.27.0
3
- torch>=2.6.0
4
- torchvision>=0.21.0
5
- pillow>=10.2.0
6
- python-multipart>=0.0.6
7
  numpy>=1.26.0
8
- pydantic>=2.5.0
9
- albumentations>=1.3.0
10
- transformers>=4.30.0
 
 
1
  fastapi==0.109.0
2
  uvicorn[standard]==0.27.0
 
 
 
 
3
  numpy>=1.26.0
4
+ python-multipart>=0.0.6
5
+ opencv-python-headless>=4.10.0
6
+ efficientnet-pytorch>=0.7.1
7
+