import gradio as gr import ibbi import numpy as np from PIL import Image, ImageDraw, ImageFont import matplotlib.pyplot as plt import io # --- Model Management --- MODEL_REGISTRY = { "Single-Class Detection": { "yolov10": "yolov10x_bb_detect_model", "yolov11": "yolov11x_bb_detect_model", "yolov9": "yolov9e_bb_detect_model", "yolov8": "yolov8x_bb_detect_model", "rtdetr": "rtdetrx_bb_detect_model", }, "Multi-Class Detection": { "yolov10": "yolov10x_bb_multi_class_detect_model", "yolov11": "yolov11x_bb_multi_class_detect_model", "yolov9": "yolov9e_bb_multi_class_detect_model", "yolov8": "yolov8x_bb_multi_class_detect_model", "rtdetr": "rtdetrx_bb_multi_class_detect_model", }, "Zero-Shot Detection": { "grounding_dino": "grounding_dino_detect_model" } } # --- CORRECTED MODEL MANAGEMENT --- # Caching is removed to prevent errors from stateful models. # This function now loads a fresh model for each analysis request. def get_model(task, architecture): """ Loads a fresh model instance based on user selection. This prevents stateful changes from one run affecting the next. """ try: # For Zero-Shot, the architecture is always 'grounding_dino' if task == "Zero-Shot Detection": architecture = "grounding_dino" model_name = MODEL_REGISTRY[task][architecture] print(f"Loading a fresh model instance: {model_name}") model = ibbi.create_model(model_name, pretrained=True) print("Model loaded successfully!") return model except KeyError as e: raise gr.Error(f"Model lookup failed. Task: '{task}', Arch: '{architecture}'. Error: {e}") except Exception as e: raise gr.Error(f"Failed to load model. Please check the model name and your connection. Error: {e}") # --- Visualization and Drawing Functions --- def draw_yolo_predictions(image, results, font, color="red"): """Draws YOLO predictions on an image with a dynamically sized font.""" img_copy = image.copy() draw = ImageDraw.Draw(img_copy) if not results or not results[0].boxes: return img_copy res_for_img = results[0] class_names = res_for_img.names for box in res_for_img.boxes: if box.cls.numel() == 0 or box.conf.numel() == 0: continue coords = box.xyxy[0].tolist() score = box.conf[0].item() class_id = int(box.cls[0].item()) label_text = f"{class_names.get(class_id, f'Unknown-{class_id}')}: {score:.2f}" draw.rectangle(coords, outline=color, width=3) text_bbox = draw.textbbox((coords[0], coords[1]), label_text, font=font) text_bg_y1 = coords[1] - (text_bbox[3] - text_bbox[1]) if coords[1] > (text_bbox[3] - text_bbox[1]) else 0 text_bg_coords = (coords[0], text_bg_y1, coords[0] + (text_bbox[2] - text_bbox[0]), coords[1]) draw.rectangle(text_bg_coords, fill=color) draw.text((coords[0], text_bg_y1), label_text, fill="white", font=font) return img_copy def draw_dino_predictions(image, results, font, color="green"): """Draws Grounding DINO predictions on an image with a dynamically sized font.""" img_copy = image.copy() draw = ImageDraw.Draw(img_copy) if not results: return img_copy for box, score, label in zip(results.get("boxes", []), results.get("scores", []), results.get("text_labels", [])): coords = box.tolist() label_text = f"{label}: {score:.2f}" draw.rectangle(coords, outline=color, width=3) text_bbox = draw.textbbox((coords[0], coords[1]), label_text, font=font) text_bg_y1 = coords[1] - (text_bbox[3] - text_bbox[1]) if coords[1] > (text_bbox[3] - text_bbox[1]) else 0 text_bg_coords = (coords[0], text_bg_y1, coords[0] + (text_bbox[2] - text_bbox[0]), coords[1]) draw.rectangle(text_bg_coords, fill=color) draw.text((coords[0], text_bg_y1), label_text, fill="white", font=font) return img_copy def visualize_embedding(embedding): """Visualizes a feature embedding as an image.""" if embedding is None: return None if not hasattr(embedding, 'cpu'): return None if len(embedding.shape) == 1: embedding = embedding.unsqueeze(0) fig, ax = plt.subplots(figsize=(10, 2)) ax.imshow(embedding.cpu().detach().numpy(), cmap='viridis', aspect='auto') ax.set_title("Feature Embedding Visualization") ax.set_xlabel("Feature Dimension") ax.set_yticks([]) fig.tight_layout() buf = io.BytesIO() fig.savefig(buf, format='png') plt.close(fig) buf.seek(0) return Image.open(buf) # --- CORRECTED Main Processing Function --- def comprehensive_analysis(image, task, architecture, text_prompt, box_threshold, text_threshold): """Performs the main analysis with corrected logic.""" if image is None: raise gr.Error("Please upload an image first!") # Calculate a dynamic font size based on image width. dynamic_font_size = max(15, int(image.width * 0.04)) try: font = ImageFont.truetype("arial.ttf", dynamic_font_size) except IOError: font = ImageFont.load_default(size=dynamic_font_size) # Get a fresh model instance to avoid stateful errors model = get_model(task, architecture) outputs = {"annotated_image": None, "model_info": "", "classes_info": "", "embedding_plot": None} if task in ["Single-Class Detection", "Multi-Class Detection"]: results = model.predict(image) outputs["annotated_image"] = draw_yolo_predictions(image, results, font=font) features = model.extract_features(image) outputs["model_info"] = f"Architecture: {architecture.upper()}\nTask: {task}\nDevice: {model.device}" outputs["classes_info"] = f"Classes: {model.get_classes()}" else: # Zero-Shot Detection if not text_prompt: raise gr.Error("Please provide a text prompt for Zero-Shot Detection.") results = model.predict( image, text_prompt=text_prompt, box_threshold=box_threshold, text_threshold=text_threshold ) outputs["annotated_image"] = draw_dino_predictions(image, results, font=font) features = model.extract_features(image, text_prompt=text_prompt) outputs["model_info"] = f"Architecture: GROUNDING_DINO\nTask: {task}\nDevice: {model.device}\nHF Model ID: {model.model.config._name_or_path}" outputs["classes_info"] = f"Prompt: '{text_prompt}'" # Process features for visualization if isinstance(features, dict): outputs["embedding_plot"] = visualize_embedding(features.get('last_hidden_state')) else: outputs["embedding_plot"] = visualize_embedding(features) # Correctly placed return statement ensures all outputs are always returned return outputs["annotated_image"], outputs["model_info"], outputs["classes_info"], outputs["embedding_plot"] # --- Gradio UI --- def update_ui_for_task(task): """Updates the UI components based on the selected task.""" if task in ["Single-Class Detection", "Multi-Class Detection"]: arch_choices = list(MODEL_REGISTRY[task].keys()) return { arch_dropdown: gr.update(choices=arch_choices, value=arch_choices[0], visible=True, interactive=True), prompt_textbox: gr.update(visible=False, value=""), box_threshold_slider: gr.update(visible=False), text_threshold_slider: gr.update(visible=False) } else: # Zero-Shot Detection arch_choices = list(MODEL_REGISTRY[task].keys()) return { arch_dropdown: gr.update(choices=arch_choices, value=arch_choices[0], visible=False), prompt_textbox: gr.update(visible=True), box_threshold_slider: gr.update(visible=True), text_threshold_slider: gr.update(visible=True) } with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# IBBI - Intelligent Bark Beetle Identifier") gr.Markdown("An all-in-one interface to analyze images using the `ibbi` library. Upload an image, select a task and model, and view the complete analysis.") with gr.Row(): with gr.Column(scale=1): gr.Markdown("### 1. Inputs") image_input = gr.Image(type="pil", label="Upload Image") task_selector = gr.Radio( choices=["Single-Class Detection", "Multi-Class Detection", "Zero-Shot Detection"], value="Single-Class Detection", label="Choose Task" ) arch_dropdown = gr.Dropdown( choices=list(MODEL_REGISTRY["Single-Class Detection"].keys()), value="yolov10", label="Choose Model Architecture" ) prompt_textbox = gr.Textbox( label="Enter Text Prompt (for Zero-Shot)", placeholder="e.g., insect . circle . metal ball", visible=False ) box_threshold_slider = gr.Slider( minimum=0.05, maximum=1.0, value=0.25, step=0.05, label="Box Threshold (Zero-Shot)", info="Lower to detect more objects, even with low confidence.", visible=False ) text_threshold_slider = gr.Slider( minimum=0.05, maximum=1.0, value=0.25, step=0.05, label="Text Threshold (Zero-Shot)", info="Lower to allow more labels to match detected objects.", visible=False ) analyze_btn = gr.Button("Analyze Image", variant="primary") with gr.Column(scale=2): gr.Markdown("### 2. Analysis Results") output_image = gr.Image(label="Annotated Image") with gr.Accordion("Details", open=True): model_details_output = gr.Textbox(label="Model Details", lines=4) classes_output = gr.Textbox(label="Classes / Prompt") embedding_output = gr.Image(label="Feature Embedding Visualization") # --- Event Handlers --- task_selector.change( fn=update_ui_for_task, inputs=task_selector, outputs=[arch_dropdown, prompt_textbox, box_threshold_slider, text_threshold_slider] ) analyze_btn.click( fn=comprehensive_analysis, inputs=[image_input, task_selector, arch_dropdown, prompt_textbox, box_threshold_slider, text_threshold_slider], outputs=[output_image, model_details_output, classes_output, embedding_output] ) gr.Markdown("---") gr.Markdown("### 3. Or Start with an Example Image") example_list = [ ["example_images/example1.jpg"], ["example_images/example2.jpg"], ["example_images/example3.jpg"], ["example_images/example4.jpg"], ["example_images/example5.jpg"], ] gr.Examples( examples=example_list, inputs=image_input, label="Select an image to load it" ) if __name__ == "__main__": demo.launch(share=True, inline=True, debug=True, show_error=True)