jjuarez commited on
Commit
84d5fc9
·
verified ·
1 Parent(s): 271847e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -9
app.py CHANGED
@@ -1,20 +1,21 @@
1
  import gradio as gr
2
- from transformers import AutoModelForImageClassification, AutoFeatureExtractor
3
  from PIL import Image
4
  import torch
 
5
 
6
  # Load the pre-trained model and preprocessor (feature extractor)
7
  model_name = "jjuarez/Vit_waste_image_class"
8
- model = AutoModelForImageClassification.from_pretrained(model_name)
9
- feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
10
 
11
  def classify_image(image):
12
- # Convert PIL Image to NumPy array if not already done by Gradio
13
  image = np.array(image)
14
 
15
- # Let the feature extractor handle resizing and normalization
16
  inputs = feature_extractor(images=image, return_tensors="pt")
17
-
18
  # Make prediction
19
  with torch.no_grad():
20
  outputs = model(**inputs)
@@ -28,9 +29,9 @@ def classify_image(image):
28
 
29
  return label
30
 
31
- # Create Gradio interface without specifying shape for the image input
32
- iface = gr.Interface(fn=classify_image,
33
- inputs=gr.Image(), # Removed shape parameter, accepts image of any size
34
  outputs=gr.Label(),
35
  title="Waste Classification with ViT",
36
  description="Upload an image of waste, and the model will classify it.")
 
1
  import gradio as gr
2
+ from transformers import ViTForImageClassification, ViTFeatureExtractor
3
  from PIL import Image
4
  import torch
5
+ import numpy as np
6
 
7
  # Load the pre-trained model and preprocessor (feature extractor)
8
  model_name = "jjuarez/Vit_waste_image_class"
9
+ model = ViTForImageClassification.from_pretrained(model_name)
10
+ feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
11
 
12
  def classify_image(image):
13
+ # Convert the PIL Image to a format compatible with the feature extractor
14
  image = np.array(image)
15
 
16
+ # Preprocess the image and prepare it for the model
17
  inputs = feature_extractor(images=image, return_tensors="pt")
18
+
19
  # Make prediction
20
  with torch.no_grad():
21
  outputs = model(**inputs)
 
29
 
30
  return label
31
 
32
+ # Create Gradio interface
33
+ iface = gr.Interface(fn=classify_image,
34
+ inputs=gr.Image(), # Accepts image of any size
35
  outputs=gr.Label(),
36
  title="Waste Classification with ViT",
37
  description="Upload an image of waste, and the model will classify it.")