gannushalini2006's picture
Update app.py
2db0bbe verified
import gradio as gr
import torch
import numpy as np
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForImageClassification
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# -------------------------------------------------
# Load Models (AI Image Detectors)
# -------------------------------------------------
models = {
"ViT": "umm-maybe/AI-image-detector",
"ResNet": "umm-maybe/AI-image-detector",
"EfficientNet": "umm-maybe/AI-image-detector"
}
processors = {}
detectors = {}
for name, model_id in models.items():
processors[name] = AutoImageProcessor.from_pretrained(model_id)
detectors[name] = AutoModelForImageClassification.from_pretrained(model_id).to(device).eval()
# -------------------------------------------------
# Prediction
# -------------------------------------------------
def predict_model(image, name):
inputs = processors[name](images=image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = detectors[name](**inputs)
probs = torch.softmax(outputs.logits, dim=-1)[0]
pred = torch.argmax(probs).item()
label = detectors[name].config.id2label[pred]
confidence = probs[pred].item()
return label, confidence
# -------------------------------------------------
# HARD VOTING ENSEMBLE
# -------------------------------------------------
def ensemble_predict(image):
votes = []
details = {}
for name in models:
label, conf = predict_model(image, name)
votes.append(label)
details[name] = {
"prediction": label,
"confidence": round(conf, 3)
}
final_label = max(set(votes), key=votes.count)
return final_label, details
# -------------------------------------------------
# Gradio Interface
# -------------------------------------------------
def detect_ai(image):
label, details = ensemble_predict(image)
return label, details
demo = gr.Interface(
fn=detect_ai,
inputs=gr.Image(type="pil"),
outputs=[
gr.Label(label="Final Decision"),
gr.JSON(label="Model-wise Predictions")
],
title="AI Image Detector (Ensemble Hard Voting)",
description=(
"Detects whether an image is AI-generated or Real using an ensemble of models.\n"
"Hard voting is applied for robust prediction."
)
)
demo.launch()