Nick-2x's picture
Update app.py
b3bf88b verified
from fastapi import FastAPI, File, UploadFile, HTTPException
from transformers import ViTForImageClassification, ViTImageProcessor
from PIL import Image
import torch
import io
app = FastAPI()
# Load model and processor
MODEL_ID = "prithivMLmods/Deep-Fake-Detector-v2-Model"
processor = ViTImageProcessor.from_pretrained(MODEL_ID)
model = ViTForImageClassification.from_pretrained(MODEL_ID)
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
try:
# 1. Read the uploaded file
image_data = await file.read()
image = Image.open(io.BytesIO(image_data)).convert("RGB")
# 2. Preprocess
inputs = processor(images=image, return_tensors="pt")
# 3. Inference
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = torch.argmax(logits, dim=1).item()
# 4. Map ID to Label
label = model.config.id2label[predicted_class_idx]
# 5. Get Confidence
probs = torch.nn.functional.softmax(logits, dim=-1)
confidence = probs[0][predicted_class_idx].item()
return {
"prediction": label,
"confidence": round(confidence, 4),
"is_deepfake": label.lower() == "deepfake"
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)