Spaces:
Sleeping
Sleeping
File size: 3,211 Bytes
ca2a79c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 | from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
import tempfile, os
# ββ Lazy load β imported on first request, not at startup βββββββββ
_predict_image = None
_predict_video = None
def get_models():
global _predict_image, _predict_video
if _predict_image is None:
from model import predict_image, predict_video
_predict_image = predict_image
_predict_video = predict_video
return _predict_image, _predict_video
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=[
"http://localhost:5173",
"https://deepfake-sentinel.vercel.app",
"https://devendra174-deepfake-sentinel.hf.space",
],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/health")
def health():
return {"status": "ok"}
@app.get("/")
def root():
return {"status": "Deepfake API running"}
@app.post("/predict/image")
async def predict_image_api(file: UploadFile = File(...)):
predict_image, _ = get_models()
img = Image.open(file.file).convert("RGB")
label, prob = predict_image(img)
fake_pct = round(float(prob) * 100, 2)
conf_pct = fake_pct if label == "FAKE" else round(100 - fake_pct, 2)
return {"verdict": label, "confidence": conf_pct, "probability": fake_pct,
"raw_score": float(prob), "threshold": 0.5, "model_used": "XceptionViT"}
@app.post("/predict/video")
async def predict_video_api(file: UploadFile = File(...)):
_, predict_video = get_models()
temp_name = None
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp:
temp_name = tmp.name
tmp.write(await file.read())
label, prob = predict_video(temp_name)
if label == "ERROR":
raise HTTPException(status_code=500, detail="Could not extract frames.")
fake_pct = round(float(prob) * 100, 2)
conf_pct = fake_pct if label == "FAKE" else round(100 - fake_pct, 2)
return {"verdict": label, "confidence": conf_pct, "probability": fake_pct,
"raw_score": float(prob), "threshold": 0.5, "model_used": "XceptionViT"}
finally:
if temp_name and os.path.exists(temp_name):
os.remove(temp_name)
@app.post("/predict/webcam")
async def predict_webcam_api(file: UploadFile = File(...)):
_, predict_video = get_models()
temp_name = None
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=".webm") as tmp:
temp_name = tmp.name
tmp.write(await file.read())
label, prob = predict_video(temp_name)
if label == "ERROR":
raise HTTPException(status_code=500, detail="Could not extract frames.")
fake_pct = round(float(prob) * 100, 2)
conf_pct = fake_pct if label == "FAKE" else round(100 - fake_pct, 2)
return {"verdict": label, "confidence": conf_pct, "probability": fake_pct,
"raw_score": float(prob), "threshold": 0.5, "model_used": "XceptionViT"}
finally:
if temp_name and os.path.exists(temp_name):
os.remove(temp_name) |