File size: 7,437 Bytes
c9caccb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import cv2
import numpy as np
import os
import tempfile
import shutil
import warnings
import logging
import os

# Suppress Keras and TensorFlow warnings
warnings.filterwarnings('ignore', category=UserWarning)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
logging.getLogger('absl').setLevel(logging.ERROR)

from tensorflow.keras.preprocessing import image
from tensorflow.keras.models import load_model
from ultralytics import YOLO

app = FastAPI(title="Deepfake Detection API")

# Update CORS for frontend connectivity
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"], # In production, replace with your frontend URL
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

@app.get("/")
async def health_check():
    return {"status": "online", "message": "Backend is running!"}

print("Loading model and YOLO11...")
model = load_model('model.h5', compile=False)
# YOLOv11 pose model for extremely tight and precise facial feature cropping (matches MTCNN style)
detector = YOLO('yolo11n-pose.pt')
print("Model loaded successfully.")

def detect_and_crop_face(img):
    """Detects face/person and crops it to 128x128."""
    # Run YOLO11 pose detection
    results = detector.predict(img, verbose=False)
    
    if len(results) > 0 and results[0].keypoints is not None and len(results[0].keypoints.xy[0]) > 0:
        # Get the first person's keypoints
        kpts = results[0].keypoints.xy[0].cpu().numpy()
        
        # 0: nose, 1: left eye, 2: right eye, 3: left ear, 4: right ear
        face_kpts = kpts[0:5]
        # Filter out keypoints that weren't detected
        valid_kpts = [k for k in face_kpts if k[0] > 0 and k[1] > 0]
        
        if valid_kpts:
            valid_kpts = np.array(valid_kpts)
            x_min, y_min = np.min(valid_kpts, axis=0)
            x_max, y_max = np.max(valid_kpts, axis=0)
            
            # Expand this tight keypoint box to capture the full face (forehead, chin, cheeks)
            w = x_max - x_min
            h = y_max - y_min
            
            # Safety for edge cases
            if w > 0 and h > 0:
                pad_x = w * 0.3
                pad_y_top = h * 0.5   # Expand more upward for the forehead
                pad_y_bot = h * 0.8   # Expand downward for the chin/mouth
                
                final_x1 = max(0, int(x_min - pad_x))
                final_y1 = max(0, int(y_min - pad_y_top))
                final_x2 = min(img.shape[1], int(x_max + pad_x))
                final_y2 = min(img.shape[0], int(y_max + pad_y_bot))
                
                face = img[final_y1:final_y2, final_x1:final_x2]
                
                if face.size > 0:
                    return cv2.resize(face, (128, 128))
            
    # Fallback to normal YOLO box heuristic if face keypoints fail but person is found
    if len(results) > 0 and len(results[0].boxes) > 0:
        box = results[0].boxes[0].xyxy[0].cpu().numpy()
        x1, y1, x2, y2 = map(int, box)
        x1, y1 = max(0, x1), max(0, y1)
        x2, y2 = min(img.shape[1], x2), min(img.shape[0], y2)
        
        h = y2 - y1
        w = x2 - x1
        if h > w * 1.5:  
            y2 = y1 + int(h * 0.3)
            
        face = img[y1:y2, x1:x2]
        if face.size > 0:
            return cv2.resize(face, (128, 128))

    # If no person is detected at all
    return cv2.resize(img, (128, 128))

def preprocess_face(face):
    """Formats the cropped face for the model."""
    img_array = image.img_to_array(face)
    img_array = np.expand_dims(img_array, axis=0)
    img_array /= 255.0  # Normalize
    return img_array

def process_image(img):
    """Processes a single BGR image array and returns the fake probability."""
    face = detect_and_crop_face(img)
    processed_image = preprocess_face(face)
    prediction = model.predict(processed_image, verbose=0)
    return float(prediction[0][0])

@app.post("/predict")
async def predict_media(file: UploadFile = File(...)):
    filename = file.filename.lower()
    
    is_video = filename.endswith(('.mp4', '.avi', '.mov', '.mkv'))
    is_image = filename.endswith(('.jpg', '.jpeg', '.png', '.bmp'))
    
    if not is_image and not is_video:
        raise HTTPException(status_code=400, detail="Unsupported file format.")
        
    try:
        if is_image:
            # Read image directly from bytes
            contents = await file.read()
            nparr = np.frombuffer(contents, np.uint8)
            img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
            
            if img is None:
                raise HTTPException(status_code=400, detail="Invalid image file.")
                
            score = process_image(img)
            result = "Real" if score < 0.5 else "Fake"
            
            return {
                "filename": filename,
                "type": "image",
                "prediction": result,
                "confidence_score": score
            }
            
        elif is_video:
            # Save video to a temporary file for cv2.VideoCapture
            with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_video:
                shutil.copyfileobj(file.file, temp_video)
                temp_video_path = temp_video.name
                
            cap = cv2.VideoCapture(temp_video_path)
            if not cap.isOpened():
                os.unlink(temp_video_path)
                raise HTTPException(status_code=400, detail="Could not open video file.")
            
            frame_scores = []
            frame_count = 0
            
            # Process 1 frame every 5 frames (~6 fps for 30fps video) for better accuracy
            while True:
                ret, frame = cap.read()
                if not ret:
                    break
                
                if frame_count % 5 == 0:
                    score = process_image(frame)
                    frame_scores.append(score)
                    
                frame_count += 1
                
            cap.release()
            os.unlink(temp_video_path)
            
            if not frame_scores:
                raise HTTPException(status_code=400, detail="Could not extract frames from video.")
                
            # Deepfakes often only manipulate specific frames, so average score can mask the spoof.
            # We use max_score to find the most manipulated frame.
            max_score = max(frame_scores)
            avg_score = sum(frame_scores) / len(frame_scores)
            
            fake_frames_count = sum(1 for s in frame_scores if s >= 0.5)
            
            final_result = "Real" if max_score < 0.5 else "Fake"
            
            return {
                "filename": filename,
                "type": "video",
                "prediction": final_result,
                "confidence_score": max_score,
                "frames_analyzed": len(frame_scores),
                "fake_frames_count": fake_frames_count,
                "max_fake_score": max_score,
                "avg_score": avg_score
            }
            
    except Exception as e:
        import traceback
        traceback.print_exc()
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)