bithal26's picture
Update app.py
212d8bf verified
"""
================================================================================
VERIDEX β€” Master UI / Orchestrator Space (DeepFake-Detector-UI)
──────────────────────────────────────────────────────────────────
Architecture
────────────
β€’ FastAPI serves the custom deepfake-detector.html at GET /
β€’ POST /predict/ accepts a raw .mp4 upload
1. Saves video to a temp file
2. MTCNN extracts up to NUM_FRAMES faces (380 Γ— 380, uint8 HWC)
3. Batch is saved as a compressed .npy file
4. Fires the .npy at all 7 Workers in parallel via gradio_client
5. Aggregates per-frame predictions with confident_strategy
6. Returns JSON { prediction, score, filename, worker_results }
ENV VARS (set in HF Space settings)
─────────────────────────────────────
WORKER_1_URL … WORKER_7_URL β€” public Gradio Space URLs for each worker
e.g. https://your-user-deepfake-worker-1.hf.space
NUM_FRAMES default 32 β€” frames to sample per video
WORKER_TIMEOUT default 120 β€” seconds to wait per worker call
================================================================================
"""
import os
import io
import time
import uuid
import logging
import tempfile
import traceback
import traceback as _tb
from concurrent.futures import ThreadPoolExecutor, as_completed, TimeoutError as FuturesTimeout
from pathlib import Path
from typing import Optional
import cv2
import numpy as np
import torch
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
import uvicorn
from gradio_client import Client, handle_file
# ─────────────────────────────────────────────────────────────────────────────
# Optional: facenet-pytorch for MTCNN face detection
# ─────────────────────────────────────────────────────────────────────────────
try:
from facenet_pytorch import MTCNN
FACENET_AVAILABLE = True
except ImportError:
FACENET_AVAILABLE = False
logging.warning(
"facenet-pytorch not installed β€” falling back to full-frame "
"centre-crop for face extraction."
)
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [UI] %(levelname)s %(message)s",
)
logger = logging.getLogger(__name__)
# ══════════════════════════════════════════════════════════════════════════════
# Configuration
# ══════════════════════════════════════════════════════════════════════════════
NUM_FRAMES = int(os.environ.get("NUM_FRAMES", "32"))
WORKER_TIMEOUT = int(os.environ.get("WORKER_TIMEOUT", "120"))
INPUT_SIZE = 380 # must match worker expectation
# Worker URLs β€” read from env vars so no secrets are hard-coded
WORKER_URLS: list[str] = [
url for url in (
os.environ.get(f"WORKER_{i}_URL", "").strip()
for i in range(1, 8)
)
if url
]
if not WORKER_URLS:
logger.warning(
"No WORKER_*_URL env vars set. "
"Set WORKER_1_URL … WORKER_7_URL in Space settings."
)
# ── HTML template path ────────────────────────────────────────────────────────
HTML_FILE = Path(__file__).parent / "deepfake-detector.html"
# ── MTCNN ─────────────────────────────────────────────────────────────────────
if FACENET_AVAILABLE:
# keep_all=True returns every detected face per frame
_mtcnn = MTCNN(
keep_all=True,
device="cuda" if torch.cuda.is_available() else "cpu",
select_largest=False,
post_process=False, # return raw uint8 tensors, not normalised
image_size=INPUT_SIZE,
margin=20,
)
logger.info("MTCNN initialised.")
else:
_mtcnn = None
# ══════════════════════════════════════════════════════════════════════════════
# Face extraction helpers
# ══════════════════════════════════════════════════════════════════════════════
def _isotropic_resize(img: np.ndarray, size: int) -> np.ndarray:
h, w = img.shape[:2]
if max(h, w) == size:
return img
scale = size / max(h, w)
new_h, new_w = int(h * scale), int(w * scale)
interp = cv2.INTER_CUBIC if scale > 1 else cv2.INTER_AREA
return cv2.resize(img, (new_w, new_h), interpolation=interp)
def _put_to_center(img: np.ndarray, size: int) -> np.ndarray:
img = img[:size, :size]
canvas = np.zeros((size, size, 3), dtype=np.uint8)
sh = (size - img.shape[0]) // 2
sw = (size - img.shape[1]) // 2
canvas[sh : sh + img.shape[0], sw : sw + img.shape[1]] = img
return canvas
def _extract_faces_mtcnn(video_path: str, num_frames: int) -> Optional[np.ndarray]:
"""
Use MTCNN to detect and crop faces from evenly-spaced video frames.
Returns uint8 numpy array of shape (N, INPUT_SIZE, INPUT_SIZE, 3) or None.
"""
cap = cv2.VideoCapture(video_path)
total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if total <= 0:
cap.release()
return None
idxs = np.linspace(0, total - 1, num_frames, dtype=np.int32)
faces_collected: list[np.ndarray] = []
for idx in idxs:
cap.set(cv2.CAP_PROP_POS_FRAMES, int(idx))
ret, frame_bgr = cap.read()
if not ret:
continue
frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
from PIL import Image as _PILImage
pil_frame = _PILImage.fromarray(frame_rgb)
try:
boxes, _ = _mtcnn.detect(pil_frame)
if boxes is None:
# No face detected β€” fall back to centre crop of whole frame
face = _isotropic_resize(frame_rgb, INPUT_SIZE)
face = _put_to_center(face, INPUT_SIZE)
faces_collected.append(face)
continue
for box in boxes:
x1, y1, x2, y2 = [int(c) for c in box]
x1, y1 = max(0, x1), max(0, y1)
x2, y2 = min(frame_rgb.shape[1], x2), min(frame_rgb.shape[0], y2)
crop = frame_rgb[y1:y2, x1:x2]
if crop.size == 0:
continue
face = _isotropic_resize(crop, INPUT_SIZE)
face = _put_to_center(face, INPUT_SIZE)
faces_collected.append(face)
except Exception as exc:
logger.warning(f"MTCNN failed on frame {idx}: {exc}")
face = _isotropic_resize(frame_rgb, INPUT_SIZE)
face = _put_to_center(face, INPUT_SIZE)
faces_collected.append(face)
cap.release()
if not faces_collected:
return None
return np.stack(faces_collected[:num_frames * 4], axis=0).astype(np.uint8)
def _extract_faces_fallback(video_path: str, num_frames: int) -> Optional[np.ndarray]:
"""Centre-crop fallback when facenet-pytorch is not available."""
cap = cv2.VideoCapture(video_path)
total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if total <= 0:
cap.release()
return None
idxs = np.linspace(0, total - 1, num_frames, dtype=np.int32)
frames = []
for idx in idxs:
cap.set(cv2.CAP_PROP_POS_FRAMES, int(idx))
ret, frame_bgr = cap.read()
if not ret:
continue
frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
face = _isotropic_resize(frame_rgb, INPUT_SIZE)
face = _put_to_center(face, INPUT_SIZE)
frames.append(face)
cap.release()
if not frames:
return None
return np.stack(frames, axis=0).astype(np.uint8)
def extract_faces(video_path: str) -> Optional[np.ndarray]:
if FACENET_AVAILABLE and _mtcnn is not None:
return _extract_faces_mtcnn(video_path, NUM_FRAMES)
return _extract_faces_fallback(video_path, NUM_FRAMES)
# ══════════════════════════════════════════════════════════════════════════════
# Aggregation strategy (mirrors deepfake_det.py confident_strategy)
# ══════════════════════════════════════════════════════════════════════════════
def confident_strategy(pred: np.ndarray, t: float = 0.8) -> float:
pred = np.array(pred, dtype=np.float32)
if len(pred) == 0:
return 0.5
confident_fake = pred[pred > t]
if len(confident_fake) >= 1:
return float(np.mean(confident_fake))
confident_real = pred[pred < (1 - t)]
if len(confident_real) >= 1:
return float(np.mean(confident_real))
return float(np.mean(pred))
# ══════════════════════════════════════════════════════════════════════════════
# Worker communication
# ══════════════════════════════════════════════════════════════════════════════
def _call_worker(worker_url: str, npy_path: str, worker_idx: int) -> dict:
"""
Call one Worker Space via gradio_client.
Returns a dict with keys: worker, predictions, n_frames, error, score
"""
result_stub = {"worker": worker_idx, "predictions": None, "n_frames": 0,
"error": None, "score": 0.5}
try:
client = Client(worker_url, verbose=False)
# handle_file wraps the filepath so gradio_client sends it correctly
response = client.predict(
npy_file=handle_file(npy_path),
api_name="/predict",
)
# response may be the dict directly or a JSON string
if isinstance(response, str):
import json
response = json.loads(response)
if not isinstance(response, dict):
raise TypeError(f"Unexpected worker response type: {type(response)}")
worker_error = response.get("error")
predictions = response.get("predictions")
if worker_error:
# Worker returned an application-level error β€” log it fully
logger.error(
f"[Worker {worker_idx}] Application error:\n{worker_error}"
)
result_stub["error"] = worker_error
return result_stub
if predictions is None or len(predictions) == 0:
msg = f"Worker returned empty predictions list: {response}"
logger.error(f"[Worker {worker_idx}] {msg}")
result_stub["error"] = msg
return result_stub
score = confident_strategy(predictions)
logger.info(
f"[Worker {worker_idx}] OK β€” frames={len(predictions)}, score={score:.4f}"
)
result_stub.update({
"predictions": predictions,
"n_frames": response.get("n_frames", len(predictions)),
"score": score,
})
return result_stub
except FuturesTimeout:
msg = f"Timed out after {WORKER_TIMEOUT}s"
logger.error(f"[Worker {worker_idx}] {msg}")
result_stub["error"] = msg
return result_stub
except Exception:
full_tb = _tb.format_exc()
logger.error(f"[Worker {worker_idx}] Exception:\n{full_tb}")
result_stub["error"] = full_tb
return result_stub
def dispatch_to_workers(npy_path: str) -> list[dict]:
"""
Fire the .npy file at all configured workers in parallel.
Each worker gets its own thread; WORKER_TIMEOUT caps each call.
Workers that fail contribute a score=0.5 fallback but log the real error.
"""
if not WORKER_URLS:
logger.warning("No workers configured β€” returning neutral score.")
return [{"worker": 0, "predictions": None, "n_frames": 0,
"error": "No workers configured.", "score": 0.5}]
results: list[dict] = []
with ThreadPoolExecutor(max_workers=len(WORKER_URLS)) as pool:
futures = {
pool.submit(_call_worker, url, npy_path, i + 1): i + 1
for i, url in enumerate(WORKER_URLS)
}
for fut in as_completed(futures, timeout=WORKER_TIMEOUT + 10):
try:
results.append(fut.result())
except Exception:
w = futures[fut]
full_tb = _tb.format_exc()
logger.error(f"[Worker {w}] Future raised:\n{full_tb}")
results.append({"worker": w, "predictions": None,
"n_frames": 0, "error": full_tb, "score": 0.5})
return results
# ══════════════════════════════════════════════════════════════════════════════
# FastAPI app
# ══════════════════════════════════════════════════════════════════════════════
app = FastAPI(title="VERIDEX DeepFake Detector UI")
@app.get("/", response_class=HTMLResponse)
async def serve_ui():
"""Serve the custom VERIDEX HTML interface."""
if not HTML_FILE.exists():
raise HTTPException(
status_code=404,
detail=f"deepfake-detector.html not found at {HTML_FILE}. "
"Ensure the file is committed to the Space repository root.",
)
return HTMLResponse(content=HTML_FILE.read_text(encoding="utf-8"))
@app.get("/health")
async def health():
return {
"status": "ok",
"workers": len(WORKER_URLS),
"worker_urls": WORKER_URLS,
"facenet": FACENET_AVAILABLE,
"num_frames": NUM_FRAMES,
"worker_timeout": WORKER_TIMEOUT,
}
@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
"""
Main prediction endpoint.
1. Save uploaded video to a temp file.
2. Extract faces via MTCNN β†’ uint8 .npy.
3. Dispatch .npy to all workers in parallel.
4. Aggregate scores, return result.
"""
start_time = time.time()
tmp_dir = tempfile.mkdtemp(prefix="veridex_")
try:
# ── 1. Save uploaded video ────────────────────────────────────────────
video_path = os.path.join(tmp_dir, f"input_{uuid.uuid4().hex}.mp4")
contents = await file.read()
with open(video_path, "wb") as f:
f.write(contents)
logger.info(f"Video saved: {video_path} ({len(contents)/1024:.1f} KB)")
# ── 2. Face extraction ────────────────────────────────────────────────
faces_array = extract_faces(video_path)
if faces_array is None or faces_array.shape[0] == 0:
raise HTTPException(
status_code=422,
detail="No faces detected in the uploaded video. "
"Please upload a video that clearly shows a face.",
)
logger.info(f"Face extraction complete: {faces_array.shape}")
# ── 3. Serialise to compressed uint8 .npy ─────────────────────────────
npy_path = os.path.join(tmp_dir, "faces.npy")
# allow_pickle=False keeps the file safe and small;
# uint8 is ~4Γ— smaller than float32 β†’ stays within HF payload limits
np.save(npy_path, faces_array.astype(np.uint8))
npy_size_kb = os.path.getsize(npy_path) / 1024
logger.info(f"NPY payload: {npy_path} ({npy_size_kb:.1f} KB)")
# ── 4. Dispatch to workers ─────────────────────────────────────────────
worker_results = dispatch_to_workers(npy_path)
# ── 5. Aggregate ───────────────────────────────────────────────────────
# Collect all per-frame predictions from workers that succeeded
all_predictions: list[float] = []
successful_workers = 0
for r in worker_results:
if r.get("predictions") and r.get("error") is None:
all_predictions.extend(r["predictions"])
successful_workers += 1
if not all_predictions:
logger.warning(
"All workers failed or returned no predictions. "
"Returning neutral score. See per-worker errors above."
)
final_score = 0.5
else:
final_score = confident_strategy(all_predictions)
label = "FAKE" if final_score >= 0.5 else "REAL"
elapsed = round(time.time() - start_time, 2)
logger.info(
f"Result: {label} score={final_score:.4f} "
f"workers={successful_workers}/{len(WORKER_URLS)} "
f"elapsed={elapsed}s"
)
return JSONResponse({
"prediction": label,
"score": round(final_score, 4),
"score_pct": f"{final_score * 100:.1f}%",
"filename": file.filename,
"faces_extracted": int(faces_array.shape[0]),
"successful_workers": successful_workers,
"total_workers": len(WORKER_URLS),
"elapsed_sec": elapsed,
"worker_results": [
{
"worker": r["worker"],
"score": round(r["score"], 4),
"n_frames": r["n_frames"],
# Truncate the full traceback in the API response but it
# has already been printed in full to the server console.
"error": (r["error"][:300] + "…") if r.get("error") else None,
}
for r in sorted(worker_results, key=lambda x: x["worker"])
],
})
except HTTPException:
raise
except Exception:
full_tb = traceback.format_exc()
logger.error(f"Unhandled error in /predict/:\n{full_tb}")
raise HTTPException(status_code=500, detail=full_tb)
finally:
# Best-effort cleanup; ignore errors if HF locks the temp dir
import shutil
try:
shutil.rmtree(tmp_dir, ignore_errors=True)
except Exception:
pass
# ══════════════════════════════════════════════════════════════════════════════
# Entry point
# ══════════════════════════════════════════════════════════════════════════════
if __name__ == "__main__":
uvicorn.run(
"app:app",
host="0.0.0.0",
port=7860,
log_level="info",
# HF Spaces injects PORT; honour it if present
**({} if not os.environ.get("PORT") else
{"port": int(os.environ["PORT"])}),
)