saibhagawan's picture
Update app.py
0c0ce45 verified
import gradio as gr
import torch
import numpy as np
from PIL import Image
from transformers import ViTFeatureExtractor, ViTForImageClassification
import os
# --- Part 1: Model & Prediction Function ---
# Define the model ID from Hugging Face Hub
MODEL_ID = "yangy50/garbage-classification"
# Load the model and feature extractor once when the app starts
try:
feature_extractor = ViTFeatureExtractor.from_pretrained(MODEL_ID)
model = ViTForImageClassification.from_pretrained(MODEL_ID)
# Move the model to the GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
print("Model loaded successfully.")
except Exception as e:
print(f"Error loading model: {e}")
# Handle the error gracefully
feature_extractor = None
model = None
def predict_image(image_file):
"""
The prediction function for the Gradio app.
Takes an image, processes it, and returns the prediction.
"""
if model is None:
return "Model not available. Please check the logs.", None
if image_file is None:
return "Please upload an image.", None
try:
# The image_file input is a direct path string
image = Image.open(image_file).convert("RGB")
# Preprocess the image using the feature extractor
inputs = feature_extractor(images=image, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
# Make a prediction
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
# Get the predicted class and its label
predicted_class_idx = logits.argmax(-1).item()
predicted_class = model.config.id2label[predicted_class_idx]
# Get all class probabilities for a richer output
probabilities = torch.nn.functional.softmax(logits, dim=-1)[0]
confidences = {
model.config.id2label[i]: float(probabilities[i])
for i in range(len(model.config.id2label))
}
return predicted_class, confidences
except Exception as e:
# A more user-friendly error message
return f"An error occurred: {e}. Please ensure the input is a valid image file.", None
# --- Part 2: Gradio Interface ---
# Automatically discover and use all .jpg and .png files in the 'examples' folder
examples_dir = "examples"
if os.path.exists(examples_dir):
example_paths = [
os.path.join(examples_dir, f) for f in os.listdir(examples_dir)
if f.endswith((".jpg", ".png"))
]
else:
example_paths = []
# Create the Gradio Interface
gr.Interface(
fn=predict_image,
inputs=gr.Image(type="filepath", label="Upload an Image of a Waste Item"),
outputs=[
gr.Label(label="Predicted Class"),
gr.Label(label="Confidences")
],
title="🗑️ Smart Recycling Assistant ♻️",
description="This model classifies waste into categories to help you recycle correctly. You can simply upload a photo of a waste item to see its category. The model will classify the item as one of these categories: 'cardboard', 'glass', 'metal', 'paper', 'plastic', or 'trash'.",
examples=example_paths,
cache_examples=False
).launch()