File size: 7,927 Bytes
c7f5818
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78b703e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import torch
import torch.nn as nn
from torchvision import models, transforms
from facenet_pytorch import MTCNN
import cv2
import numpy as np
from PIL import Image
import shutil
import os
import uuid

# --- CONFIG ---
MODEL_PATH = "deepfake_detector.pth"  # Path ที่เก็บโมเดล
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
IMG_SIZE = 224
FRAME_INTERVAL = 10  # ตรวจทุกๆ 10 เฟรม (เพื่อความเร็ว)

app = FastAPI(title="DeepGuard API", description="Deepfake Detection API")
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # ใน Production ควรเปลี่ยนเป็น domain จริง แต่ตอนนี้ใส่ * (รับหมด) ไปก่อน
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)
# --- 1. Model Loader ---
def build_model():
    # ต้องสร้างโครงสร้างโมเดลให้เหมือนตอนเทรนเป๊ะๆ
    model = models.efficientnet_b0(weights=None) # ไม่ต้องโหลด weight เน็ต เพราะเราจะโหลดเอง
    num_ftrs = model.classifier[1].in_features
    # โครงสร้างต้องตรงกับที่เราแก้ล่าสุด (Dropout + Linear)
    model.classifier[1] = nn.Sequential(
        nn.Dropout(p=0.6),
        nn.Linear(num_ftrs, 2)
    )
    return model

print(f"Loading model from {MODEL_PATH}...")
try:
    model = build_model()
    # map_location สำคัญมากถ้าเทรนบน GPU แต่รันบน CPU
    model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
    model.to(DEVICE)
    model.eval() # สั่งให้เป็นโหมดทำนาย (ปิด Dropout)
    print("✅ Model loaded successfully!")
except Exception as e:
    print(f"❌ Error loading model: {e}")
    # ถ้าโหลดไม่ผ่าน ให้รันต่อไม่ได้
    raise e

# --- 2. Helper Components ---
# Face Detector
mtcnn = MTCNN(keep_all=False, select_largest=True, device=DEVICE, margin=20)

# Preprocessing (เหมือนตอน Validation)
transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# --- 3. Processing Logic ---
def process_video(video_path):
    cap = cv2.VideoCapture(video_path)
    frames_processed = 0
    fake_probs = []
    
    frame_count = 0
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
            
        # ข้ามเฟรมเพื่อความเร็ว
        if frame_count % FRAME_INTERVAL != 0:
            frame_count += 1
            continue
            
        # แปลงสี BGR -> RGB
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        pil_img = Image.fromarray(frame_rgb)
        
        # Detect Face & Crop
        # MTCNN จะคืนค่ามาเป็น PIL Image ที่ Crop แล้ว (ถ้า detected)
        face = mtcnn(pil_img) 
        
        if face is not None:
            # face ที่ได้มาเป็น Tensor อยู่แล้ว (เพราะ MTCNN ของ library นี้ทำให้)
            # แต่เราต้อง Normalize ให้เหมือนตอนเทรน (EfficientNet ต้องการ)
            
            # แปลงกลับเป็น PIL เพื่อเข้า Transform มาตรฐาน (หรือจะทำ manual ก็ได้)
            # เพื่อความชัวร์ ใช้ transform ที่เราประกาศไว้ดีกว่า
            # แต่ mtcnn คืนค่ามาเป็น tensor ที่ normalize มาแบบนึงแล้ว (0-1)
            # ซึ่งมักจะไม่ตรงกับ ImageNet mean/std
            
            # **แก้ไข Logic:** เพื่อความแม่นยำสูงสุด เราจะใช้ box จาก mtcnn แล้ว crop เอง
            # เพื่อส่งเข้า transform pipeline เดียวกับตอนเทรน
            boxes, _ = mtcnn.detect(pil_img)
            if boxes is not None:
                box = boxes[0]
                face_img = pil_img.crop(box)
                
                # Preprocess
                input_tensor = transform(face_img).unsqueeze(0).to(DEVICE)
                
                # Inference
                with torch.no_grad():
                    outputs = model(input_tensor)
                    # output คือ [logit_fake, logit_real] หรือ [logit_real, logit_fake] ขึ้นอยู่กับ class index
                    # ปกติ ImageFolder จะเรียงตามตัวอักษร: 0=fake, 1=real
                    # เราจะใช้ Softmax เพื่อแปลงเป็น %
                    probs = torch.softmax(outputs, dim=1)
                    fake_prob = probs[0][0].item() # สมมติว่า class 0 คือ fake
                    
                    fake_probs.append(fake_prob)
                    frames_processed += 1

        frame_count += 1
    
    cap.release()
    return fake_probs

# --- 4. API Endpoints ---
@app.get("/")
def home():
    return {"message": "DeepGuard API is running!"}

@app.post("/predict")
async def predict_video(file: UploadFile = File(...)):
    # 1. เช็คชนิดไฟล์
    if not file.filename.endswith(('.mp4', '.avi', '.mov')):
        raise HTTPException(status_code=400, detail="Invalid file format. Please upload video.")
    
    # 2. Save ไฟล์วิดีโอลงเครื่องชั่วคราว (เพราะ OpenCV อ่านจาก RAM ไม่ได้ง่ายๆ)
    temp_filename = f"temp_{uuid.uuid4()}.mp4"
    try:
        with open(temp_filename, "wb") as buffer:
            shutil.copyfileobj(file.file, buffer)
            
        # 3. ส่งเข้า Process
        fake_probs = process_video(temp_filename)
        
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Processing Error: {str(e)}")
    finally:
        # 4. ลบไฟล์ทิ้งเมื่อเสร็จ (Clean up)
        if os.path.exists(temp_filename):
            os.remove(temp_filename)
            
    # 5. สรุปผล
    if not fake_probs:
        return {"status": "failed", "message": "No faces detected in the video."}
    
    # หาค่าเฉลี่ยความน่าจะเป็น
    avg_fake_prob = np.mean(fake_probs)
    
    # Thresholding (ปรับได้)
    # ถ้า Fake Probability > 0.5 ให้ตอบว่า FAKE
    # แต่เนื่องจากเรามี Label Smoothing ค่ามันอาจจะไม่ใช่ 0 หรือ 1 เป๊ะๆ
    detection_result = "FAKE" if avg_fake_prob > 0.5 else "REAL"
    confidence = avg_fake_prob if detection_result == "FAKE" else (1 - avg_fake_prob)
    
    return {
        "filename": file.filename,
        "result": detection_result,
        "confidence": f"{confidence*100:.2f}%",
        "fake_probability": float(avg_fake_prob),
        "frames_analyzed": len(fake_probs)
    }

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)