Spaces:
Running
Running
| import os | |
| import glob | |
| import json | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| import torch | |
| import torchvision.transforms as T | |
| from PIL import Image | |
| import gradio as gr | |
| from datetime import datetime | |
| from ultralytics import YOLO | |
| import models | |
| print(f"Is CUDA available: {torch.cuda.is_available()}") | |
| # print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # DINOv2 | |
| # Select checkpoint | |
| dinov2_ckpt = ['dinov2_vits14', 'dinov2_vitb14', 'dinov2_vitl14', 'dinov2_vitg14'][1] | |
| dinov2 = torch.hub.load('facebookresearch/dinov2', dinov2_ckpt) | |
| dinov2.to(device) | |
| print() | |
| transform_image = T.Compose([ | |
| T.Resize((224, 224)), | |
| T.ToTensor(), | |
| T.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| def extract_embedding(image): | |
| """ | |
| Predict the identity of an image. | |
| Args: | |
| image: A PIL Image object. | |
| Returns: | |
| A string representing the predicted identity of the image. | |
| """ | |
| # Convert the image to a tensor. | |
| transformed_img = transform_image(image)[:3].unsqueeze(0).to(device) | |
| # Get the embedding of the image. | |
| with torch.no_grad(): | |
| embedding = dinov2(transformed_img) | |
| # print(embedding.shape) | |
| embedding = embedding[0].cpu().numpy().tolist() | |
| return { | |
| "embedding": embedding | |
| } | |
| with open("index_to_species.json", "r") as file: | |
| index_to_species_data = file.read() | |
| index_to_species = json.loads(index_to_species_data) | |
| num_classes = len(list(index_to_species.keys())) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load the model | |
| linear_model_name = 'linear_2025-08-12.pt' | |
| classify_model = models.LinearClassifier(input_dim=768, output_dim=num_classes) | |
| classify_model.load_state_dict(torch.load(linear_model_name)) | |
| detect_model = YOLO('yolov8m_2023-10-23_best.pt') | |
| k = 5 | |
| def detect(image): | |
| results = detect_model(image, conf=0.1) | |
| # Get the current time | |
| current_time = datetime.now() | |
| # Format the current time as a string | |
| formatted_time = current_time.strftime("%Y-%m-%d %H:%M:%S") | |
| print(formatted_time) | |
| try: | |
| results = results[0].boxes.xyxy[0].cpu().numpy() | |
| top = int(results[1]) | |
| left = int(results[0]) | |
| width = int(results[2] - results[0]) | |
| height = int(results[3] - results[1]) | |
| return { | |
| "top": top, | |
| "left": left, | |
| "width": width, | |
| "height": height | |
| } | |
| except: | |
| return { | |
| "top": 0, | |
| "left": 0, | |
| "width": 0, | |
| "height": 0 | |
| } | |
| def classify(image): | |
| image_width, image_height = image.size | |
| detection = detect(image) | |
| if detection["top"] == 0 and detection["left"] == 0 and detection["width"] == 0 and detection["height"] == 0: | |
| return {} | |
| # Crop the image | |
| image = image.crop((detection['left'], detection['top'], detection['left'] + detection['width'], detection['top'] + detection['height'])) | |
| # Perform the embedding search | |
| embedding = extract_embedding(image) | |
| embedding = embedding['embedding'] | |
| output = classify_model(torch.Tensor(embedding).to(device)) | |
| tops = torch.topk(output, k=k).indices | |
| scores = torch.softmax(output, dim=0)[tops] | |
| result = {index_to_species[str(tops[i].item())].replace("_", " "): round(scores[i].item(), 2) for i in range(len(tops))} | |
| sorted_result = {k: v for k, v in sorted(result.items(), key=lambda item: item[1], reverse=True) if v > 0} | |
| # Get the current time | |
| current_time = datetime.now() | |
| # Format the current time as a string | |
| formatted_time = current_time.strftime("%Y-%m-%d %H:%M:%S") | |
| # Print the formatted time | |
| print(f"{formatted_time} {sorted_result}") | |
| return { | |
| "detection": [detection], | |
| "classification": [sorted_result], | |
| "image_width": image_width, | |
| "image_height": image_height | |
| } | |
| title = "🐢" | |
| gr.Interface( | |
| fn=classify, | |
| inputs=gr.Image(type="pil", label="Input Image"), | |
| outputs=[gr.JSON()], | |
| title=title, | |
| ).launch() |