bithal26 commited on
Commit
212d8bf
Β·
verified Β·
1 Parent(s): 8442b78

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +466 -306
app.py CHANGED
@@ -1,323 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
 
 
 
 
 
 
 
 
 
 
 
2
  import cv2
3
- import torch
4
  import numpy as np
5
- from PIL import Image
6
- import gradio as gr
 
 
 
7
  from gradio_client import Client, handle_file
8
- from facenet_pytorch.models.mtcnn import MTCNN
9
- import concurrent.futures
10
- import tempfile
11
- from fastapi import FastAPI, UploadFile, File
12
- from fastapi.responses import HTMLResponse
13
- import shutil
14
-
15
- # ==========================================
16
- # 1. API ROUTER
17
- # ==========================================
18
- WORKER_SPACES = [
19
- "bithal26/DeepFake-Worker-1",
20
- "bithal26/DeepFake-Worker-2",
21
- "bithal26/DeepFake-Worker-3",
22
- "bithal26/DeepFake-Worker-4",
23
- "bithal26/DeepFake-Worker-5",
24
- "bithal26/DeepFake-Worker-6",
25
- "bithal26/DeepFake-Worker-7"
26
- ]
27
-
28
- clients = []
29
- print("Initializing connections to 7 API Workers...")
30
- for space in WORKER_SPACES:
31
- try:
32
- clients.append(Client(space))
33
- except Exception as e:
34
- print(f"Warning: Could not connect to {space}. Error: {e}")
35
 
