DeepTrust2 / model /predict.py
priyansh-nagar's picture
Upload 4 files
b154f75 verified
raw
history blame contribute delete
749 Bytes
import torch
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForImageClassification
MODEL_ID = "prithivMLmods/Deepfake-Detection-Exp-02-21"
processor = AutoImageProcessor.from_pretrained(MODEL_ID)
model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
model.eval()
def predict(image: Image.Image):
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
probs = torch.softmax(outputs.logits, dim=1)
conf, pred = torch.max(probs, dim=1)
pred = pred.item()
conf = conf.item()
label = "Fake" if pred == 0 else "Real"
trust_score = int(conf*100 if label=="Real" else (1-conf)*100)
return label, conf, trust_score