panda1835's picture
Update app.py
206c7b2
raw
history blame
1.14 kB
import os
import glob
import json
import warnings
warnings.filterwarnings("ignore")
import torch
from PIL import Image
import gradio as gr
import models
with open("index_to_species.json", "r") as file:
index_to_species_data = file.read()
index_to_species = json.loads(index_to_species_data)
num_classes = len(list(index_to_species.keys()))
# Load the model
classify_model = models.DinoVisionTransformerClassifier(num_classes)
classify_model.load_state_dict(torch.load("best_dinov2_both_2023-11-21_07-44-35.pth"), map_location=torch.device('cpu'))
classify_model.eval()
def classify(image):
output = classify_model(image)[0]
tops = torch.topk(output, k=k).indices
scores = torch.softmax(output, dim=0)[tops]
result = {index_to_species[str(tops[i].item())]: round(scores[i].item(), 2) for i in range(len(tops))}
sorted_result = {k: v for k, v in sorted(result.items(), key=lambda item: item[1], reverse=True) if v > 0}
return sorted_result
title = "🐢"
gr.Interface(
fn=classify,
inputs=gr.Image(type="pil", label="Input Image"),
outputs=[gr.JSON()],
title=title,
).launch()