Spaces:
Sleeping
Sleeping
| import io | |
| import os | |
| import cv2 | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from fastapi import FastAPI, UploadFile, File, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from transformers import AutoImageProcessor, AutoModelForImageClassification | |
| app = FastAPI(title="DeepFake Detection API") | |
| # Setup CORS for your React frontend | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global model loading | |
| MODEL_NAME = "dima806/deepfake_vs_real_image_detection" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| try: | |
| processor = AutoImageProcessor.from_pretrained(MODEL_NAME) | |
| model = AutoModelForImageClassification.from_pretrained(MODEL_NAME).to(device) | |
| model.eval() | |
| print(f"✅ Model loaded on {device}") | |
| except Exception as e: | |
| print(f"❌ Error loading model: {e}") | |
| def predict_frame(image: Image.Image): | |
| inputs = processor(images=image, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| probs = torch.nn.functional.softmax(logits, dim=-1) | |
| pred_id = logits.argmax(-1).item() | |
| conf = probs[0][pred_id].item() | |
| label = model.config.id2label[pred_id].lower() | |
| final_label = "FAKE" if "fake" in label else "REAL" | |
| return final_label, conf | |
| def health_check(): | |
| return {"status": "online", "model": MODEL_NAME} | |
| async def predict_image_api(file: UploadFile = File(...)): | |
| try: | |
| content = await file.read() | |
| image = Image.open(io.BytesIO(content)).convert("RGB") | |
| label, confidence = predict_frame(image) | |
| return { | |
| "success": True, | |
| "prediction": label, | |
| "confidence": round(confidence, 4), | |
| "status": "Detection complete" | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def predict_video_api(file: UploadFile = File(...)): | |
| temp_path = f"temp_{file.filename}" | |
| try: | |
| with open(temp_path, "wb") as f: | |
| f.write(await file.read()) | |
| cap = cv2.VideoCapture(temp_path) | |
| predictions = [] | |
| frame_count = 0 | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: break | |
| # Sample every 15th frame to keep Hugging Face CPU happy | |
| if frame_count % 15 == 0: | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| pil_img = Image.fromarray(frame_rgb) | |
| label, _ = predict_frame(pil_img) | |
| predictions.append(label) | |
| frame_count += 1 | |
| cap.release() | |
| os.remove(temp_path) | |
| if not predictions: | |
| return {"success": False, "message": "No frames processed"} | |
| fake_count = predictions.count("FAKE") | |
| final_pred = "FAKE" if fake_count > (len(predictions) / 2) else "REAL" | |
| return { | |
| "success": True, | |
| "prediction": final_pred, | |
| "stats": { | |
| "total_frames_sampled": len(predictions), | |
| "fake_frames": fake_count, | |
| "real_frames": len(predictions) - fake_count | |
| } | |
| } | |
| except Exception as e: | |
| if os.path.exists(temp_path): os.remove(temp_path) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |