Spaces:
Sleeping
Sleeping
File size: 7,437 Bytes
c9caccb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 | from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import cv2
import numpy as np
import os
import tempfile
import shutil
import warnings
import logging
import os
# Suppress Keras and TensorFlow warnings
warnings.filterwarnings('ignore', category=UserWarning)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
logging.getLogger('absl').setLevel(logging.ERROR)
from tensorflow.keras.preprocessing import image
from tensorflow.keras.models import load_model
from ultralytics import YOLO
app = FastAPI(title="Deepfake Detection API")
# Update CORS for frontend connectivity
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, replace with your frontend URL
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def health_check():
return {"status": "online", "message": "Backend is running!"}
print("Loading model and YOLO11...")
model = load_model('model.h5', compile=False)
# YOLOv11 pose model for extremely tight and precise facial feature cropping (matches MTCNN style)
detector = YOLO('yolo11n-pose.pt')
print("Model loaded successfully.")
def detect_and_crop_face(img):
"""Detects face/person and crops it to 128x128."""
# Run YOLO11 pose detection
results = detector.predict(img, verbose=False)
if len(results) > 0 and results[0].keypoints is not None and len(results[0].keypoints.xy[0]) > 0:
# Get the first person's keypoints
kpts = results[0].keypoints.xy[0].cpu().numpy()
# 0: nose, 1: left eye, 2: right eye, 3: left ear, 4: right ear
face_kpts = kpts[0:5]
# Filter out keypoints that weren't detected
valid_kpts = [k for k in face_kpts if k[0] > 0 and k[1] > 0]
if valid_kpts:
valid_kpts = np.array(valid_kpts)
x_min, y_min = np.min(valid_kpts, axis=0)
x_max, y_max = np.max(valid_kpts, axis=0)
# Expand this tight keypoint box to capture the full face (forehead, chin, cheeks)
w = x_max - x_min
h = y_max - y_min
# Safety for edge cases
if w > 0 and h > 0:
pad_x = w * 0.3
pad_y_top = h * 0.5 # Expand more upward for the forehead
pad_y_bot = h * 0.8 # Expand downward for the chin/mouth
final_x1 = max(0, int(x_min - pad_x))
final_y1 = max(0, int(y_min - pad_y_top))
final_x2 = min(img.shape[1], int(x_max + pad_x))
final_y2 = min(img.shape[0], int(y_max + pad_y_bot))
face = img[final_y1:final_y2, final_x1:final_x2]
if face.size > 0:
return cv2.resize(face, (128, 128))
# Fallback to normal YOLO box heuristic if face keypoints fail but person is found
if len(results) > 0 and len(results[0].boxes) > 0:
box = results[0].boxes[0].xyxy[0].cpu().numpy()
x1, y1, x2, y2 = map(int, box)
x1, y1 = max(0, x1), max(0, y1)
x2, y2 = min(img.shape[1], x2), min(img.shape[0], y2)
h = y2 - y1
w = x2 - x1
if h > w * 1.5:
y2 = y1 + int(h * 0.3)
face = img[y1:y2, x1:x2]
if face.size > 0:
return cv2.resize(face, (128, 128))
# If no person is detected at all
return cv2.resize(img, (128, 128))
def preprocess_face(face):
"""Formats the cropped face for the model."""
img_array = image.img_to_array(face)
img_array = np.expand_dims(img_array, axis=0)
img_array /= 255.0 # Normalize
return img_array
def process_image(img):
"""Processes a single BGR image array and returns the fake probability."""
face = detect_and_crop_face(img)
processed_image = preprocess_face(face)
prediction = model.predict(processed_image, verbose=0)
return float(prediction[0][0])
@app.post("/predict")
async def predict_media(file: UploadFile = File(...)):
filename = file.filename.lower()
is_video = filename.endswith(('.mp4', '.avi', '.mov', '.mkv'))
is_image = filename.endswith(('.jpg', '.jpeg', '.png', '.bmp'))
if not is_image and not is_video:
raise HTTPException(status_code=400, detail="Unsupported file format.")
try:
if is_image:
# Read image directly from bytes
contents = await file.read()
nparr = np.frombuffer(contents, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if img is None:
raise HTTPException(status_code=400, detail="Invalid image file.")
score = process_image(img)
result = "Real" if score < 0.5 else "Fake"
return {
"filename": filename,
"type": "image",
"prediction": result,
"confidence_score": score
}
elif is_video:
# Save video to a temporary file for cv2.VideoCapture
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_video:
shutil.copyfileobj(file.file, temp_video)
temp_video_path = temp_video.name
cap = cv2.VideoCapture(temp_video_path)
if not cap.isOpened():
os.unlink(temp_video_path)
raise HTTPException(status_code=400, detail="Could not open video file.")
frame_scores = []
frame_count = 0
# Process 1 frame every 5 frames (~6 fps for 30fps video) for better accuracy
while True:
ret, frame = cap.read()
if not ret:
break
if frame_count % 5 == 0:
score = process_image(frame)
frame_scores.append(score)
frame_count += 1
cap.release()
os.unlink(temp_video_path)
if not frame_scores:
raise HTTPException(status_code=400, detail="Could not extract frames from video.")
# Deepfakes often only manipulate specific frames, so average score can mask the spoof.
# We use max_score to find the most manipulated frame.
max_score = max(frame_scores)
avg_score = sum(frame_scores) / len(frame_scores)
fake_frames_count = sum(1 for s in frame_scores if s >= 0.5)
final_result = "Real" if max_score < 0.5 else "Fake"
return {
"filename": filename,
"type": "video",
"prediction": final_result,
"confidence_score": max_score,
"frames_analyzed": len(frame_scores),
"fake_frames_count": fake_frames_count,
"max_fake_score": max_score,
"avg_score": avg_score
}
except Exception as e:
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)
|