Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from transformers import ViTForImageClassification, ViTFeatureExtractor | |
| from PIL import Image | |
| import torch | |
| import numpy as np | |
| # Load the pre-trained model and preprocessor (feature extractor) | |
| model_name = "jjuarez/Vit_waste_image_class" | |
| model = ViTForImageClassification.from_pretrained(model_name) | |
| feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224") | |
| def classify_image(image): | |
| # Convert the PIL Image to a format compatible with the feature extractor | |
| image = np.array(image) | |
| # Preprocess the image and prepare it for the model | |
| inputs = feature_extractor(images=image, return_tensors="pt") | |
| # Make prediction | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| # Retrieve the highest probability class label index | |
| predicted_class_idx = logits.argmax(-1).item() | |
| # Define a manual mapping of label indices to human-readable labels | |
| index_to_label = { | |
| 0: "Aluminium", | |
| 1: "Batteries", | |
| 2: "Cardboard", | |
| 3: "Glass", | |
| 4: "Hard Plastic", | |
| 5: "Paper", | |
| 6: "Soft Plastics", | |
| } | |
| # Convert the index to the model's class label | |
| label = index_to_label.get(predicted_class_idx, "Unknown Label") | |
| return label | |
| # Create Gradio interface | |
| iface = gr.Interface(fn=classify_image, | |
| inputs=gr.Image(), # Accepts image of any size | |
| outputs=gr.Label(), | |
| title="Waste Classification with ViT", | |
| description="Upload an image of waste, and the model will classify it.") | |
| # Launch the app | |
| iface.launch() | |