fengruilin commited on
Commit
eeda2a7
·
verified ·
1 Parent(s): f4511f0

Trying to show top 5

Browse files
Files changed (1) hide show
  1. app.py +14 -5
app.py CHANGED
@@ -318,10 +318,19 @@ def predict(img: Image.Image):
318
  with torch.no_grad():
319
  logits = model(x)
320
  probs = torch.nn.functional.softmax(logits, dim=1)[0]
321
- top_idx = torch.argmax(probs).item()
322
- top_label = labels[top_idx] if labels[top_idx] is not None else f"Class {top_idx}"
323
- confidence = float(probs[top_idx])
324
- return {top_label: confidence}
 
 
 
 
 
 
 
 
 
325
 
326
  # --------------------------
327
  # 6. Gradio interface
@@ -329,7 +338,7 @@ def predict(img: Image.Image):
329
  demo = gr.Interface(
330
  fn=predict,
331
  inputs=gr.Image(type="pil"),
332
- outputs=gr.Label(num_top_classes=1),
333
  title="🚗Car Model Classifier🚗",
334
  description="⬆Upload a car image and see what it is and how confident our model is on this particular picture!"
335
  )
 
318
  with torch.no_grad():
319
  logits = model(x)
320
  probs = torch.nn.functional.softmax(logits, dim=1)[0]
321
+
322
+ # Get top 5 predictions
323
+ top_5_probs, top_5_indices = torch.topk(probs, 5)
324
+
325
+ # Create dictionary with top 5 predictions
326
+ results = {}
327
+ for i in range(5):
328
+ idx = top_5_indices[i].item()
329
+ label = labels[idx] if idx < len(labels) else f"Class {idx}"
330
+ confidence = float(top_5_probs[i])
331
+ results[label] = confidence
332
+
333
+ return results
334
 
335
  # --------------------------
336
  # 6. Gradio interface
 
338
  demo = gr.Interface(
339
  fn=predict,
340
  inputs=gr.Image(type="pil"),
341
+ outputs=gr.Label(num_top_classes=5),
342
  title="🚗Car Model Classifier🚗",
343
  description="⬆Upload a car image and see what it is and how confident our model is on this particular picture!"
344
  )