dgapi / app.py
Hayloo9838's picture
Update app.py
ac11708 verified
raw
history blame
918 Bytes
import gradio as gr
from PIL import Image
import torch
from transformers import ViTFeatureExtractor, ViTForImageClassification
# Load model + processor
model = ViTForImageClassification.from_pretrained("Falconsai/nsfw_image_detection")
processor = ViTFeatureExtractor.from_pretrained("Falconsai/nsfw_image_detection")
def predict(image):
# Process image
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
label = model.config.id2label[predicted_class_idx]
return label
# Gradio Interface
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload Image"),
outputs=gr.Label(label="Prediction"),
title="NSFW Image Detection (ViT)",
description="Upload an image to classify if it's NSFW or SFW."
)
demo.launch()