bithal26 commited on
Commit
913d66a
·
verified ·
1 Parent(s): f90bad5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +162 -44
app.py CHANGED
@@ -1,21 +1,23 @@
1
  import os
 
2
  import torch
3
  import numpy as np
4
  from PIL import Image
5
- import cv2
6
  import gradio as gr
7
  from gradio_client import Client, handle_file
8
  from torchvision.transforms import Normalize
9
  from facenet_pytorch.models.mtcnn import MTCNN
10
  import concurrent.futures
11
  import tempfile
12
- from huggingface_hub import get_token
13
  from fastapi import FastAPI, UploadFile, File
14
  from fastapi.responses import HTMLResponse
15
  import shutil
16
 
17
  # ==========================================
18
- # 1. API ROUTER CONFIGURATION
 
 
 
19
  # ==========================================
20
  WORKER_SPACES = [
21
  "bithal26/DeepFake-Worker-1",
@@ -28,20 +30,16 @@ WORKER_SPACES = [
28
  ]
29
 
30
  clients = []
31
- print("Initializing connections to 7 API Workers...")
32
- hf_token = get_token()
33
-
34
- if not hf_token:
35
- print("CRITICAL WARNING: No HF_TOKEN found! Private workers will fail to connect.")
36
-
37
  for space in WORKER_SPACES:
38
  try:
39
- clients.append(Client(space, token=hf_token))
 
40
  except Exception as e:
41
  print(f"Warning: Could not connect to {space}. Error: {e}")
42
 
43
  # ==========================================
44
- # 2. MTCNN PREPROCESSING ENGINE
45
  # ==========================================
46
  mean = [0.485, 0.456, 0.406]
47
  std = [0.229, 0.224, 0.225]
@@ -70,7 +68,6 @@ class VideoReader:
70
  frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
71
  if frame_count <= 0: return None
72
  frame_idxs = np.linspace(0, frame_count - 1, num_frames, endpoint=True, dtype=np.int32)
73
-
74
  frames, idxs_read = [], []
75
  for frame_idx in range(frame_idxs[0], frame_idxs[-1] + 1):
76
  ret = capture.grab()
@@ -90,7 +87,7 @@ class FaceExtractor:
90
  self.video_reader = VideoReader()
91
  self.detector = MTCNN(margin=0, thresholds=[0.7, 0.8, 0.8], device=device)
92
 
93
- def process_video(self, video_path, frames_per_video=16):
94
  result = self.video_reader.read_frames(video_path, num_frames=frames_per_video)
95
  if result is None: return []
96
  my_frames, my_idxs = result
@@ -117,14 +114,14 @@ face_extractor = FaceExtractor()
117
  def confident_strategy(pred, t=0.8):
118
  pred = np.array(pred)
119
  sz = len(pred)
120
- if sz == 0: return 0.0
121
  fakes = np.count_nonzero(pred > t)
122
  if fakes > sz // 2.5 and fakes > 11:
123
- return np.mean(pred[pred > t])
124
  elif np.count_nonzero(pred < 0.2) > 0.9 * sz:
125
- return np.mean(pred[pred < 0.2])
126
  else:
127
- return np.mean(pred)
128
 
129
  def call_worker(client, tensor_filepath):
130
  try:
@@ -133,74 +130,195 @@ def call_worker(client, tensor_filepath):
133
  if not preds: return 0.5
134
  return confident_strategy(preds)
135
  except Exception as e:
136
- print(f"API Call Failed: {e}")
137
  return 0.5
138
 
139
  # ==========================================
140
- # 3. FASTAPI SERVER & DIRECT HTML INJECTION
141
  # ==========================================
142
  app = FastAPI()
143
 
144
- # 1. Serve your custom HTML file as the main page
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  @app.get("/")
146
  def read_root():
 
147
  with open("deepfake-detector.html", "r", encoding="utf-8") as f:
148
  html_content = f.read()
149
- return HTMLResponse(content=html_content)
 
 
 
 
150
 
151
- # 2. Create the hidden API endpoint your HTML will call
152
  @app.post("/api/analyze")
153
  async def analyze_api(file: UploadFile = File(...)):
154
- # Save uploaded video temporarily
155
  temp_dir = tempfile.mkdtemp()
156
  video_path = os.path.join(temp_dir, file.filename)
157
  with open(video_path, "wb") as buffer:
158
  shutil.copyfileobj(file.file, buffer)
159
 
160
- # Extract Faces
161
  input_size = 380
162
- faces = face_extractor.process_video(video_path, frames_per_video=16)
 
163
 
164
- if len(faces) == 0:
165
- return {"error": "No faces detected."}
166
-
167
- x = []
168
  for frame_data in faces:
169
  for face in frame_data["faces"]:
170
  resized_face = isotropically_resize_image(face, input_size)
171
  resized_face = put_to_center(resized_face, input_size)
172
- x.append(resized_face)
173
- if len(x) >= 16 * 4: break
 
174
 
175
- x = np.array(x, dtype=np.uint8)
176
- x = torch.tensor(x, device=device).float()
177
- x = x.permute((0, 3, 1, 2))
178
- for i in range(len(x)):
179
- x[i] = normalize_transform(x[i] / 255.)
 
 
 
 
180
 
181
- # Save Tensor
182
  tensor_path = os.path.join(temp_dir, "batch_tensor.pt")
183
- torch.save(x, tensor_path)
184
 
185
- # Ping Workers
186
  worker_scores = []
187
  with concurrent.futures.ThreadPoolExecutor(max_workers=7) as executor:
188
  futures = [executor.submit(call_worker, client, tensor_path) for client in clients]
189
  for future in concurrent.futures.as_completed(futures):
190
  worker_scores.append(future.result())
191
 
192
- # Aggregate
193
  final_score = np.mean(worker_scores)
194
-
195
- # Clean up temp files
196
  shutil.rmtree(temp_dir, ignore_errors=True)
197
 
198
- # Return pure JSON data to the HTML frontend
199
  return {
200
  "final_score": float(final_score),
201
  "worker_scores": [float(s) for s in worker_scores]
202
  }
203
 
204
- # Gradio wrapper just to keep Hugging Face happy, but we mount our custom FastAPI app
205
  demo = gr.Blocks()
206
  app = gr.mount_gradio_app(app, demo, path="/gradio")
 
1
  import os
2
+ import cv2
3
  import torch
4
  import numpy as np
5
  from PIL import Image
 
6
  import gradio as gr
7
  from gradio_client import Client, handle_file
8
  from torchvision.transforms import Normalize
9
  from facenet_pytorch.models.mtcnn import MTCNN
10
  import concurrent.futures
11
  import tempfile
 
12
  from fastapi import FastAPI, UploadFile, File
13
  from fastapi.responses import HTMLResponse
14
  import shutil
15
 
16
  # ==========================================
17
+ # 1. API ROUTER
18
+ # ==========================================
19
+ # ==========================================
20
+ # 1. API ROUTER
21
  # ==========================================
22
  WORKER_SPACES = [
23
  "bithal26/DeepFake-Worker-1",
 
30
  ]
31
 
32
  clients = []
33
+ print("Initializing connections to 7 Public API Workers...")
 
 
 
 
 
34
  for space in WORKER_SPACES:
35
  try:
36
+ # No token needed anymore!
37
+ clients.append(Client(space))
38
  except Exception as e:
39
  print(f"Warning: Could not connect to {space}. Error: {e}")
40
 
41
  # ==========================================
42
+ # 2. NOTEBOOK-EXACT PREPROCESSING
43
  # ==========================================
44
  mean = [0.485, 0.456, 0.406]
45
  std = [0.229, 0.224, 0.225]
 
68
  frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
69
  if frame_count <= 0: return None
70
  frame_idxs = np.linspace(0, frame_count - 1, num_frames, endpoint=True, dtype=np.int32)
 
71
  frames, idxs_read = [], []
72
  for frame_idx in range(frame_idxs[0], frame_idxs[-1] + 1):
73
  ret = capture.grab()
 
87
  self.video_reader = VideoReader()
88
  self.detector = MTCNN(margin=0, thresholds=[0.7, 0.8, 0.8], device=device)
89
 
90
+ def process_video(self, video_path, frames_per_video=32):
91
  result = self.video_reader.read_frames(video_path, num_frames=frames_per_video)
92
  if result is None: return []
93
  my_frames, my_idxs = result
 
114
  def confident_strategy(pred, t=0.8):
115
  pred = np.array(pred)
116
  sz = len(pred)
117
+ if sz == 0: return 0.5
118
  fakes = np.count_nonzero(pred > t)
119
  if fakes > sz // 2.5 and fakes > 11:
120
+ return float(np.mean(pred[pred > t]))
121
  elif np.count_nonzero(pred < 0.2) > 0.9 * sz:
122
+ return float(np.mean(pred[pred < 0.2]))
123
  else:
124
+ return float(np.mean(pred))
125
 
126
  def call_worker(client, tensor_filepath):
127
  try:
 
130
  if not preds: return 0.5
131
  return confident_strategy(preds)
132
  except Exception as e:
133
+ print(f"Worker Error: {e}")
134
  return 0.5
135
 
136
  # ==========================================
137
+ # 3. FASTAPI SERVER
138
  # ==========================================
139
  app = FastAPI()
140
 
141
+ # Override the JS in the HTML dynamically to make real API calls
142
+ JS_OVERRIDE = """
143
+ <script>
144
+ function handleDrop(e) {
145
+ e.preventDefault();
146
+ document.getElementById('uploadZone').classList.remove('dragging');
147
+ const file = e.dataTransfer.files[0];
148
+ if (file) startAnalysis(file);
149
+ }
150
+
151
+ function startAnalysis(file) {
152
+ if (!file) return;
153
+ const overlay = document.getElementById('analyzeOverlay');
154
+ overlay.classList.add('visible');
155
+
156
+ const steps = ['step1','step2','step3','step4','step5','step6'];
157
+ const labels = [
158
+ 'Decoding video frames...',
159
+ 'Extracting facial landmarks...',
160
+ 'Running 7 parallel neural models...',
161
+ 'Frequency domain analysis...',
162
+ 'Temporal coherence check...',
163
+ 'Generating forensic report...'
164
+ ];
165
+
166
+ let currentStep = 0;
167
+ const interval = setInterval(() => {
168
+ if (currentStep > 0) document.getElementById(steps[currentStep - 1]).className = 'a-step done';
169
+ if (currentStep < steps.length) {
170
+ document.getElementById(steps[currentStep]).className = 'a-step active';
171
+ document.getElementById('analyzeText').textContent = labels[currentStep];
172
+ currentStep++;
173
+ }
174
+ }, 450);
175
+
176
+ const formData = new FormData();
177
+ formData.append('file', file);
178
+ const startTime = performance.now();
179
+
180
+ fetch('/api/analyze', { method: 'POST', body: formData })
181
+ .then(res => res.json())
182
+ .then(data => {
183
+ clearInterval(interval);
184
+ steps.forEach(s => document.getElementById(s).className = 'a-step');
185
+ overlay.classList.remove('visible');
186
+
187
+ if (data.error) {
188
+ alert("Analysis Error: " + data.error);
189
+ return;
190
+ }
191
+
192
+ const duration = ((performance.now() - startTime) / 1000).toFixed(1);
193
+ updateRealMetrics(data.final_score, data.worker_scores);
194
+ showRealResult(file.name, data.final_score, data.worker_scores, duration);
195
+ })
196
+ .catch(err => {
197
+ clearInterval(interval);
198
+ overlay.classList.remove('visible');
199
+ alert("System Error: " + err);
200
+ });
201
+ }
202
+
203
+ function updateRealMetrics(finalScore, workerScores) {
204
+ const isFake = finalScore >= 0.5;
205
+ const confidence = isFake ? finalScore * 100 : (1 - finalScore) * 100;
206
+
207
+ const scoreEl = document.getElementById('authScore');
208
+ scoreEl.textContent = confidence.toFixed(1) + '%';
209
+ scoreEl.className = 'result-score ' + (isFake ? 'fake' : 'authentic');
210
+
211
+ for(let i=1; i<=5; i++) {
212
+ let wScore = workerScores[i-1] ? workerScores[i-1] * 100 : confidence;
213
+ document.getElementById('m' + i).textContent = wScore.toFixed(1) + '%';
214
+ document.getElementById('b' + i).style.width = wScore + '%';
215
+ }
216
+ }
217
+
218
+ function showRealResult(fileName, finalScore, workerScores, duration) {
219
+ const isFake = finalScore >= 0.5;
220
+ const confidence = isFake ? (finalScore * 100).toFixed(1) : ((1 - finalScore) * 100).toFixed(1);
221
+ const overlay = document.getElementById('resultOverlay');
222
+
223
+ document.getElementById('modalScore').textContent = confidence + '%';
224
+ document.getElementById('modalScore').style.color = isFake ? 'var(--red)' : 'var(--green)';
225
+ document.getElementById('modalVerdict').textContent = isFake ? 'DEEPFAKE DETECTED' : 'AUTHENTIC CONTENT';
226
+ document.getElementById('modalVerdict').className = 'verdict-title ' + (isFake ? '' : 'authentic');
227
+ document.getElementById('modalDesc').textContent = isFake
228
+ ? `High confidence manipulation detected in "${fileName}". Ensemble forensic signals indicate AI-generated modifications.`
229
+ : `No significant manipulation detected in "${fileName}". All forensic signals within normal parameters.`;
230
+
231
+ document.getElementById('mm1').textContent = confidence + '%';
232
+ document.getElementById('mm2').textContent = workerScores[1] ? (workerScores[1]*100).toFixed(1) + '%' : confidence + '%';
233
+ document.getElementById('mm3').textContent = duration + 's';
234
+
235
+ overlay.classList.add('visible');
236
+ }
237
+
238
+ function closeResult() { document.getElementById('resultOverlay').classList.remove('visible'); }
239
+ document.getElementById('resultOverlay').addEventListener('click', function(e) { if (e.target === this) closeResult(); });
240
+
241
+ setTimeout(() => {
242
+ const observer = new IntersectionObserver((entries) => {
243
+ entries.forEach(e => {
244
+ if (e.isIntersecting) {
245
+ e.target.style.opacity = '1';
246
+ e.target.style.transform = 'translateY(0)';
247
+ }
248
+ });
249
+ }, { threshold: 0.1 });
250
+ document.querySelectorAll('.how-step, .feature-card, .report-card').forEach(el => {
251
+ el.style.opacity = '0';
252
+ el.style.transform = 'translateY(24px)';
253
+ el.style.transition = 'opacity 0.6s ease, transform 0.6s ease, border-color 0.3s';
254
+ observer.observe(el);
255
+ });
256
+ }, 500);
257
+ </script>
258
+ </body>
259
+ </html>
260
+ """
261
+
262
  @app.get("/")
263
  def read_root():
264
+ # Read the raw HTML file and replace the static <script> with our live API logic
265
  with open("deepfake-detector.html", "r", encoding="utf-8") as f:
266
  html_content = f.read()
267
+
268
+ # Split off the bottom script tag and replace it with the live JS
269
+ html_parts = html_content.split("<script>")
270
+ live_html = html_parts[0] + JS_OVERRIDE
271
+ return HTMLResponse(content=live_html)
272
 
 
273
  @app.post("/api/analyze")
274
  async def analyze_api(file: UploadFile = File(...)):
 
275
  temp_dir = tempfile.mkdtemp()
276
  video_path = os.path.join(temp_dir, file.filename)
277
  with open(video_path, "wb") as buffer:
278
  shutil.copyfileobj(file.file, buffer)
279
 
 
280
  input_size = 380
281
+ frames_per_video = 32
282
+ batch_size = frames_per_video * 4
283
 
284
+ faces = face_extractor.process_video(video_path, frames_per_video=frames_per_video)
285
+
286
+ x = np.zeros((batch_size, input_size, input_size, 3), dtype=np.uint8)
287
+ n = 0
288
  for frame_data in faces:
289
  for face in frame_data["faces"]:
290
  resized_face = isotropically_resize_image(face, input_size)
291
  resized_face = put_to_center(resized_face, input_size)
292
+ if n < batch_size:
293
+ x[n] = resized_face
294
+ n += 1
295
 
296
+ if n == 0:
297
+ shutil.rmtree(temp_dir, ignore_errors=True)
298
+ return {"error": "No faces detected."}
299
+
300
+ # Pass exactly 'n' tensors to avoid blank arrays
301
+ x_tensor = torch.tensor(x[:n]).float()
302
+ x_tensor = x_tensor.permute((0, 3, 1, 2))
303
+ for i in range(n):
304
+ x_tensor[i] = normalize_transform(x_tensor[i] / 255.)
305
 
 
306
  tensor_path = os.path.join(temp_dir, "batch_tensor.pt")
307
+ torch.save(x_tensor, tensor_path)
308
 
 
309
  worker_scores = []
310
  with concurrent.futures.ThreadPoolExecutor(max_workers=7) as executor:
311
  futures = [executor.submit(call_worker, client, tensor_path) for client in clients]
312
  for future in concurrent.futures.as_completed(futures):
313
  worker_scores.append(future.result())
314
 
 
315
  final_score = np.mean(worker_scores)
 
 
316
  shutil.rmtree(temp_dir, ignore_errors=True)
317
 
 
318
  return {
319
  "final_score": float(final_score),
320
  "worker_scores": [float(s) for s in worker_scores]
321
  }
322
 
 
323
  demo = gr.Blocks()
324
  app = gr.mount_gradio_app(app, demo, path="/gradio")