jjuarez commited on
Commit
177fc8c
·
verified ·
1 Parent(s): 598d759

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -6
app.py CHANGED
@@ -10,22 +10,26 @@ model = AutoModelForImageClassification.from_pretrained(model_name)
10
  feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
11
 
12
  def classify_image(image):
 
 
 
 
13
  # Resize the input image to (224, 224)
14
  image = image.resize((224, 224))
15
-
16
  # Preprocess the image
17
  inputs = feature_extractor(images=image, return_tensors="pt")
18
-
19
  # Make prediction
20
  with torch.no_grad():
21
- logits = model(**inputs).logits
22
-
23
  # Retrieve the highest probability class label
24
  predicted_class_idx = logits.argmax(-1).item()
25
-
26
  # Convert the index to the model's class label
27
  label = model.config.id2label[predicted_class_idx]
28
-
29
  return label
30
 
31
  # Create Gradio interface without specifying shape for the image input
 
10
  feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
11
 
12
  def classify_image(image):
13
+ # Ensure image is in RGB format (in case it's a PNG with alpha channel, for example)
14
+ if image.mode != 'RGB':
15
+ image = image.convert('RGB')
16
+
17
  # Resize the input image to (224, 224)
18
  image = image.resize((224, 224))
19
+
20
  # Preprocess the image
21
  inputs = feature_extractor(images=image, return_tensors="pt")
22
+
23
  # Make prediction
24
  with torch.no_grad():
25
+ logits = model(**inputs.pixel_values).logits # Ensure to access pixel_values
26
+
27
  # Retrieve the highest probability class label
28
  predicted_class_idx = logits.argmax(-1).item()
29
+
30
  # Convert the index to the model's class label
31
  label = model.config.id2label[predicted_class_idx]
32
+
33
  return label
34
 
35
  # Create Gradio interface without specifying shape for the image input