cy4nide's picture
Update detector.py
c8a7c2e verified
raw
history blame contribute delete
871 Bytes
import torch
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForImageClassification
import streamlit as st
MODEL_NAME = "Wvolf/ViT_Deepfake_Detection"
@st.cache_resource
def load_model():
processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
model = AutoModelForImageClassification.from_pretrained(MODEL_NAME)
model.eval()
return processor, model
def predict_image(image: Image.Image):
processor, model = load_model()
if image.mode != "RGB":
image = image.convert("RGB")
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
probs = torch.softmax(outputs.logits, dim=1)
pred_id = torch.argmax(probs, dim=1).item()
confidence = probs[0][pred_id].item()
label = model.config.id2label[pred_id]
return label, confidence