vnturtle-api / app.py
panda1835's picture
Update app.py
27906be verified
import os
import glob
import json
import warnings
warnings.filterwarnings("ignore")
import torch
import torchvision.transforms as T
from PIL import Image
import gradio as gr
from datetime import datetime
from ultralytics import YOLO
import models
print(f"Is CUDA available: {torch.cuda.is_available()}")
# print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# DINOv2
# Select checkpoint
dinov2_ckpt = ['dinov2_vits14', 'dinov2_vitb14', 'dinov2_vitl14', 'dinov2_vitg14'][1]
dinov2 = torch.hub.load('facebookresearch/dinov2', dinov2_ckpt)
dinov2.to(device)
print()
transform_image = T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def extract_embedding(image):
"""
Predict the identity of an image.
Args:
image: A PIL Image object.
Returns:
A string representing the predicted identity of the image.
"""
# Convert the image to a tensor.
transformed_img = transform_image(image)[:3].unsqueeze(0).to(device)
# Get the embedding of the image.
with torch.no_grad():
embedding = dinov2(transformed_img)
# print(embedding.shape)
embedding = embedding[0].cpu().numpy().tolist()
return {
"embedding": embedding
}
with open("index_to_species.json", "r") as file:
index_to_species_data = file.read()
index_to_species = json.loads(index_to_species_data)
num_classes = len(list(index_to_species.keys()))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the model
linear_model_name = 'linear_2025-08-12.pt'
classify_model = models.LinearClassifier(input_dim=768, output_dim=num_classes)
classify_model.load_state_dict(torch.load(linear_model_name))
detect_model = YOLO('yolov8m_2023-10-23_best.pt')
k = 5
def detect(image):
results = detect_model(image, conf=0.1)
# Get the current time
current_time = datetime.now()
# Format the current time as a string
formatted_time = current_time.strftime("%Y-%m-%d %H:%M:%S")
print(formatted_time)
try:
results = results[0].boxes.xyxy[0].cpu().numpy()
top = int(results[1])
left = int(results[0])
width = int(results[2] - results[0])
height = int(results[3] - results[1])
return {
"top": top,
"left": left,
"width": width,
"height": height
}
except:
return {
"top": 0,
"left": 0,
"width": 0,
"height": 0
}
def classify(image):
image_width, image_height = image.size
detection = detect(image)
if detection["top"] == 0 and detection["left"] == 0 and detection["width"] == 0 and detection["height"] == 0:
return {}
# Crop the image
image = image.crop((detection['left'], detection['top'], detection['left'] + detection['width'], detection['top'] + detection['height']))
# Perform the embedding search
embedding = extract_embedding(image)
embedding = embedding['embedding']
output = classify_model(torch.Tensor(embedding).to(device))
tops = torch.topk(output, k=k).indices
scores = torch.softmax(output, dim=0)[tops]
result = {index_to_species[str(tops[i].item())].replace("_", " "): round(scores[i].item(), 2) for i in range(len(tops))}
sorted_result = {k: v for k, v in sorted(result.items(), key=lambda item: item[1], reverse=True) if v > 0}
# Get the current time
current_time = datetime.now()
# Format the current time as a string
formatted_time = current_time.strftime("%Y-%m-%d %H:%M:%S")
# Print the formatted time
print(f"{formatted_time} {sorted_result}")
return {
"detection": [detection],
"classification": [sorted_result],
"image_width": image_width,
"image_height": image_height
}
title = "🐢"
gr.Interface(
fn=classify,
inputs=gr.Image(type="pil", label="Input Image"),
outputs=[gr.JSON()],
title=title,
).launch()