File size: 4,167 Bytes
c06ec80
 
 
 
 
 
 
 
3c17cec
c06ec80
619b8a9
44d687b
0db6636
c06ec80
 
80bb40d
dbc83fe
80bb40d
ff14199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209677b
 
 
ff14199
 
 
 
cef7422
 
 
619b8a9
c06ec80
 
80bb40d
 
c06ec80
27906be
ca3b812
20d8a9c
0db6636
631df19
619b8a9
0db6636
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
619b8a9
803d45d
0db6636
 
 
 
 
 
 
 
ca3b812
 
 
c06ec80
 
ca3b812
c8c5272
c06ec80
44d687b
 
c06ec80
44d687b
 
 
 
 
c670de6
8157dfb
 
803d45d
 
c670de6
619b8a9
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import os
import glob
import json
import warnings

warnings.filterwarnings("ignore")

import torch
import torchvision.transforms as T
from PIL import Image
import gradio as gr
from datetime import datetime
from ultralytics import YOLO
import models

print(f"Is CUDA available: {torch.cuda.is_available()}")
# print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# DINOv2

# Select checkpoint
dinov2_ckpt = ['dinov2_vits14', 'dinov2_vitb14', 'dinov2_vitl14', 'dinov2_vitg14'][1]
dinov2 = torch.hub.load('facebookresearch/dinov2', dinov2_ckpt)

dinov2.to(device)
print()

transform_image = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])
])

def extract_embedding(image):
    """
    Predict the identity of an image.
    Args:
    image: A PIL Image object.
    Returns:
    A string representing the predicted identity of the image.
    """
    
    # Convert the image to a tensor.
    transformed_img = transform_image(image)[:3].unsqueeze(0).to(device)
    
    # Get the embedding of the image.
    with torch.no_grad():
        embedding = dinov2(transformed_img)
        # print(embedding.shape)
        embedding = embedding[0].cpu().numpy().tolist()
    return {
      "embedding": embedding
    }
    
with open("index_to_species.json", "r") as file:
    index_to_species_data = file.read()
index_to_species = json.loads(index_to_species_data)

num_classes = len(list(index_to_species.keys()))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the model
linear_model_name = 'linear_2025-08-12.pt'
classify_model = models.LinearClassifier(input_dim=768, output_dim=num_classes)
classify_model.load_state_dict(torch.load(linear_model_name))
detect_model = YOLO('yolov8m_2023-10-23_best.pt')
k = 5

def detect(image):
    results = detect_model(image, conf=0.1)
    # Get the current time
    current_time = datetime.now()
    # Format the current time as a string
    formatted_time = current_time.strftime("%Y-%m-%d %H:%M:%S")
    print(formatted_time)
    try:
        results = results[0].boxes.xyxy[0].cpu().numpy()
        top = int(results[1])
        left = int(results[0])
        width = int(results[2] - results[0])
        height = int(results[3] - results[1])
        return {
            "top": top,
            "left": left,
            "width": width,
            "height": height
        }
    except:
        return {
            "top": 0,
            "left": 0,
            "width": 0,
            "height": 0
        }

def classify(image):
    image_width, image_height = image.size
    detection = detect(image)
    
    if detection["top"] == 0 and detection["left"] == 0 and detection["width"] == 0 and detection["height"] == 0:
        return {}
    # Crop the image
    image = image.crop((detection['left'], detection['top'], detection['left'] + detection['width'], detection['top'] + detection['height']))
    
    # Perform the embedding search
    embedding = extract_embedding(image)
    embedding = embedding['embedding']
    output = classify_model(torch.Tensor(embedding).to(device))
    tops = torch.topk(output, k=k).indices
    scores = torch.softmax(output, dim=0)[tops]

    result = {index_to_species[str(tops[i].item())].replace("_", " "): round(scores[i].item(), 2) for i in range(len(tops))}
    sorted_result = {k: v for k, v in sorted(result.items(), key=lambda item: item[1], reverse=True) if v > 0}
    # Get the current time
    current_time = datetime.now()
    
    # Format the current time as a string
    formatted_time = current_time.strftime("%Y-%m-%d %H:%M:%S")
    
    # Print the formatted time
    print(f"{formatted_time} {sorted_result}")
    return {
        "detection": [detection],
        "classification": [sorted_result],
        "image_width": image_width,
        "image_height": image_height
    }

    
title = "🐢"

gr.Interface(
    fn=classify, 
    inputs=gr.Image(type="pil", label="Input Image"),
    outputs=[gr.JSON()],
    title=title,
).launch()