jjuarez commited on
Commit
d0eb989
·
verified ·
1 Parent(s): 84d5fc9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -2
app.py CHANGED
@@ -21,14 +21,27 @@ def classify_image(image):
21
  outputs = model(**inputs)
22
  logits = outputs.logits
23
 
24
- # Retrieve the highest probability class label
25
  predicted_class_idx = logits.argmax(-1).item()
26
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  # Convert the index to the model's class label
28
- label = model.config.id2label[predicted_class_idx]
29
 
30
  return label
31
 
 
32
  # Create Gradio interface
33
  iface = gr.Interface(fn=classify_image,
34
  inputs=gr.Image(), # Accepts image of any size
 
21
  outputs = model(**inputs)
22
  logits = outputs.logits
23
 
24
+ # Retrieve the highest probability class label index
25
  predicted_class_idx = logits.argmax(-1).item()
26
 
27
+ # Define a manual mapping of label indices to human-readable labels
28
+ # This should be adjusted based on your dataset and labels
29
+ index_to_label = {
30
+ "0": "Aluminium",
31
+ "1": "Batteries",
32
+ "2": "Cardboard",
33
+ "3": "Glass",
34
+ "4": "Hard Plastic",
35
+ "5": "Paper",
36
+ "6": "Soft Plastics"
37
+ }
38
+
39
  # Convert the index to the model's class label
40
+ label = index_to_label.get(predicted_class_idx, "Unknown Label")
41
 
42
  return label
43
 
44
+
45
  # Create Gradio interface
46
  iface = gr.Interface(fn=classify_image,
47
  inputs=gr.Image(), # Accepts image of any size