Spaces:
Running
Running
File size: 4,167 Bytes
c06ec80 3c17cec c06ec80 619b8a9 44d687b 0db6636 c06ec80 80bb40d dbc83fe 80bb40d ff14199 209677b ff14199 cef7422 619b8a9 c06ec80 80bb40d c06ec80 27906be ca3b812 20d8a9c 0db6636 631df19 619b8a9 0db6636 619b8a9 803d45d 0db6636 ca3b812 c06ec80 ca3b812 c8c5272 c06ec80 44d687b c06ec80 44d687b c670de6 8157dfb 803d45d c670de6 619b8a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import os
import glob
import json
import warnings
warnings.filterwarnings("ignore")
import torch
import torchvision.transforms as T
from PIL import Image
import gradio as gr
from datetime import datetime
from ultralytics import YOLO
import models
print(f"Is CUDA available: {torch.cuda.is_available()}")
# print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# DINOv2
# Select checkpoint
dinov2_ckpt = ['dinov2_vits14', 'dinov2_vitb14', 'dinov2_vitl14', 'dinov2_vitg14'][1]
dinov2 = torch.hub.load('facebookresearch/dinov2', dinov2_ckpt)
dinov2.to(device)
print()
transform_image = T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def extract_embedding(image):
"""
Predict the identity of an image.
Args:
image: A PIL Image object.
Returns:
A string representing the predicted identity of the image.
"""
# Convert the image to a tensor.
transformed_img = transform_image(image)[:3].unsqueeze(0).to(device)
# Get the embedding of the image.
with torch.no_grad():
embedding = dinov2(transformed_img)
# print(embedding.shape)
embedding = embedding[0].cpu().numpy().tolist()
return {
"embedding": embedding
}
with open("index_to_species.json", "r") as file:
index_to_species_data = file.read()
index_to_species = json.loads(index_to_species_data)
num_classes = len(list(index_to_species.keys()))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the model
linear_model_name = 'linear_2025-08-12.pt'
classify_model = models.LinearClassifier(input_dim=768, output_dim=num_classes)
classify_model.load_state_dict(torch.load(linear_model_name))
detect_model = YOLO('yolov8m_2023-10-23_best.pt')
k = 5
def detect(image):
results = detect_model(image, conf=0.1)
# Get the current time
current_time = datetime.now()
# Format the current time as a string
formatted_time = current_time.strftime("%Y-%m-%d %H:%M:%S")
print(formatted_time)
try:
results = results[0].boxes.xyxy[0].cpu().numpy()
top = int(results[1])
left = int(results[0])
width = int(results[2] - results[0])
height = int(results[3] - results[1])
return {
"top": top,
"left": left,
"width": width,
"height": height
}
except:
return {
"top": 0,
"left": 0,
"width": 0,
"height": 0
}
def classify(image):
image_width, image_height = image.size
detection = detect(image)
if detection["top"] == 0 and detection["left"] == 0 and detection["width"] == 0 and detection["height"] == 0:
return {}
# Crop the image
image = image.crop((detection['left'], detection['top'], detection['left'] + detection['width'], detection['top'] + detection['height']))
# Perform the embedding search
embedding = extract_embedding(image)
embedding = embedding['embedding']
output = classify_model(torch.Tensor(embedding).to(device))
tops = torch.topk(output, k=k).indices
scores = torch.softmax(output, dim=0)[tops]
result = {index_to_species[str(tops[i].item())].replace("_", " "): round(scores[i].item(), 2) for i in range(len(tops))}
sorted_result = {k: v for k, v in sorted(result.items(), key=lambda item: item[1], reverse=True) if v > 0}
# Get the current time
current_time = datetime.now()
# Format the current time as a string
formatted_time = current_time.strftime("%Y-%m-%d %H:%M:%S")
# Print the formatted time
print(f"{formatted_time} {sorted_result}")
return {
"detection": [detection],
"classification": [sorted_result],
"image_width": image_width,
"image_height": image_height
}
title = "🐢"
gr.Interface(
fn=classify,
inputs=gr.Image(type="pil", label="Input Image"),
outputs=[gr.JSON()],
title=title,
).launch() |