vnturtle-api / app.py
panda1835's picture
Update app.py
c8c5272
raw
history blame
1.17 kB
import os
import glob
import json
import warnings
warnings.filterwarnings("ignore")
import torch
from PIL import Image
import gradio as gr
import models
with open("index_to_species.json", "r") as file:
index_to_species_data = file.read()
index_to_species = json.loads(index_to_species_data)
num_classes = len(list(index_to_species.keys()))
# Load the model
classify_model = models.DinoVisionTransformerClassifier(num_classes)
classify_model.load_state_dict(torch.load("best_dinov2_both_2023-11-21_07-44-35.pth", map_location=torch.device('cpu')))
classify_model.eval()
k = 5
def classify(image):
output = classify_model(image)[0]
tops = torch.topk(output, k=k).indices
scores = torch.softmax(output, dim=0)[tops]
result = {index_to_species[str(tops[i].item())].replace("_", " "): round(scores[i].item(), 2) for i in range(len(tops))}
sorted_result = {k: v for k, v in sorted(result.items(), key=lambda item: item[1], reverse=True) if v > 0}
return sorted_result
title = "🐢"
gr.Interface(
fn=classify,
inputs=gr.Image(type="pil", label="Input Image"),
outputs=[gr.JSON()],
title=title,
).launch()