PsychicFireSong commited on
Commit
5bbce66
·
verified ·
1 Parent(s): 841fbac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -7
app.py CHANGED
@@ -41,7 +41,7 @@ transform = transforms.Compose([
41
 
42
  # 4. Prediction Function
43
  def predict(image):
44
- """Takes a PIL image and returns a dictionary of top 5 predictions."""
45
  if model is None:
46
  return {"Error": "Model is not loaded. Please check the logs for errors."}
47
 
@@ -50,24 +50,24 @@ def predict(image):
50
  outputs = model(image)
51
  probabilities = torch.nn.functional.softmax(outputs, dim=1)[0]
52
 
53
- # Get top 5 predictions
54
- top5_prob, top5_indices = torch.topk(probabilities, 5)
55
 
56
- confidences = {labels[i]: float(p) for i, p in zip(top5_indices, top5_prob)}
57
 
58
  return confidences
59
 
60
  # 5. Gradio Interface
61
  title = "Bird Species Classifier"
62
- description = "Upload an image of a bird to classify it into one of 200 species. This model is a ConvNeXt V2 Large, fine-tuned on the Caltech-UCSD Birds 200 (CUB-200) dataset."
63
 
64
  iface = gr.Interface(
65
  fn=predict,
66
  inputs=gr.Image(type="pil", label="Upload Bird Image"),
67
- outputs=gr.Label(num_top_classes=5, label="Predictions"),
68
  title=title,
69
  description=description,
70
  )
71
 
72
  if __name__ == "__main__":
73
- iface.launch()
 
41
 
42
  # 4. Prediction Function
43
  def predict(image):
44
+ """Takes a PIL image and returns a dictionary of top 3 predictions."""
45
  if model is None:
46
  return {"Error": "Model is not loaded. Please check the logs for errors."}
47
 
 
50
  outputs = model(image)
51
  probabilities = torch.nn.functional.softmax(outputs, dim=1)[0]
52
 
53
+ # Get top 3 predictions
54
+ top3_prob, top3_indices = torch.topk(probabilities, 3)
55
 
56
+ confidences = {labels[i]: float(p) for i, p in zip(top3_indices, top3_prob)}
57
 
58
  return confidences
59
 
60
  # 5. Gradio Interface
61
  title = "Bird Species Classifier"
62
+ description = "Upload an image of a bird to classify it into one of 200 species. This model is a ConvNeXt V2 Large, fine-tuned on a dataset of 200 bird species."
63
 
64
  iface = gr.Interface(
65
  fn=predict,
66
  inputs=gr.Image(type="pil", label="Upload Bird Image"),
67
+ outputs=gr.Label(num_top_classes=3, label="Predictions"),
68
  title=title,
69
  description=description,
70
  )
71
 
72
  if __name__ == "__main__":
73
+ iface.launch()