import os import glob import json import warnings warnings.filterwarnings("ignore") import torch import torchvision.transforms as T from PIL import Image import gradio as gr from datetime import datetime from ultralytics import YOLO import models print(f"Is CUDA available: {torch.cuda.is_available()}") # print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # DINOv2 # Select checkpoint dinov2_ckpt = ['dinov2_vits14', 'dinov2_vitb14', 'dinov2_vitl14', 'dinov2_vitg14'][1] dinov2 = torch.hub.load('facebookresearch/dinov2', dinov2_ckpt) dinov2.to(device) print() transform_image = T.Compose([ T.Resize((224, 224)), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def extract_embedding(image): """ Predict the identity of an image. Args: image: A PIL Image object. Returns: A string representing the predicted identity of the image. """ # Convert the image to a tensor. transformed_img = transform_image(image)[:3].unsqueeze(0).to(device) # Get the embedding of the image. with torch.no_grad(): embedding = dinov2(transformed_img) # print(embedding.shape) embedding = embedding[0].cpu().numpy().tolist() return { "embedding": embedding } with open("index_to_species.json", "r") as file: index_to_species_data = file.read() index_to_species = json.loads(index_to_species_data) num_classes = len(list(index_to_species.keys())) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load the model linear_model_name = 'linear_2025-08-12.pt' classify_model = models.LinearClassifier(input_dim=768, output_dim=num_classes) classify_model.load_state_dict(torch.load(linear_model_name)) detect_model = YOLO('yolov8m_2023-10-23_best.pt') k = 5 def detect(image): results = detect_model(image, conf=0.1) # Get the current time current_time = datetime.now() # Format the current time as a string formatted_time = current_time.strftime("%Y-%m-%d %H:%M:%S") print(formatted_time) try: results = results[0].boxes.xyxy[0].cpu().numpy() top = int(results[1]) left = int(results[0]) width = int(results[2] - results[0]) height = int(results[3] - results[1]) return { "top": top, "left": left, "width": width, "height": height } except: return { "top": 0, "left": 0, "width": 0, "height": 0 } def classify(image): image_width, image_height = image.size detection = detect(image) if detection["top"] == 0 and detection["left"] == 0 and detection["width"] == 0 and detection["height"] == 0: return {} # Crop the image image = image.crop((detection['left'], detection['top'], detection['left'] + detection['width'], detection['top'] + detection['height'])) # Perform the embedding search embedding = extract_embedding(image) embedding = embedding['embedding'] output = classify_model(torch.Tensor(embedding).to(device)) tops = torch.topk(output, k=k).indices scores = torch.softmax(output, dim=0)[tops] result = {index_to_species[str(tops[i].item())].replace("_", " "): round(scores[i].item(), 2) for i in range(len(tops))} sorted_result = {k: v for k, v in sorted(result.items(), key=lambda item: item[1], reverse=True) if v > 0} # Get the current time current_time = datetime.now() # Format the current time as a string formatted_time = current_time.strftime("%Y-%m-%d %H:%M:%S") # Print the formatted time print(f"{formatted_time} {sorted_result}") return { "detection": [detection], "classification": [sorted_result], "image_width": image_width, "image_height": image_height } title = "🐢" gr.Interface( fn=classify, inputs=gr.Image(type="pil", label="Input Image"), outputs=[gr.JSON()], title=title, ).launch()