36
- # ==========================================
37
- # 2. PREPROCESSING ENGINE
38
- # ==========================================
39
- device = torch.device('cpu')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- def isotropically_resize_image(img, size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  h, w = img.shape[:2]
43
- if max(w, h) == size: return img
44
- scale = size / w if w > h else size / h
45
- w, h = w * scale, h * scale
46
- interpolation = interpolation_up if scale > 1 else interpolation_down
47
- return cv2.resize(img, (int(w), int(h)), interpolation=interpolation)
48
-
49
- def put_to_center(img, input_size):
50
- img = img[:input_size, :input_size]
51
- image = np.zeros((input_size, input_size, 3), dtype=np.uint8)
52
- start_w = (input_size - img.shape[1]) // 2
53
- start_h = (input_size - img.shape[0]) // 2
54
- image[start_h:start_h + img.shape[0], start_w: start_w + img.shape[1], :] = img
55
- return image
56
-
57
- class VideoReader:
58
- def read_frames(self, path, num_frames):
59
- capture = cv2.VideoCapture(path)
60
- frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
61
- if frame_count <= 0: return None
62
- frame_idxs = np.linspace(0, frame_count - 1, num_frames, endpoint=True, dtype=np.int32)
63
- frames, idxs_read = [], []
64
- for frame_idx in range(frame_idxs[0], frame_idxs[-1] + 1):
65
- ret = capture.grab()
66
- if not ret: break
67
- current = len(idxs_read)
68
- if frame_idx == frame_idxs[current]:
69
- ret, frame = capture.retrieve()
70
- if not ret or frame is None: break
71
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
72
- frames.append(frame)
73
- idxs_read.append(frame_idx)
74
- capture.release()
75
- return np.stack(frames), idxs_read if len(frames) > 0 else None
76
-
77
- class FaceExtractor:
78
- def __init__(self):
79
- self.video_reader = VideoReader()
80
- self.detector = MTCNN(margin=0, thresholds=[0.7, 0.8, 0.8], device=device)
81
-
82
- def process_video(self, video_path, frames_per_video=32):
83
- result = self.video_reader.read_frames(video_path, num_frames=frames_per_video)
84
- if result is None: return []
85
- my_frames, my_idxs = result
86
- results = []
87
- for frame in my_frames:
88
- img = Image.fromarray(frame.astype(np.uint8))
89
- img = img.resize(size=[s // 2 for s in img.size])
90
- batch_boxes, probs = self.detector.detect(img, landmarks=False)
91
- faces = []
92
- if batch_boxes is not None:
93
- for bbox in batch_boxes:
94
- if bbox is not None:
95
- xmin, ymin, xmax, ymax = [int(b * 2) for b in bbox]
96
- w, h = xmax - xmin, ymax - ymin
97
- p_h, p_w = h // 3, w // 3
98
- crop = frame[max(ymin - p_h, 0):ymax + p_h, max(xmin - p_w, 0):xmax + p_w]
99
- faces.append(crop)
100
- if faces:
101
- results.append({"faces": faces})
102
- return results
103
-
104
- face_extractor = FaceExtractor()
105
-
106
- def confident_strategy(pred, t=0.8):
107
- pred = np.array(pred)
108
- sz = len(pred)
109
- if sz == 0: return 0.5
110
- fakes = np.count_nonzero(pred > t)
111
- if fakes > sz // 2.5 and fakes > 11:
112
- return float(np.mean(pred[pred > t]))
113
- elif np.count_nonzero(pred < 0.2) > 0.9 * sz:
114
- return float(np.mean(pred[pred < 0.2]))
115
- else:
116
- return float(np.mean(pred))
117
-
118
- def call_worker(client, filepath):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  try:
120
- # Pass positional argument to avoid Gradio 4 keyword parsing issues
121
- result = client.predict(handle_file(filepath), api_name="/predict")
122
-
123
- # Log backend errors if they happen
124
- if isinstance(result, dict) and "error" in result:
125
- print(f"Worker Error: {result['error']}")
126
- return 0.5
127
-
128
- preds = result.get("predictions", []) if isinstance(result, dict) else []
129
- if not preds: return 0.5
130
- return confident_strategy(preds)
131
- except Exception as e:
132
- print(f"Network Connection Error: {e}")
133
- return 0.5
134
-
135
- # ==========================================
136
- # 3. FASTAPI SERVER & HTML INJECTION
137
- # ==========================================
138
- app = FastAPI()
139
-
140
- JS_OVERRIDE = """
141
- <script>
142
- function handleDrop(e) {
143
- e.preventDefault();
144
- document.getElementById('uploadZone').classList.remove('dragging');
145
- const file = e.dataTransfer.files[0];
146
- if (file) startAnalysis(file);
147
- }
148
-
149
- function startAnalysis(file) {
150
- if (!file) return;
151
- const overlay = document.getElementById('analyzeOverlay');
152
- overlay.classList.add('visible');
153
-
154
- const steps = ['step1','step2','step3','step4','step5','step6'];
155
- const labels = [
156
- 'Decoding video frames...',
157
- 'Extracting facial landmarks...',
158
- 'Running 7 parallel neural models...',
159
- 'Frequency domain analysis...',
160
- 'Temporal coherence check...',
161
- 'Generating forensic report...'
162
- ];
163
-
164
- let currentStep = 0;
165
- const interval = setInterval(() => {
166
- if (currentStep > 0) document.getElementById(steps[currentStep - 1]).className = 'a-step done';
167
- if (currentStep < steps.length) {
168
- document.getElementById(steps[currentStep]).className = 'a-step active';
169
- document.getElementById('analyzeText').textContent = labels[currentStep];
170
- currentStep++;
171
- }
172
- }, 450);
173
-
174
- const formData = new FormData();
175
- formData.append('file', file);
176
- const startTime = performance.now();
177
-
178
- fetch('/api/analyze', { method: 'POST', body: formData })
179
- .then(res => res.json())
180
- .then(data => {
181
- clearInterval(interval);
182
- steps.forEach(s => document.getElementById(s).className = 'a-step');
183
- overlay.classList.remove('visible');
184
-
185
- if (data.error) {
186
- alert("Analysis Error: " + data.error);
187
- return;
 
 
 
 
188
  }
 
 
 
 
 
 
 
 
 
189
 
190
- const duration = ((performance.now() - startTime) / 1000).toFixed(1);
191
- updateRealMetrics(data.final_score, data.worker_scores);
192
- showRealResult(file.name, data.final_score, data.worker_scores, duration);
193
- })
194
- .catch(err => {
195
- clearInterval(interval);
196
- overlay.classList.remove('visible');
197
- alert("System Error: " + err);
198
- });
199
- }
200
-
201
- function updateRealMetrics(finalScore, workerScores) {
202
- const isFake = finalScore >= 0.5;
203
- const confidence = isFake ? finalScore * 100 : (1 - finalScore) * 100;
204
-
205
- const scoreEl = document.getElementById('authScore');
206
- scoreEl.textContent = confidence.toFixed(1) + '%';
207
- scoreEl.className = 'result-score ' + (isFake ? 'fake' : 'authentic');
208
-
209
- for(let i=1; i<=5; i++) {
210
- let wScore = workerScores[i-1] ? workerScores[i-1] * 100 : confidence;
211
- document.getElementById('m' + i).textContent = wScore.toFixed(1) + '%';
212
- document.getElementById('b' + i).style.width = wScore + '%';
213
- }
214
- }
215
-
216
- function showRealResult(fileName, finalScore, workerScores, duration) {
217
- const isFake = finalScore >= 0.5;
218
- const confidence = isFake ? (finalScore * 100).toFixed(1) : ((1 - finalScore) * 100).toFixed(1);
219
- const overlay = document.getElementById('resultOverlay');
220
-
221
- document.getElementById('modalScore').textContent = confidence + '%';
222
- document.getElementById('modalScore').style.color = isFake ? 'var(--red)' : 'var(--green)';
223
- document.getElementById('modalVerdict').textContent = isFake ? 'DEEPFAKE DETECTED' : 'AUTHENTIC CONTENT';
224
- document.getElementById('modalVerdict').className = 'verdict-title ' + (isFake ? '' : 'authentic');
225
- document.getElementById('modalDesc').textContent = isFake
226
- ? `High confidence manipulation detected in "${fileName}". Ensemble forensic signals indicate AI-generated modifications.`
227
- : `No significant manipulation detected in "${fileName}". All forensic signals within normal parameters.`;
228
-
229
- document.getElementById('mm1').textContent = confidence + '%';
230
- document.getElementById('mm2').textContent = workerScores[1] ? (workerScores[1]*100).toFixed(1) + '%' : confidence + '%';
231
- document.getElementById('mm3').textContent = duration + 's';
232
-
233
- overlay.classList.add('visible');
234
- }
235
-
236
- function closeResult() { document.getElementById('resultOverlay').classList.remove('visible'); }
237
- document.getElementById('resultOverlay').addEventListener('click', function(e) { if (e.target === this) closeResult(); });
238
-
239
- setTimeout(() => {
240
- const observer = new IntersectionObserver((entries) => {
241
- entries.forEach(e => {
242
- if (e.isIntersecting) {
243
- e.target.style.opacity = '1';
244
- e.target.style.transform = 'translateY(0)';
245
- }
246
- });
247
- }, { threshold: 0.1 });
248
- document.querySelectorAll('.how-step, .feature-card, .report-card').forEach(el => {
249
- el.style.opacity = '0';
250
- el.style.transform = 'translateY(24px)';
251
- el.style.transition = 'opacity 0.6s ease, transform 0.6s ease, border-color 0.3s';
252
- observer.observe(el);
253
- });
254
- }, 500);
255
- </script>
256
- </body>
257
- </html>
258
- """
259
 
260
- @app.get("/")
261
- def read_root():
262
- try:
263
- with open("deepfake-detector.html", "r", encoding="utf-8") as f:
264
- html_content = f.read()
265
- html_parts = html_content.split("<script>")
266
- live_html = html_parts[0] + JS_OVERRIDE
267
- return HTMLResponse(content=live_html)
268
- except FileNotFoundError:
269
- return HTMLResponse(content="<h1>Error: deepfake-detector.html not found.</h1><p>Please upload the HTML file to this space.</p>")
270
-
271
- @app.post("/api/analyze")
272
- async def analyze_api(file: UploadFile = File(...)):
273
- temp_dir = tempfile.mkdtemp()
274
- video_path = os.path.join(temp_dir, file.filename)
275
- with open(video_path, "wb") as buffer:
276
- shutil.copyfileobj(file.file, buffer)
277
-
278
- input_size = 380
279
- frames_per_video = 32
280
- batch_size = frames_per_video * 4
281
-
282
- faces = face_extractor.process_video(video_path, frames_per_video=frames_per_video)
283
-
284
- x = np.zeros((batch_size, input_size, input_size, 3), dtype=np.uint8)
285
- n = 0
286
- for frame_data in faces:
287
- for face in frame_data["faces"]:
288
- resized_face = isotropically_resize_image(face, input_size)
289
- resized_face = put_to_center(resized_face, input_size)
290
- if n < batch_size:
291
- x[n] = resized_face
292
- n += 1
293
-
294
- if n == 0:
295
- shutil.rmtree(temp_dir, ignore_errors=True)
296
- return {"error": "No faces detected."}
297
-
298
- # Save as highly compressed uint8 numpy array instead of float32 tensor
299
- x_final = x[:n]
300
- np_path = os.path.join(temp_dir, "batch_tensor.npy")
301
- np.save(np_path, x_final)
302
-
303
- worker_scores = []
304
- with concurrent.futures.ThreadPoolExecutor(max_workers=7) as executor:
305
- futures = [executor.submit(call_worker, client, np_path) for client in clients]
306
- for future in concurrent.futures.as_completed(futures):
307
- worker_scores.append(future.result())
308
-
309
- final_score = np.mean(worker_scores)
310
- shutil.rmtree(temp_dir, ignore_errors=True)
311
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
  return {
313
- "final_score": float(final_score),
314
- "worker_scores": [float(s) for s in worker_scores]
 
 
 
 
315
  }
316
 
317
- demo = gr.Blocks()
318
- app = gr.mount_gradio_app(app, demo, path="/gradio")
319
 
320
- # --- KEEP THE SERVER AWAKE ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
  if __name__ == "__main__":
322
- import uvicorn
323
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
 
 
 
 
 
1
+ """
2
+ ================================================================================
3
+ VERIDEX β€” Master UI / Orchestrator Space (DeepFake-Detector-UI)
4
+ ──────────────────────────────────────────────────────────────────
5
+ Architecture
6
+ ────────────
7
+ β€’ FastAPI serves the custom deepfake-detector.html at GET /
8
+ β€’ POST /predict/ accepts a raw .mp4 upload
9
+ 1. Saves video to a temp file
10
+ 2. MTCNN extracts up to NUM_FRAMES faces (380 Γ— 380, uint8 HWC)
11
+ 3. Batch is saved as a compressed .npy file
12
+ 4. Fires the .npy at all 7 Workers in parallel via gradio_client
13
+ 5. Aggregates per-frame predictions with confident_strategy
14
+ 6. Returns JSON { prediction, score, filename, worker_results }
15
+
16
+ ENV VARS (set in HF Space settings)
17
+ ─────────────────────────────────────
18
+ WORKER_1_URL … WORKER_7_URL β€” public Gradio Space URLs for each worker
19
+ e.g. https://your-user-deepfake-worker-1.hf.space
20
+ NUM_FRAMES default 32 β€” frames to sample per video
21
+ WORKER_TIMEOUT default 120 β€” seconds to wait per worker call
22
+ ================================================================================
23
+ """
24
+
25
  import os
26
+ import io
27
+ import time
28
+ import uuid
29
+ import logging
30
+ import tempfile
31
+ import traceback
32
+ import traceback as _tb
33
+ from concurrent.futures import ThreadPoolExecutor, as_completed, TimeoutError as FuturesTimeout
34
+ from pathlib import Path
35
+ from typing import Optional
36
+
37
  import cv2
 
38
  import numpy as np
39
+ import torch
40
+ from fastapi import FastAPI, File, UploadFile, HTTPException
41
+ from fastapi.responses import HTMLResponse, JSONResponse
42
+ from fastapi.staticfiles import StaticFiles
43
+ import uvicorn
44
  from gradio_client import Client, handle_file
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ # ─────────────────────────────────────────────────────────────────────────────
47
+ # Optional: facenet-pytorch for MTCNN face detection
48
+ # ─────────────────────────────────────────────────────────────────────────────
49
+ try:
50
+ from facenet_pytorch import MTCNN
51
+ FACENET_AVAILABLE = True
52
+ except ImportError:
53
+ FACENET_AVAILABLE = False
54
+ logging.warning(
55
+ "facenet-pytorch not installed β€” falling back to full-frame "
56
+ "centre-crop for face extraction."
57
+ )
58
+
59
+ logging.basicConfig(
60
+ level=logging.INFO,
61
+ format="%(asctime)s [UI] %(levelname)s %(message)s",
62
+ )
63
+ logger = logging.getLogger(__name__)
64
+
65
+ # ══════════════════════════════════════════════════════════════════════════════
66
+ # Configuration
67
+ # ══════════════════════════════════════════════════════════════════════════════
68
+
69
+ NUM_FRAMES = int(os.environ.get("NUM_FRAMES", "32"))
70
+ WORKER_TIMEOUT = int(os.environ.get("WORKER_TIMEOUT", "120"))
71
+ INPUT_SIZE = 380 # must match worker expectation
72
+
73
+ # Worker URLs β€” read from env vars so no secrets are hard-coded
74
+ WORKER_URLS: list[str] = [
75
+ url for url in (
76
+ os.environ.get(f"WORKER_{i}_URL", "").strip()
77
+ for i in range(1, 8)
78
+ )
79
+ if url
80
+ ]
81
 
82
+ if not WORKER_URLS:
83
+ logger.warning(
84
+ "No WORKER_*_URL env vars set. "
85
+ "Set WORKER_1_URL … WORKER_7_URL in Space settings."
86
+ )
87
+
88
+ # ── HTML template path ────────────────────────────────────────────────────────
89
+ HTML_FILE = Path(__file__).parent / "deepfake-detector.html"
90
+
91
+ # ── MTCNN ─────────────────────────────────────────────────────────────────────
92
+ if FACENET_AVAILABLE:
93
+ # keep_all=True returns every detected face per frame
94
+ _mtcnn = MTCNN(
95
+ keep_all=True,
96
+ device="cuda" if torch.cuda.is_available() else "cpu",
97
+ select_largest=False,
98
+ post_process=False, # return raw uint8 tensors, not normalised
99
+ image_size=INPUT_SIZE,
100
+ margin=20,
101
+ )
102
+ logger.info("MTCNN initialised.")
103
+ else:
104
+ _mtcnn = None
105
+
106
+
107
+ # ══════════════════════════════════════════════════════════════════════════════
108
+ # Face extraction helpers
109
+ # ══════════════════════════════════════════════════════════════════════════════
110
+
111
+ def _isotropic_resize(img: np.ndarray, size: int) -> np.ndarray:
112
  h, w = img.shape[:2]
113
+ if max(h, w) == size:
114
+ return img
115
+ scale = size / max(h, w)
116
+ new_h, new_w = int(h * scale), int(w * scale)
117
+ interp = cv2.INTER_CUBIC if scale > 1 else cv2.INTER_AREA
118
+ return cv2.resize(img, (new_w, new_h), interpolation=interp)
119
+
120
+
121
+ def _put_to_center(img: np.ndarray, size: int) -> np.ndarray:
122
+ img = img[:size, :size]
123
+ canvas = np.zeros((size, size, 3), dtype=np.uint8)
124
+ sh = (size - img.shape[0]) // 2
125
+ sw = (size - img.shape[1]) // 2
126
+ canvas[sh : sh + img.shape[0], sw : sw + img.shape[1]] = img
127
+ return canvas
128
+
129
+
130
+ def _extract_faces_mtcnn(video_path: str, num_frames: int) -> Optional[np.ndarray]:
131
+ """
132
+ Use MTCNN to detect and crop faces from evenly-spaced video frames.
133
+ Returns uint8 numpy array of shape (N, INPUT_SIZE, INPUT_SIZE, 3) or None.
134
+ """
135
+ cap = cv2.VideoCapture(video_path)
136
+ total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
137
+ if total <= 0:
138
+ cap.release()
139
+ return None
140
+
141
+ idxs = np.linspace(0, total - 1, num_frames, dtype=np.int32)
142
+ faces_collected: list[np.ndarray] = []
143
+
144
+ for idx in idxs:
145
+ cap.set(cv2.CAP_PROP_POS_FRAMES, int(idx))
146
+ ret, frame_bgr = cap.read()
147
+ if not ret:
148
+ continue
149
+ frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
150
+ from PIL import Image as _PILImage
151
+ pil_frame = _PILImage.fromarray(frame_rgb)
152
+
153
+ try:
154
+ boxes, _ = _mtcnn.detect(pil_frame)
155
+ if boxes is None:
156
+ # No face detected β€” fall back to centre crop of whole frame
157
+ face = _isotropic_resize(frame_rgb, INPUT_SIZE)
158
+ face = _put_to_center(face, INPUT_SIZE)
159
+ faces_collected.append(face)
160
+ continue
161
+
162
+ for box in boxes:
163
+ x1, y1, x2, y2 = [int(c) for c in box]
164
+ x1, y1 = max(0, x1), max(0, y1)
165
+ x2, y2 = min(frame_rgb.shape[1], x2), min(frame_rgb.shape[0], y2)
166
+ crop = frame_rgb[y1:y2, x1:x2]
167
+ if crop.size == 0:
168
+ continue
169
+ face = _isotropic_resize(crop, INPUT_SIZE)
170
+ face = _put_to_center(face, INPUT_SIZE)
171
+ faces_collected.append(face)
172
+
173
+ except Exception as exc:
174
+ logger.warning(f"MTCNN failed on frame {idx}: {exc}")
175
+ face = _isotropic_resize(frame_rgb, INPUT_SIZE)
176
+ face = _put_to_center(face, INPUT_SIZE)
177
+ faces_collected.append(face)
178
+
179
+ cap.release()
180
+ if not faces_collected:
181
+ return None
182
+ return np.stack(faces_collected[:num_frames * 4], axis=0).astype(np.uint8)
183
+
184
+
185
+ def _extract_faces_fallback(video_path: str, num_frames: int) -> Optional[np.ndarray]:
186
+ """Centre-crop fallback when facenet-pytorch is not available."""
187
+ cap = cv2.VideoCapture(video_path)
188
+ total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
189
+ if total <= 0:
190
+ cap.release()
191
+ return None
192
+
193
+ idxs = np.linspace(0, total - 1, num_frames, dtype=np.int32)
194
+ frames = []
195
+ for idx in idxs:
196
+ cap.set(cv2.CAP_PROP_POS_FRAMES, int(idx))
197
+ ret, frame_bgr = cap.read()
198
+ if not ret:
199
+ continue
200
+ frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
201
+ face = _isotropic_resize(frame_rgb, INPUT_SIZE)
202
+ face = _put_to_center(face, INPUT_SIZE)
203
+ frames.append(face)
204
+ cap.release()
205
+
206
+ if not frames:
207
+ return None
208
+ return np.stack(frames, axis=0).astype(np.uint8)
209
+
210
+
211
+ def extract_faces(video_path: str) -> Optional[np.ndarray]:
212
+ if FACENET_AVAILABLE and _mtcnn is not None:
213
+ return _extract_faces_mtcnn(video_path, NUM_FRAMES)
214
+ return _extract_faces_fallback(video_path, NUM_FRAMES)
215
+
216
+
217
+ # ══════════════════════════════════════════════════════════════════════════════
218
+ # Aggregation strategy (mirrors deepfake_det.py confident_strategy)
219
+ # ══════════════════════════════════════════════════════════════════════════════
220
+
221
+ def confident_strategy(pred: np.ndarray, t: float = 0.8) -> float:
222
+ pred = np.array(pred, dtype=np.float32)
223
+ if len(pred) == 0:
224
+ return 0.5
225
+ confident_fake = pred[pred > t]
226
+ if len(confident_fake) >= 1:
227
+ return float(np.mean(confident_fake))
228
+ confident_real = pred[pred < (1 - t)]
229
+ if len(confident_real) >= 1:
230
+ return float(np.mean(confident_real))
231
+ return float(np.mean(pred))
232
+
233
+
234
+ # ══════════════════════════════════════════════════════════════════════════════
235
+ # Worker communication
236
+ # ══════════════════════════════════════════════════════════════════════════════
237
+
238
+ def _call_worker(worker_url: str, npy_path: str, worker_idx: int) -> dict:
239
+ """
240
+ Call one Worker Space via gradio_client.
241
+ Returns a dict with keys: worker, predictions, n_frames, error, score
242
+ """
243
+ result_stub = {"worker": worker_idx, "predictions": None, "n_frames": 0,
244
+ "error": None, "score": 0.5}
245
  try:
246
+ client = Client(worker_url, verbose=False)
247
+ # handle_file wraps the filepath so gradio_client sends it correctly
248
+ response = client.predict(
249
+ npy_file=handle_file(npy_path),
250
+ api_name="/predict",
251
+ )
252
+
253
+ # response may be the dict directly or a JSON string
254
+ if isinstance(response, str):
255
+ import json
256
+ response = json.loads(response)
257
+
258
+ if not isinstance(response, dict):
259
+ raise TypeError(f"Unexpected worker response type: {type(response)}")
260
+
261
+ worker_error = response.get("error")
262
+ predictions = response.get("predictions")
263
+
264
+ if worker_error:
265
+ # Worker returned an application-level error β€” log it fully
266
+ logger.error(
267
+ f"[Worker {worker_idx}] Application error:\n{worker_error}"
268
+ )
269
+ result_stub["error"] = worker_error
270
+ return result_stub
271
+
272
+ if predictions is None or len(predictions) == 0:
273
+ msg = f"Worker returned empty predictions list: {response}"
274
+ logger.error(f"[Worker {worker_idx}] {msg}")
275
+ result_stub["error"] = msg
276
+ return result_stub
277
+
278
+ score = confident_strategy(predictions)
279
+ logger.info(
280
+ f"[Worker {worker_idx}] OK β€” frames={len(predictions)}, score={score:.4f}"
281
+ )
282
+ result_stub.update({
283
+ "predictions": predictions,
284
+ "n_frames": response.get("n_frames", len(predictions)),
285
+ "score": score,
286
+ })
287
+ return result_stub
288
+
289
+ except FuturesTimeout:
290
+ msg = f"Timed out after {WORKER_TIMEOUT}s"
291
+ logger.error(f"[Worker {worker_idx}] {msg}")
292
+ result_stub["error"] = msg
293
+ return result_stub
294
+
295
+ except Exception:
296
+ full_tb = _tb.format_exc()
297
+ logger.error(f"[Worker {worker_idx}] Exception:\n{full_tb}")
298
+ result_stub["error"] = full_tb
299
+ return result_stub
300
+
301
+
302
+ def dispatch_to_workers(npy_path: str) -> list[dict]:
303
+ """
304
+ Fire the .npy file at all configured workers in parallel.
305
+ Each worker gets its own thread; WORKER_TIMEOUT caps each call.
306
+ Workers that fail contribute a score=0.5 fallback but log the real error.
307
+ """
308
+ if not WORKER_URLS:
309
+ logger.warning("No workers configured β€” returning neutral score.")
310
+ return [{"worker": 0, "predictions": None, "n_frames": 0,
311
+ "error": "No workers configured.", "score": 0.5}]
312
+
313
+ results: list[dict] = []
314
+ with ThreadPoolExecutor(max_workers=len(WORKER_URLS)) as pool:
315
+ futures = {
316
+ pool.submit(_call_worker, url, npy_path, i + 1): i + 1
317
+ for i, url in enumerate(WORKER_URLS)
318
  }
319
+ for fut in as_completed(futures, timeout=WORKER_TIMEOUT + 10):
320
+ try:
321
+ results.append(fut.result())
322
+ except Exception:
323
+ w = futures[fut]
324
+ full_tb = _tb.format_exc()
325
+ logger.error(f"[Worker {w}] Future raised:\n{full_tb}")
326
+ results.append({"worker": w, "predictions": None,
327
+ "n_frames": 0, "error": full_tb, "score": 0.5})
328
 
329
+ return results
330
+
331
+
332
+ # ══════════════════════════════════════════════════════════════════════════════
333
+ # FastAPI app
334
+ # ══════════════════════════════════════════════════════════════════════════════
335
+
336
+ app = FastAPI(title="VERIDEX DeepFake Detector UI")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
 
339
+ @app.get("/", response_class=HTMLResponse)
340
+ async def serve_ui():
341
+ """Serve the custom VERIDEX HTML interface."""
342
+ if not HTML_FILE.exists():
343
+ raise HTTPException(
344
+ status_code=404,
345
+ detail=f"deepfake-detector.html not found at {HTML_FILE}. "
346
+ "Ensure the file is committed to the Space repository root.",
347
+ )
348
+ return HTMLResponse(content=HTML_FILE.read_text(encoding="utf-8"))
349
+
350
+
351
+ @app.get("/health")
352
+ async def health():
353
  return {
354
+ "status": "ok",
355
+ "workers": len(WORKER_URLS),
356
+ "worker_urls": WORKER_URLS,
357
+ "facenet": FACENET_AVAILABLE,
358
+ "num_frames": NUM_FRAMES,
359
+ "worker_timeout": WORKER_TIMEOUT,
360
  }
361
 
 
 
362
 
363
+ @app.post("/predict/")
364
+ async def predict(file: UploadFile = File(...)):
365
+ """
366
+ Main prediction endpoint.
367
+
368
+ 1. Save uploaded video to a temp file.
369
+ 2. Extract faces via MTCNN β†’ uint8 .npy.
370
+ 3. Dispatch .npy to all workers in parallel.
371
+ 4. Aggregate scores, return result.
372
+ """
373
+ start_time = time.time()
374
+ tmp_dir = tempfile.mkdtemp(prefix="veridex_")
375
+
376
+ try:
377
+ # ── 1. Save uploaded video ────────────────────────────────────────────
378
+ video_path = os.path.join(tmp_dir, f"input_{uuid.uuid4().hex}.mp4")
379
+ contents = await file.read()
380
+ with open(video_path, "wb") as f:
381
+ f.write(contents)
382
+ logger.info(f"Video saved: {video_path} ({len(contents)/1024:.1f} KB)")
383
+
384
+ # ── 2. Face extraction ────────────────────────────────────────────────
385
+ faces_array = extract_faces(video_path)
386
+ if faces_array is None or faces_array.shape[0] == 0:
387
+ raise HTTPException(
388
+ status_code=422,
389
+ detail="No faces detected in the uploaded video. "
390
+ "Please upload a video that clearly shows a face.",
391
+ )
392
+ logger.info(f"Face extraction complete: {faces_array.shape}")
393
+
394
+ # ── 3. Serialise to compressed uint8 .npy ─────────────────────────────
395
+ npy_path = os.path.join(tmp_dir, "faces.npy")
396
+ # allow_pickle=False keeps the file safe and small;
397
+ # uint8 is ~4Γ— smaller than float32 β†’ stays within HF payload limits
398
+ np.save(npy_path, faces_array.astype(np.uint8))
399
+ npy_size_kb = os.path.getsize(npy_path) / 1024
400
+ logger.info(f"NPY payload: {npy_path} ({npy_size_kb:.1f} KB)")
401
+
402
+ # ── 4. Dispatch to workers ─────────────────────────────────────────────
403
+ worker_results = dispatch_to_workers(npy_path)
404
+
405
+ # ── 5. Aggregate ───────────────────────────────────────────────────────
406
+ # Collect all per-frame predictions from workers that succeeded
407
+ all_predictions: list[float] = []
408
+ successful_workers = 0
409
+ for r in worker_results:
410
+ if r.get("predictions") and r.get("error") is None:
411
+ all_predictions.extend(r["predictions"])
412
+ successful_workers += 1
413
+
414
+ if not all_predictions:
415
+ logger.warning(
416
+ "All workers failed or returned no predictions. "
417
+ "Returning neutral score. See per-worker errors above."
418
+ )
419
+ final_score = 0.5
420
+ else:
421
+ final_score = confident_strategy(all_predictions)
422
+
423
+ label = "FAKE" if final_score >= 0.5 else "REAL"
424
+ elapsed = round(time.time() - start_time, 2)
425
+
426
+ logger.info(
427
+ f"Result: {label} score={final_score:.4f} "
428
+ f"workers={successful_workers}/{len(WORKER_URLS)} "
429
+ f"elapsed={elapsed}s"
430
+ )
431
+
432
+ return JSONResponse({
433
+ "prediction": label,
434
+ "score": round(final_score, 4),
435
+ "score_pct": f"{final_score * 100:.1f}%",
436
+ "filename": file.filename,
437
+ "faces_extracted": int(faces_array.shape[0]),
438
+ "successful_workers": successful_workers,
439
+ "total_workers": len(WORKER_URLS),
440
+ "elapsed_sec": elapsed,
441
+ "worker_results": [
442
+ {
443
+ "worker": r["worker"],
444
+ "score": round(r["score"], 4),
445
+ "n_frames": r["n_frames"],
446
+ # Truncate the full traceback in the API response but it
447
+ # has already been printed in full to the server console.
448
+ "error": (r["error"][:300] + "…") if r.get("error") else None,
449
+ }
450
+ for r in sorted(worker_results, key=lambda x: x["worker"])
451
+ ],
452
+ })
453
+
454
+ except HTTPException:
455
+ raise
456
+ except Exception:
457
+ full_tb = traceback.format_exc()
458
+ logger.error(f"Unhandled error in /predict/:\n{full_tb}")
459
+ raise HTTPException(status_code=500, detail=full_tb)
460
+
461
+ finally:
462
+ # Best-effort cleanup; ignore errors if HF locks the temp dir
463
+ import shutil
464
+ try:
465
+ shutil.rmtree(tmp_dir, ignore_errors=True)
466
+ except Exception:
467
+ pass
468
+
469
+
470
+ # ══════════════════════════════════════════════════════════════════════════════
471
+ # Entry point
472
+ # ══════════════════════════════════════════════════════════════════════════════
473
+
474
  if __name__ == "__main__":
475
+ uvicorn.run(
476
+ "app:app",
477
+ host="0.0.0.0",
478
+ port=7860,
479
+ log_level="info",
480
+ # HF Spaces injects PORT; honour it if present
481
+ **({} if not os.environ.get("PORT") else
482
+ {"port": int(os.environ["PORT"])}),
483
+ )