Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -41,7 +41,7 @@ transform = transforms.Compose([
|
|
| 41 |
|
| 42 |
# 4. Prediction Function
|
| 43 |
def predict(image):
|
| 44 |
-
"""Takes a PIL image and returns a dictionary of top
|
| 45 |
if model is None:
|
| 46 |
return {"Error": "Model is not loaded. Please check the logs for errors."}
|
| 47 |
|
|
@@ -50,24 +50,24 @@ def predict(image):
|
|
| 50 |
outputs = model(image)
|
| 51 |
probabilities = torch.nn.functional.softmax(outputs, dim=1)[0]
|
| 52 |
|
| 53 |
-
# Get top
|
| 54 |
-
|
| 55 |
|
| 56 |
-
confidences = {labels[i]: float(p) for i, p in zip(
|
| 57 |
|
| 58 |
return confidences
|
| 59 |
|
| 60 |
# 5. Gradio Interface
|
| 61 |
title = "Bird Species Classifier"
|
| 62 |
-
description = "Upload an image of a bird to classify it into one of 200 species. This model is a ConvNeXt V2 Large, fine-tuned on
|
| 63 |
|
| 64 |
iface = gr.Interface(
|
| 65 |
fn=predict,
|
| 66 |
inputs=gr.Image(type="pil", label="Upload Bird Image"),
|
| 67 |
-
outputs=gr.Label(num_top_classes=
|
| 68 |
title=title,
|
| 69 |
description=description,
|
| 70 |
)
|
| 71 |
|
| 72 |
if __name__ == "__main__":
|
| 73 |
-
iface.launch()
|
|
|
|
| 41 |
|
| 42 |
# 4. Prediction Function
|
| 43 |
def predict(image):
|
| 44 |
+
"""Takes a PIL image and returns a dictionary of top 3 predictions."""
|
| 45 |
if model is None:
|
| 46 |
return {"Error": "Model is not loaded. Please check the logs for errors."}
|
| 47 |
|
|
|
|
| 50 |
outputs = model(image)
|
| 51 |
probabilities = torch.nn.functional.softmax(outputs, dim=1)[0]
|
| 52 |
|
| 53 |
+
# Get top 3 predictions
|
| 54 |
+
top3_prob, top3_indices = torch.topk(probabilities, 3)
|
| 55 |
|
| 56 |
+
confidences = {labels[i]: float(p) for i, p in zip(top3_indices, top3_prob)}
|
| 57 |
|
| 58 |
return confidences
|
| 59 |
|
| 60 |
# 5. Gradio Interface
|
| 61 |
title = "Bird Species Classifier"
|
| 62 |
+
description = "Upload an image of a bird to classify it into one of 200 species. This model is a ConvNeXt V2 Large, fine-tuned on a dataset of 200 bird species."
|
| 63 |
|
| 64 |
iface = gr.Interface(
|
| 65 |
fn=predict,
|
| 66 |
inputs=gr.Image(type="pil", label="Upload Bird Image"),
|
| 67 |
+
outputs=gr.Label(num_top_classes=3, label="Predictions"),
|
| 68 |
title=title,
|
| 69 |
description=description,
|
| 70 |
)
|
| 71 |
|
| 72 |
if __name__ == "__main__":
|
| 73 |
+
iface.launch()
|