jjuarez commited on
Commit
598d759
·
verified ·
1 Parent(s): 21b485c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -3
app.py CHANGED
@@ -10,23 +10,32 @@ model = AutoModelForImageClassification.from_pretrained(model_name)
10
  feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
11
 
12
  def classify_image(image):
 
 
 
13
  # Preprocess the image
14
  inputs = feature_extractor(images=image, return_tensors="pt")
 
15
  # Make prediction
16
  with torch.no_grad():
17
  logits = model(**inputs).logits
 
18
  # Retrieve the highest probability class label
19
  predicted_class_idx = logits.argmax(-1).item()
 
20
  # Convert the index to the model's class label
21
  label = model.config.id2label[predicted_class_idx]
 
22
  return label
23
 
24
- # Create Gradio interface
 
25
  iface = gr.Interface(fn=classify_image,
26
- inputs=gr.inputs.Image(shape=(224, 224)),
27
- outputs="label",
28
  title="Waste Classification with ViT",
29
  description="Upload an image of waste, and the model will classify it.")
30
 
31
  # Launch the app
32
  iface.launch()
 
 
10
  feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
11
 
12
  def classify_image(image):
13
+ # Resize the input image to (224, 224)
14
+ image = image.resize((224, 224))
15
+
16
  # Preprocess the image
17
  inputs = feature_extractor(images=image, return_tensors="pt")
18
+
19
  # Make prediction
20
  with torch.no_grad():
21
  logits = model(**inputs).logits
22
+
23
  # Retrieve the highest probability class label
24
  predicted_class_idx = logits.argmax(-1).item()
25
+
26
  # Convert the index to the model's class label
27
  label = model.config.id2label[predicted_class_idx]
28
+
29
  return label
30
 
31
+ # Create Gradio interface without specifying shape for the image input
32
+ # This allows the input component to accept images of any size
33
  iface = gr.Interface(fn=classify_image,
34
+ inputs=gr.Image(), # Removed shape parameter
35
+ outputs=gr.Label(),
36
  title="Waste Classification with ViT",
37
  description="Upload an image of waste, and the model will classify it.")
38
 
39
  # Launch the app
40
  iface.launch()
41
+