File size: 1,662 Bytes
1a7e2f0
84d5fc9
1a7e2f0
 
84d5fc9
1a7e2f0
 
 
84d5fc9
 
1a7e2f0
 
84d5fc9
271847e
177fc8c
84d5fc9
1a7e2f0
84d5fc9
1a7e2f0
 
271847e
 
177fc8c
d0eb989
1a7e2f0
177fc8c
d0eb989
 
a060c24
 
 
 
 
 
 
d0eb989
 
1a7e2f0
d0eb989
177fc8c
1a7e2f0
 
d0eb989
84d5fc9
 
 
598d759
1a7e2f0
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import gradio as gr
from transformers import ViTForImageClassification, ViTFeatureExtractor
from PIL import Image
import torch
import numpy as np

# Load the pre-trained model and preprocessor (feature extractor)
model_name = "jjuarez/Vit_waste_image_class"
model = ViTForImageClassification.from_pretrained(model_name)
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")

def classify_image(image):
    # Convert the PIL Image to a format compatible with the feature extractor
    image = np.array(image)

    # Preprocess the image and prepare it for the model
    inputs = feature_extractor(images=image, return_tensors="pt")

    # Make prediction
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits

    # Retrieve the highest probability class label index
    predicted_class_idx = logits.argmax(-1).item()

    # Define a manual mapping of label indices to human-readable labels
    index_to_label = {
        0: "Aluminium",
        1: "Batteries",
        2: "Cardboard",
        3: "Glass",
        4: "Hard Plastic",
        5: "Paper",
        6: "Soft Plastics",
    }

    # Convert the index to the model's class label
    label = index_to_label.get(predicted_class_idx, "Unknown Label")

    return label


# Create Gradio interface
iface = gr.Interface(fn=classify_image, 
                     inputs=gr.Image(),  # Accepts image of any size
                     outputs=gr.Label(),
                     title="Waste Classification with ViT",
                     description="Upload an image of waste, and the model will classify it.")

# Launch the app
iface.launch()