Deepfake-things / app.py
Nick-2x's picture
Create app.py
3642241 verified
import io
import os
import cv2
import torch
import numpy as np
from PIL import Image
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoImageProcessor, AutoModelForImageClassification
app = FastAPI(title="DeepFake Detection API")
# Setup CORS for your React frontend
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global model loading
MODEL_NAME = "dima806/deepfake_vs_real_image_detection"
device = "cuda" if torch.cuda.is_available() else "cpu"
try:
processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
model = AutoModelForImageClassification.from_pretrained(MODEL_NAME).to(device)
model.eval()
print(f"✅ Model loaded on {device}")
except Exception as e:
print(f"❌ Error loading model: {e}")
def predict_frame(image: Image.Image):
inputs = processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probs = torch.nn.functional.softmax(logits, dim=-1)
pred_id = logits.argmax(-1).item()
conf = probs[0][pred_id].item()
label = model.config.id2label[pred_id].lower()
final_label = "FAKE" if "fake" in label else "REAL"
return final_label, conf
@app.get("/")
def health_check():
return {"status": "online", "model": MODEL_NAME}
@app.post("/predict/image")
async def predict_image_api(file: UploadFile = File(...)):
try:
content = await file.read()
image = Image.open(io.BytesIO(content)).convert("RGB")
label, confidence = predict_frame(image)
return {
"success": True,
"prediction": label,
"confidence": round(confidence, 4),
"status": "Detection complete"
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/predict/video")
async def predict_video_api(file: UploadFile = File(...)):
temp_path = f"temp_{file.filename}"
try:
with open(temp_path, "wb") as f:
f.write(await file.read())
cap = cv2.VideoCapture(temp_path)
predictions = []
frame_count = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret: break
# Sample every 15th frame to keep Hugging Face CPU happy
if frame_count % 15 == 0:
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pil_img = Image.fromarray(frame_rgb)
label, _ = predict_frame(pil_img)
predictions.append(label)
frame_count += 1
cap.release()
os.remove(temp_path)
if not predictions:
return {"success": False, "message": "No frames processed"}
fake_count = predictions.count("FAKE")
final_pred = "FAKE" if fake_count > (len(predictions) / 2) else "REAL"
return {
"success": True,
"prediction": final_pred,
"stats": {
"total_frames_sampled": len(predictions),
"fake_frames": fake_count,
"real_frames": len(predictions) - fake_count
}
}
except Exception as e:
if os.path.exists(temp_path): os.remove(temp_path)
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)