Spaces:
Build error
Build error
| import random | |
| import os | |
| import numpy as np | |
| import gradio as gr | |
| from PIL import Image | |
| from groundingdino.util.inference import load_model as load_groundingdino_model | |
| from groundingdino.util.inference import predict as grounding_dino_predict | |
| import groundingdino.datasets.transforms as T | |
| import torch | |
| from torchvision.ops import box_convert | |
| from torchvision.transforms.functional import to_tensor | |
| from torchvision.transforms import GaussianBlur | |
| import time | |
| # ---------------------------- | |
| # DINOv2 Classifier Imports | |
| # ---------------------------- | |
| import torch.nn as nn | |
| from torchvision import transforms | |
| import pandas as pd | |
| from typing import List, Tuple | |
| import copy | |
| import matplotlib.pyplot as plt | |
| # ---------------------------- | |
| # DINOv2 Classifier Definitions | |
| # ---------------------------- | |
| # 1. PadToSquare Class | |
| class PadToSquare: | |
| """ | |
| Pads an image to make it square by adding padding to the shorter side. | |
| """ | |
| def __init__(self, fill=0): | |
| self.fill = fill | |
| def __call__(self, img): | |
| w, h = img.size | |
| max_wh = max(w, h) | |
| hp = (max_wh - w) // 2 | |
| vp = (max_wh - h) // 2 | |
| padding = (hp, vp, max_wh - w - hp, max_wh - h - vp) | |
| return transforms.functional.pad(img, padding, fill=self.fill, padding_mode='constant') | |
| # 2. DinoVisionTransformerClassifier Class (Modified to include entropy-based approach) | |
| class DinoVisionTransformerClassifier(nn.Module): | |
| """ | |
| DINOv2 Vision Transformer-based classifier with entropy-based "Unknown" class handling. | |
| """ | |
| def __init__(self, num_classes, hidden_size=256, dropout_p=0.5, negative_slope=0.01): | |
| super(DinoVisionTransformerClassifier, self).__init__() | |
| # Load DINOv2 model from torch.hub | |
| self.transformer = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14', pretrained=True) | |
| self.transformer.norm = nn.Identity() # Remove existing normalization if necessary | |
| # Batch Normalization after transformer | |
| self.batch_norm1 = nn.BatchNorm1d(384) # 384 is the embedding size | |
| # Classification head | |
| self.classifier = nn.Sequential( | |
| nn.Linear(384, hidden_size), | |
| nn.BatchNorm1d(hidden_size), | |
| nn.LeakyReLU(negative_slope=negative_slope, inplace=True), | |
| nn.Dropout(p=dropout_p), | |
| nn.Linear(hidden_size, num_classes) | |
| ) | |
| # Initialize weights | |
| self._initialize_weights() | |
| def forward(self, x): | |
| features = self.transformer(x) # Forward pass through the transformer | |
| features = self.batch_norm1(features) # Apply Batch Normalization | |
| logits = self.classifier(features) # Forward pass through the classification head | |
| return logits, features # Return both logits and features | |
| def _initialize_weights(self): | |
| # Initialize weights of the classifier layers | |
| for m in self.classifier.modules(): | |
| if isinstance(m, nn.Linear): | |
| nn.init.kaiming_normal_(m.weight, a=0.01, mode='fan_in', nonlinearity='leaky_relu') | |
| if m.bias is not None: | |
| nn.init.zeros_(m.bias) | |
| elif isinstance(m, nn.BatchNorm1d): | |
| nn.init.ones_(m.weight) | |
| nn.init.zeros_(m.bias) | |
| # 3. Model Loading Function (Updated for Entropy-Based Classifier) | |
| def load_model(model_path, device): | |
| """ | |
| Loads the trained model and class information from the saved checkpoint. | |
| Args: | |
| model_path (str): Path to the saved .pth model file. | |
| device (torch.device): Device to load the model onto. | |
| Returns: | |
| model (nn.Module): The loaded PyTorch model. | |
| class_names (List[str]): List of class names. | |
| """ | |
| if not os.path.exists(model_path): | |
| raise FileNotFoundError(f"Model file '{model_path}' does not exist.") | |
| checkpoint = torch.load(model_path, map_location=device) | |
| class_names = checkpoint['class_names'] | |
| num_classes = len(class_names) | |
| # Initialize the model architecture | |
| model = DinoVisionTransformerClassifier(num_classes=num_classes) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model.to(device) | |
| model.eval() # Set to evaluation mode | |
| return model, class_names | |
| # 4. Image Preprocessing Function (Updated to accept PIL Image directly) | |
| def preprocess_image_pil(pil_image: Image.Image, transform: transforms.Compose) -> torch.Tensor: | |
| """ | |
| Applies the transformation pipeline to a PIL image. | |
| Args: | |
| pil_image (PIL.Image.Image): The image to preprocess. | |
| transform (transforms.Compose): The transformation pipeline. | |
| Returns: | |
| torch.Tensor: The preprocessed image tensor. | |
| """ | |
| return transform(pil_image) | |
| # ---------------------------- | |
| # Gradio App Definitions | |
| # ---------------------------- | |
| # Automatically set device based on availability | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {DEVICE}") | |
| PROMPT = "bug" | |
| # Define a custom transform for Gaussian blur (Unused in current context) | |
| def gaussian_blur(x, p=0.5, kernel_size_min=3, kernel_size_max=20, sigma_min=0.1, sigma_max=3): | |
| if x.ndim == 4: | |
| for i in range(x.shape[0]): | |
| if random.random() < p: | |
| kernel_size = random.randrange(kernel_size_min, kernel_size_max + 1, 2) | |
| sigma = random.uniform(sigma_min, sigma_max) | |
| x[i] = GaussianBlur(kernel_size=kernel_size, sigma=sigma)(x[i]) | |
| return x | |
| # Custom Label Function (Unused in current context) | |
| def custom_label_func(fpath): | |
| # this directs the labels to be 2 levels up from the image folder | |
| label = fpath.parents[2].name | |
| return label | |
| # Image loading function for GroundingDINO | |
| def load_image(image_source): | |
| transform = T.Compose( | |
| [ | |
| T.RandomResize([800], max_size=1333), | |
| T.ToTensor(), | |
| T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
| ] | |
| ) | |
| image_source = image_source.convert("RGB") | |
| image_transformed, _ = transform(image_source, None) | |
| return image_transformed | |
| # Load GroundingDINO object detection model | |
| od_model = load_groundingdino_model( | |
| model_checkpoint_path="groundingdino_swint_ogc.pth", | |
| model_config_path="GroundingDINO_SwinT_OGC.cfg.py", | |
| device=DEVICE) | |
| print("Object detection model loaded") | |
| # Load DINOv2 classifier model (Updated to use the entropy-based classifier) | |
| # Update MODEL_PATH to the path where your DINOv2 model checkpoint is stored | |
| MODEL_PATH = 'dinov2_classifier_with_vos_unsure.pth' # Updated model path | |
| dinov2_model, class_names = load_model(MODEL_PATH, torch.device(DEVICE)) | |
| print(f"DINOv2 Classification model loaded with {len(class_names)} classes.") | |
| # Optionally, append "Unknown" to class names if needed | |
| # Removed the line that appends "Unknown" as the model handles it via thresholding | |
| # Replace specific class names if necessary | |
| # Example: Replace "Scolotodes_schwarzi" with "Scolytodes_glaber" | |
| target = "Scolotodes_schwarzi" | |
| if target in class_names: | |
| idx = class_names.index(target) | |
| class_names[idx] = "Scolytodes_glaber" | |
| print(f"Replaced '{target}' with 'Scolytodes_glaber' in class names.") | |
| else: | |
| print(f"'{target}' not found in class names. No replacement made.") | |
| # Define the transformation pipeline for DINOv2 model | |
| dinov2_transform = transforms.Compose([ | |
| transforms.Resize(224), # Resize smaller edge to 224 | |
| PadToSquare(), # Pad to make the image square | |
| transforms.Resize((224, 224)), # Resize to 224x224 | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], # Normalize with ImageNet mean | |
| [0.229, 0.224, 0.225]) # Normalize with ImageNet std | |
| ]) | |
| # Object Detection Function | |
| def detect_objects(og_image, model=od_model, prompt="bug . insect", device="cpu"): | |
| TEXT_PROMPT = prompt | |
| BOX_THRESHOLD = 0.15 # 35 Adjusted back to original value | |
| TEXT_THRESHOLD = 0.15 # 25 Adjusted back to original value | |
| DEVICE = device # cuda or cpu | |
| # Convert numpy array to PIL Image if needed | |
| if isinstance(og_image, np.ndarray): | |
| og_image_obj = Image.fromarray(og_image) | |
| else: | |
| og_image_obj = og_image # Assuming og_image is already a PIL Image | |
| # Transform the image | |
| image_transformed = load_image(image_source = og_image_obj) | |
| # Model prediction | |
| boxes, logits, phrases = grounding_dino_predict( | |
| model=model, | |
| image=image_transformed, | |
| caption=TEXT_PROMPT, | |
| box_threshold=BOX_THRESHOLD, | |
| text_threshold=TEXT_THRESHOLD, | |
| device=DEVICE) | |
| # Use og_image_obj directly for further processing | |
| width, height = og_image_obj.size # Corrected to (width, height) | |
| boxes_norm = boxes * torch.Tensor([width, height, width, height]) | |
| xyxy = box_convert( | |
| boxes=boxes_norm, | |
| in_fmt="cxcywh", | |
| out_fmt="xyxy").numpy() | |
| img_lst = [] | |
| for i in range(len(boxes_norm)): | |
| crop_img = og_image_obj.crop((xyxy[i])) | |
| img_lst.append(crop_img) | |
| print(f"Detected {len(img_lst)} objects.") | |
| return img_lst | |
| # Inference/Class Prediction Function using the Entropy-Based DINOv2 Classifier | |
| def classify_beetle(img: Image.Image, threshold=75.0): | |
| """ | |
| Classifies the input image using the DINOv2 classifier with entropy-based "Unknown" class. | |
| Args: | |
| img (PIL.Image.Image): The image to classify. | |
| threshold (float): Confidence threshold to assign "Unknown". | |
| Returns: | |
| dict: The top 3 class labels with their corresponding confidence scores and "Unknown" if applicable. | |
| """ | |
| # Preprocess the image | |
| input_tensor = preprocess_image_pil(img, dinov2_transform).unsqueeze(0).to(torch.device(DEVICE)) | |
| print(f"Input tensor shape: {input_tensor.shape}") | |
| with torch.no_grad(): | |
| outputs, _ = dinov2_model(input_tensor) | |
| print(f"Model outputs: {outputs}") | |
| probabilities = torch.softmax(outputs, dim=1).cpu().numpy()[0] # p(x) in [0,1] | |
| print(f"Probabilities (0-1 scale): {probabilities}") | |
| # Calculate entropy | |
| # Adding a small epsilon to avoid log(0) | |
| epsilon = 1e-12 | |
| entropy = -np.sum(probabilities * np.log(probabilities + epsilon)) | |
| # Maximum entropy for uniform distribution | |
| max_entropy = -np.sum((1.0 / len(probabilities)) * np.log(1.0 / len(probabilities))) | |
| normalized_entropy = entropy / max_entropy # Normalize between 0 and 1 | |
| unknown_prob = normalized_entropy | |
| print(f"Entropy: {entropy}, Normalized Entropy: {normalized_entropy}, Unknown Probability: {unknown_prob}") | |
| # Convert probabilities to percentage for display | |
| probabilities_percent = np.around(probabilities * 100, decimals=1) | |
| print(f"Probabilities (Percentage): {probabilities_percent}") | |
| # Get top 3 classes | |
| top_indices = np.argsort(probabilities_percent)[-3:][::-1] # Indices of top 3 classes | |
| top_probs = probabilities_percent[top_indices] | |
| top_classes = [class_names[i] for i in top_indices] | |
| # Initialize conf_dict with top 3 classes | |
| conf_dict = {top_classes[i]: float(top_probs[i]) for i in range(len(top_classes))} | |
| # Assign "Unknown" based on entropy and threshold | |
| if top_probs[0] < threshold: | |
| conf_dict["Unknown"] = float(np.around(unknown_prob, decimals=1)) | |
| print(f"Conf_dict: {conf_dict}") | |
| return conf_dict | |
| # Main Prediction Function for Gradio | |
| def predict_beetle(img): | |
| print("Detecting objects in the image...") | |
| start_time = time.perf_counter() # Start timing | |
| # Detect objects in the image | |
| image_lst = detect_objects(og_image=img, model=od_model, prompt=PROMPT, device=DEVICE) | |
| print(f"Detected {len(image_lst)} objects.") | |
| # Initialize lists to hold results | |
| output_lst = [] | |
| img_cnt = len(image_lst) | |
| for i in range(img_cnt): | |
| print(f"Classifying object {i+1}/{img_cnt}...") | |
| conf_dict = classify_beetle(image_lst[i]) | |
| output_lst.append([image_lst[i], conf_dict]) | |
| print(f"Object {i+1} classified.") | |
| end_time = time.perf_counter() | |
| processing_time = end_time - start_time | |
| print(f"Total processing duration: {processing_time:.2f} seconds") | |
| return output_lst | |
| # ---------------------------- | |
| # Gradio Interface Setup | |
| # ---------------------------- | |
| sample_images_dir = "example_images" | |
| # Sample images with labels | |
| example_images = [ | |
| os.path.join(sample_images_dir, "example1.jpg"), | |
| os.path.join(sample_images_dir, "example2.jpg"), | |
| os.path.join(sample_images_dir, "example3.jpg"), | |
| os.path.join(sample_images_dir, "mixed.jpg") | |
| ] | |
| # Corresponding labels for the example images | |
| example_labels = ["Example Beetles 1", "Example Beetles 2", "Example Beetles 3", "Example Beetles 4"] | |
| with gr.Blocks() as demo: | |
| gr.Markdown("<h1><center>Intelligent Bark Beetle Identifier (IBBI)</center></h1>") | |
| with gr.Column(variant="panel"): | |
| with gr.Row(variant="compact"): | |
| inputs = gr.Image(label="Input Image") | |
| # Add examples with labels | |
| gr.Examples( | |
| label="Select an example below if you have no images to upload.", | |
| examples=example_images, | |
| inputs=inputs, | |
| examples_per_page=4, | |
| example_labels=example_labels | |
| ) | |
| btn = gr.Button("Classify", variant="primary") | |
| # Set the gallery layout and height directly in the constructor | |
| gallery = gr.Gallery(label="Classified Objects", show_label=True, elem_id="gallery", columns=4, height="auto") | |
| # Define the output format for the gallery | |
| def format_gallery(results): | |
| formatted = [] | |
| for img, conf in results: | |
| # Create a label string from the confidence dictionary | |
| label_str = ", ".join([f"{k}: {v:.1f}%" for k, v in conf.items()]) | |
| # Append the image and label as a tuple | |
| formatted.append((img, label_str)) | |
| return formatted | |
| # Modify the click event to format the gallery | |
| btn.click( | |
| lambda img: format_gallery(predict_beetle(img)), | |
| inputs, | |
| gallery | |
| ) | |
| # Launch the Gradio app | |
| demo.launch(share=True, inline=True, debug=True, show_error=True) | |