deepguard-api / api.py
Fuuji's picture
Update api.py
c7f5818 verified
Raw
History Blame Contribute Delete
7.93 kB
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import torch
import torch.nn as nn
from torchvision import models, transforms
from facenet_pytorch import MTCNN
import cv2
import numpy as np
from PIL import Image
import shutil
import os
import uuid
# --- CONFIG ---
MODEL_PATH = "deepfake_detector.pth" # Path ที่เก็บโมเดล
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
IMG_SIZE = 224
FRAME_INTERVAL = 10 # ตรวจทุกๆ 10 เฟรม (เพื่อความเร็ว)
app = FastAPI(title="DeepGuard API", description="Deepfake Detection API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # ใน Production ควรเปลี่ยนเป็น domain จริง แต่ตอนนี้ใส่ * (รับหมด) ไปก่อน
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# --- 1. Model Loader ---
def build_model():
# ต้องสร้างโครงสร้างโมเดลให้เหมือนตอนเทรนเป๊ะๆ
model = models.efficientnet_b0(weights=None) # ไม่ต้องโหลด weight เน็ต เพราะเราจะโหลดเอง
num_ftrs = model.classifier[1].in_features
# โครงสร้างต้องตรงกับที่เราแก้ล่าสุด (Dropout + Linear)
model.classifier[1] = nn.Sequential(
nn.Dropout(p=0.6),
nn.Linear(num_ftrs, 2)
)
return model
print(f"Loading model from {MODEL_PATH}...")
try:
model = build_model()
# map_location สำคัญมากถ้าเทรนบน GPU แต่รันบน CPU
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.to(DEVICE)
model.eval() # สั่งให้เป็นโหมดทำนาย (ปิด Dropout)
print("✅ Model loaded successfully!")
except Exception as e:
print(f"❌ Error loading model: {e}")
# ถ้าโหลดไม่ผ่าน ให้รันต่อไม่ได้
raise e
# --- 2. Helper Components ---
# Face Detector
mtcnn = MTCNN(keep_all=False, select_largest=True, device=DEVICE, margin=20)
# Preprocessing (เหมือนตอน Validation)
transform = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# --- 3. Processing Logic ---
def process_video(video_path):
cap = cv2.VideoCapture(video_path)
frames_processed = 0
fake_probs = []
frame_count = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
# ข้ามเฟรมเพื่อความเร็ว
if frame_count % FRAME_INTERVAL != 0:
frame_count += 1
continue
# แปลงสี BGR -> RGB
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pil_img = Image.fromarray(frame_rgb)
# Detect Face & Crop
# MTCNN จะคืนค่ามาเป็น PIL Image ที่ Crop แล้ว (ถ้า detected)
face = mtcnn(pil_img)
if face is not None:
# face ที่ได้มาเป็น Tensor อยู่แล้ว (เพราะ MTCNN ของ library นี้ทำให้)
# แต่เราต้อง Normalize ให้เหมือนตอนเทรน (EfficientNet ต้องการ)
# แปลงกลับเป็น PIL เพื่อเข้า Transform มาตรฐาน (หรือจะทำ manual ก็ได้)
# เพื่อความชัวร์ ใช้ transform ที่เราประกาศไว้ดีกว่า
# แต่ mtcnn คืนค่ามาเป็น tensor ที่ normalize มาแบบนึงแล้ว (0-1)
# ซึ่งมักจะไม่ตรงกับ ImageNet mean/std
# **แก้ไข Logic:** เพื่อความแม่นยำสูงสุด เราจะใช้ box จาก mtcnn แล้ว crop เอง
# เพื่อส่งเข้า transform pipeline เดียวกับตอนเทรน
boxes, _ = mtcnn.detect(pil_img)
if boxes is not None:
box = boxes[0]
face_img = pil_img.crop(box)
# Preprocess
input_tensor = transform(face_img).unsqueeze(0).to(DEVICE)
# Inference
with torch.no_grad():
outputs = model(input_tensor)
# output คือ [logit_fake, logit_real] หรือ [logit_real, logit_fake] ขึ้นอยู่กับ class index
# ปกติ ImageFolder จะเรียงตามตัวอักษร: 0=fake, 1=real
# เราจะใช้ Softmax เพื่อแปลงเป็น %
probs = torch.softmax(outputs, dim=1)
fake_prob = probs[0][0].item() # สมมติว่า class 0 คือ fake
fake_probs.append(fake_prob)
frames_processed += 1
frame_count += 1
cap.release()
return fake_probs
# --- 4. API Endpoints ---
@app.get("/")
def home():
return {"message": "DeepGuard API is running!"}
@app.post("/predict")
async def predict_video(file: UploadFile = File(...)):
# 1. เช็คชนิดไฟล์
if not file.filename.endswith(('.mp4', '.avi', '.mov')):
raise HTTPException(status_code=400, detail="Invalid file format. Please upload video.")
# 2. Save ไฟล์วิดีโอลงเครื่องชั่วคราว (เพราะ OpenCV อ่านจาก RAM ไม่ได้ง่ายๆ)
temp_filename = f"temp_{uuid.uuid4()}.mp4"
try:
with open(temp_filename, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
# 3. ส่งเข้า Process
fake_probs = process_video(temp_filename)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Processing Error: {str(e)}")
finally:
# 4. ลบไฟล์ทิ้งเมื่อเสร็จ (Clean up)
if os.path.exists(temp_filename):
os.remove(temp_filename)
# 5. สรุปผล
if not fake_probs:
return {"status": "failed", "message": "No faces detected in the video."}
# หาค่าเฉลี่ยความน่าจะเป็น
avg_fake_prob = np.mean(fake_probs)
# Thresholding (ปรับได้)
# ถ้า Fake Probability > 0.5 ให้ตอบว่า FAKE
# แต่เนื่องจากเรามี Label Smoothing ค่ามันอาจจะไม่ใช่ 0 หรือ 1 เป๊ะๆ
detection_result = "FAKE" if avg_fake_prob > 0.5 else "REAL"
confidence = avg_fake_prob if detection_result == "FAKE" else (1 - avg_fake_prob)
return {
"filename": file.filename,
"result": detection_result,
"confidence": f"{confidence*100:.2f}%",
"fake_probability": float(avg_fake_prob),
"frames_analyzed": len(fake_probs)
}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)