Spaces:
Sleeping
Sleeping
| import os | |
| import glob | |
| import json | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| import torch | |
| from PIL import Image | |
| import gradio as gr | |
| import models | |
| with open("index_to_species.json", "r") as file: | |
| index_to_species_data = file.read() | |
| index_to_species = json.loads(index_to_species_data) | |
| num_classes = len(list(index_to_species.keys())) | |
| # Load the model | |
| classify_model = models.DinoVisionTransformerClassifier(num_classes) | |
| classify_model.load_state_dict(torch.load("best_dinov2_both_2023-11-21_07-44-35.pth"), map_location=torch.device('cpu')) | |
| classify_model.eval() | |
| def classify(image): | |
| output = classify_model(image)[0] | |
| tops = torch.topk(output, k=k).indices | |
| scores = torch.softmax(output, dim=0)[tops] | |
| result = {index_to_species[str(tops[i].item())]: round(scores[i].item(), 2) for i in range(len(tops))} | |
| sorted_result = {k: v for k, v in sorted(result.items(), key=lambda item: item[1], reverse=True) if v > 0} | |
| return sorted_result | |
| title = "🐢" | |
| gr.Interface( | |
| fn=classify, | |
| inputs=gr.Image(type="pil", label="Input Image"), | |
| outputs=[gr.JSON()], | |
| title=title, | |
| ).launch() |