panda1835's picture
Update app.py
dbc83fe
import os
import glob
import json
import warnings
warnings.filterwarnings("ignore")
import torch
from PIL import Image
import gradio as gr
from datetime import datetime
import models
print(f"Is CUDA available: {torch.cuda.is_available()}")
# print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
with open("index_to_species.json", "r") as file:
index_to_species_data = file.read()
index_to_species = json.loads(index_to_species_data)
num_classes = len(list(index_to_species.keys()))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the model
classify_model = models.DinoVisionTransformerClassifier(num_classes)
classify_model = classify_model.to(device)
classify_model.load_state_dict(torch.load("best_dinov2_both_2023-11-21_07-44-35.pth", map_location=torch.device(device)))
classify_model.eval()
k = 5
def classify(image):
output = classify_model(image)[0]
tops = torch.topk(output, k=k).indices
scores = torch.softmax(output, dim=0)[tops]
result = {index_to_species[str(tops[i].item())].replace("_", " "): round(scores[i].item(), 2) for i in range(len(tops))}
sorted_result = {k: v for k, v in sorted(result.items(), key=lambda item: item[1], reverse=True) if v > 0}
# Get the current time
current_time = datetime.now()
# Format the current time as a string
formatted_time = current_time.strftime("%Y-%m-%d %H:%M:%S")
# Print the formatted time
print(f"{formatted_time} {sorted_result}")
return sorted_result
title = "🐢"
gr.Interface(
fn=classify,
inputs=gr.Image(type="pil", label="Input Image"),
outputs=[gr.JSON()],
title=title,
).launch()