panda1835's picture
Update app.py
44d687b
raw
history blame
1.68 kB
import os
import glob
import json
import warnings
warnings.filterwarnings("ignore")
import torch
from PIL import Image
import gradio as gr
from datetime import datetime
import models
print(f"Is CUDA available: {torch.cuda.is_available()}")
# True
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
# Tesla T4
with open("index_to_species.json", "r") as file:
index_to_species_data = file.read()
index_to_species = json.loads(index_to_species_data)
num_classes = len(list(index_to_species.keys()))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the model
classify_model = models.DinoVisionTransformerClassifier(num_classes)
classify_model.load_state_dict(torch.load("best_dinov2_both_2023-11-21_07-44-35.pth", map_location=torch.device(device)))
classify_model.eval()
k = 5
def classify(image):
output = classify_model(image)[0]
tops = torch.topk(output, k=k).indices
scores = torch.softmax(output, dim=0)[tops]
result = {index_to_species[str(tops[i].item())].replace("_", " "): round(scores[i].item(), 2) for i in range(len(tops))}
sorted_result = {k: v for k, v in sorted(result.items(), key=lambda item: item[1], reverse=True) if v > 0}
# Get the current time
current_time = datetime.now()
# Format the current time as a string
formatted_time = current_time.strftime("%Y-%m-%d %H:%M:%S")
# Print the formatted time
print(f"{formatted_time} {sorted_result}")
return sorted_result
title = "🐢"
gr.Interface(
fn=classify,
inputs=gr.Image(type="pil", label="Input Image"),
outputs=[gr.JSON()],
title=title,
).launch()