Spaces:
Runtime error
Runtime error
File size: 1,662 Bytes
1a7e2f0 84d5fc9 1a7e2f0 84d5fc9 1a7e2f0 84d5fc9 1a7e2f0 84d5fc9 271847e 177fc8c 84d5fc9 1a7e2f0 84d5fc9 1a7e2f0 271847e 177fc8c d0eb989 1a7e2f0 177fc8c d0eb989 a060c24 d0eb989 1a7e2f0 d0eb989 177fc8c 1a7e2f0 d0eb989 84d5fc9 598d759 1a7e2f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import gradio as gr
from transformers import ViTForImageClassification, ViTFeatureExtractor
from PIL import Image
import torch
import numpy as np
# Load the pre-trained model and preprocessor (feature extractor)
model_name = "jjuarez/Vit_waste_image_class"
model = ViTForImageClassification.from_pretrained(model_name)
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
def classify_image(image):
# Convert the PIL Image to a format compatible with the feature extractor
image = np.array(image)
# Preprocess the image and prepare it for the model
inputs = feature_extractor(images=image, return_tensors="pt")
# Make prediction
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
# Retrieve the highest probability class label index
predicted_class_idx = logits.argmax(-1).item()
# Define a manual mapping of label indices to human-readable labels
index_to_label = {
0: "Aluminium",
1: "Batteries",
2: "Cardboard",
3: "Glass",
4: "Hard Plastic",
5: "Paper",
6: "Soft Plastics",
}
# Convert the index to the model's class label
label = index_to_label.get(predicted_class_idx, "Unknown Label")
return label
# Create Gradio interface
iface = gr.Interface(fn=classify_image,
inputs=gr.Image(), # Accepts image of any size
outputs=gr.Label(),
title="Waste Classification with ViT",
description="Upload an image of waste, and the model will classify it.")
# Launch the app
iface.launch()
|