jjuarez's picture
Update app.py
a060c24 verified
import gradio as gr
from transformers import ViTForImageClassification, ViTFeatureExtractor
from PIL import Image
import torch
import numpy as np
# Load the pre-trained model and preprocessor (feature extractor)
model_name = "jjuarez/Vit_waste_image_class"
model = ViTForImageClassification.from_pretrained(model_name)
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
def classify_image(image):
# Convert the PIL Image to a format compatible with the feature extractor
image = np.array(image)
# Preprocess the image and prepare it for the model
inputs = feature_extractor(images=image, return_tensors="pt")
# Make prediction
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
# Retrieve the highest probability class label index
predicted_class_idx = logits.argmax(-1).item()
# Define a manual mapping of label indices to human-readable labels
index_to_label = {
0: "Aluminium",
1: "Batteries",
2: "Cardboard",
3: "Glass",
4: "Hard Plastic",
5: "Paper",
6: "Soft Plastics",
}
# Convert the index to the model's class label
label = index_to_label.get(predicted_class_idx, "Unknown Label")
return label
# Create Gradio interface
iface = gr.Interface(fn=classify_image,
inputs=gr.Image(), # Accepts image of any size
outputs=gr.Label(),
title="Waste Classification with ViT",
description="Upload an image of waste, and the model will classify it.")
# Launch the app
iface.launch()