Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from transformers import ViTFeatureExtractor, ViTForImageClassification | |
| import os | |
| # --- Part 1: Model & Prediction Function --- | |
| # Define the model ID from Hugging Face Hub | |
| MODEL_ID = "yangy50/garbage-classification" | |
| # Load the model and feature extractor once when the app starts | |
| try: | |
| feature_extractor = ViTFeatureExtractor.from_pretrained(MODEL_ID) | |
| model = ViTForImageClassification.from_pretrained(MODEL_ID) | |
| # Move the model to the GPU if available | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| model.eval() | |
| print("Model loaded successfully.") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| # Handle the error gracefully | |
| feature_extractor = None | |
| model = None | |
| def predict_image(image_file): | |
| """ | |
| The prediction function for the Gradio app. | |
| Takes an image, processes it, and returns the prediction. | |
| """ | |
| if model is None: | |
| return "Model not available. Please check the logs.", None | |
| if image_file is None: | |
| return "Please upload an image.", None | |
| try: | |
| # The image_file input is a direct path string | |
| image = Image.open(image_file).convert("RGB") | |
| # Preprocess the image using the feature extractor | |
| inputs = feature_extractor(images=image, return_tensors="pt") | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| # Make a prediction | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| # Get the predicted class and its label | |
| predicted_class_idx = logits.argmax(-1).item() | |
| predicted_class = model.config.id2label[predicted_class_idx] | |
| # Get all class probabilities for a richer output | |
| probabilities = torch.nn.functional.softmax(logits, dim=-1)[0] | |
| confidences = { | |
| model.config.id2label[i]: float(probabilities[i]) | |
| for i in range(len(model.config.id2label)) | |
| } | |
| return predicted_class, confidences | |
| except Exception as e: | |
| # A more user-friendly error message | |
| return f"An error occurred: {e}. Please ensure the input is a valid image file.", None | |
| # --- Part 2: Gradio Interface --- | |
| # Automatically discover and use all .jpg and .png files in the 'examples' folder | |
| examples_dir = "examples" | |
| if os.path.exists(examples_dir): | |
| example_paths = [ | |
| os.path.join(examples_dir, f) for f in os.listdir(examples_dir) | |
| if f.endswith((".jpg", ".png")) | |
| ] | |
| else: | |
| example_paths = [] | |
| # Create the Gradio Interface | |
| gr.Interface( | |
| fn=predict_image, | |
| inputs=gr.Image(type="filepath", label="Upload an Image of a Waste Item"), | |
| outputs=[ | |
| gr.Label(label="Predicted Class"), | |
| gr.Label(label="Confidences") | |
| ], | |
| title="🗑️ Smart Recycling Assistant ♻️", | |
| description="This model classifies waste into categories to help you recycle correctly. You can simply upload a photo of a waste item to see its category. The model will classify the item as one of these categories: 'cardboard', 'glass', 'metal', 'paper', 'plastic', or 'trash'.", | |
| examples=example_paths, | |
| cache_examples=False | |
| ).launch() | |