jjuarez commited on
Commit
1a7e2f0
·
verified ·
1 Parent(s): c787f34

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -0
app.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForImageClassification, AutoFeatureExtractor
3
+ import requests
4
+ from PIL import Image
5
+ import torch
6
+
7
+ # Load the pre-trained model and preprocessor (feature extractor)
8
+ model_name = "jjuarez/Vit_waste_image_class"
9
+ model = AutoModelForImageClassification.from_pretrained(model_name)
10
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
11
+
12
+ def classify_image(image):
13
+ # Preprocess the image
14
+ inputs = feature_extractor(images=image, return_tensors="pt")
15
+ # Make prediction
16
+ with torch.no_grad():
17
+ logits = model(**inputs).logits
18
+ # Retrieve the highest probability class label
19
+ predicted_class_idx = logits.argmax(-1).item()
20
+ # Convert the index to the model's class label
21
+ label = model.config.id2label[predicted_class_idx]
22
+ return label
23
+
24
+ # Create Gradio interface
25
+ iface = gr.Interface(fn=classify_image,
26
+ inputs=gr.inputs.Image(shape=(224, 224)),
27
+ outputs="label",
28
+ title="Waste Classification with ViT",
29
+ description="Upload an image of waste, and the model will classify it.")
30
+
31
+ # Launch the app
32
+ iface.launch()