import io import os import cv2 import torch import numpy as np from PIL import Image from fastapi import FastAPI, UploadFile, File, HTTPException from fastapi.middleware.cors import CORSMiddleware from transformers import AutoImageProcessor, AutoModelForImageClassification app = FastAPI(title="DeepFake Detection API") # Setup CORS for your React frontend app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Global model loading MODEL_NAME = "dima806/deepfake_vs_real_image_detection" device = "cuda" if torch.cuda.is_available() else "cpu" try: processor = AutoImageProcessor.from_pretrained(MODEL_NAME) model = AutoModelForImageClassification.from_pretrained(MODEL_NAME).to(device) model.eval() print(f"✅ Model loaded on {device}") except Exception as e: print(f"❌ Error loading model: {e}") def predict_frame(image: Image.Image): inputs = processor(images=image, return_tensors="pt").to(device) with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits probs = torch.nn.functional.softmax(logits, dim=-1) pred_id = logits.argmax(-1).item() conf = probs[0][pred_id].item() label = model.config.id2label[pred_id].lower() final_label = "FAKE" if "fake" in label else "REAL" return final_label, conf @app.get("/") def health_check(): return {"status": "online", "model": MODEL_NAME} @app.post("/predict/image") async def predict_image_api(file: UploadFile = File(...)): try: content = await file.read() image = Image.open(io.BytesIO(content)).convert("RGB") label, confidence = predict_frame(image) return { "success": True, "prediction": label, "confidence": round(confidence, 4), "status": "Detection complete" } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/predict/video") async def predict_video_api(file: UploadFile = File(...)): temp_path = f"temp_{file.filename}" try: with open(temp_path, "wb") as f: f.write(await file.read()) cap = cv2.VideoCapture(temp_path) predictions = [] frame_count = 0 while cap.isOpened(): ret, frame = cap.read() if not ret: break # Sample every 15th frame to keep Hugging Face CPU happy if frame_count % 15 == 0: frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) pil_img = Image.fromarray(frame_rgb) label, _ = predict_frame(pil_img) predictions.append(label) frame_count += 1 cap.release() os.remove(temp_path) if not predictions: return {"success": False, "message": "No frames processed"} fake_count = predictions.count("FAKE") final_pred = "FAKE" if fake_count > (len(predictions) / 2) else "REAL" return { "success": True, "prediction": final_pred, "stats": { "total_frames_sampled": len(predictions), "fake_frames": fake_count, "real_frames": len(predictions) - fake_count } } except Exception as e: if os.path.exists(temp_path): os.remove(temp_path) raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)