File size: 11,206 Bytes
bafa380
9c80281
 
 
38723be
9c80281
 
f4a6ba2
9c80281
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
929de5c
 
 
9c80281
929de5c
 
 
 
9c80281
929de5c
 
 
 
9c80281
929de5c
 
 
 
9c80281
 
 
 
 
f4a6ba2
9c80281
f4a6ba2
 
9c80281
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f4a6ba2
 
9c80281
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f4a6ba2
9c80281
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
929de5c
9c80281
929de5c
9c80281
 
f4a6ba2
 
 
 
 
 
 
 
929de5c
9c80281
 
 
 
f4a6ba2
 
9c80281
 
 
929de5c
9c80281
 
 
 
 
 
 
 
 
f4a6ba2
9c80281
929de5c
9c80281
38723be
929de5c
9c80281
 
 
 
38723be
929de5c
9c80281
 
f4a6ba2
9c80281
f4a6ba2
9c80281
 
 
3e96378
9c80281
 
 
 
 
 
 
3e96378
9c80281
 
 
 
 
 
f4a6ba2
9c80281
 
 
 
 
 
f4a6ba2
9c80281
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e96378
9c80281
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38723be
f4a6ba2
9c80281
 
 
 
 
38723be
9c80281
 
 
 
 
38723be
3e96378
 
 
 
 
 
 
 
 
 
 
9c80281
3e96378
f4a6ba2
 
38723be
 
3e96378
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
import gradio as gr
import ibbi
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import matplotlib.pyplot as plt
import io

# --- Model Management ---
MODEL_REGISTRY = {
    "Single-Class Detection": {
        "yolov10": "yolov10x_bb_detect_model",
        "yolov11": "yolov11x_bb_detect_model",
        "yolov9": "yolov9e_bb_detect_model",
        "yolov8": "yolov8x_bb_detect_model",
        "rtdetr": "rtdetrx_bb_detect_model",
    },
    "Multi-Class Detection": {
        "yolov10": "yolov10x_bb_multi_class_detect_model",
        "yolov11": "yolov11x_bb_multi_class_detect_model",
        "yolov9": "yolov9e_bb_multi_class_detect_model",
        "yolov8": "yolov8x_bb_multi_class_detect_model",
        "rtdetr": "rtdetrx_bb_multi_class_detect_model",
    },
    "Zero-Shot Detection": {
        "grounding_dino": "grounding_dino_detect_model"
    }
}

# --- CORRECTED MODEL MANAGEMENT ---
# Caching is removed to prevent errors from stateful models.
# This function now loads a fresh model for each analysis request.
def get_model(task, architecture):
    """
    Loads a fresh model instance based on user selection.
    This prevents stateful changes from one run affecting the next.
    """
    try:
        # For Zero-Shot, the architecture is always 'grounding_dino'
        if task == "Zero-Shot Detection":
            architecture = "grounding_dino"
            
        model_name = MODEL_REGISTRY[task][architecture]
        print(f"Loading a fresh model instance: {model_name}")
        model = ibbi.create_model(model_name, pretrained=True)
        print("Model loaded successfully!")
        return model
    except KeyError as e:
        raise gr.Error(f"Model lookup failed. Task: '{task}', Arch: '{architecture}'. Error: {e}")
    except Exception as e:
        raise gr.Error(f"Failed to load model. Please check the model name and your connection. Error: {e}")

# --- Visualization and Drawing Functions ---

def draw_yolo_predictions(image, results, font, color="red"):
    """Draws YOLO predictions on an image with a dynamically sized font."""
    img_copy = image.copy()
    draw = ImageDraw.Draw(img_copy)
    if not results or not results[0].boxes:
        return img_copy
    res_for_img = results[0]
    class_names = res_for_img.names
    for box in res_for_img.boxes:
        if box.cls.numel() == 0 or box.conf.numel() == 0: continue
        coords = box.xyxy[0].tolist()
        score = box.conf[0].item()
        class_id = int(box.cls[0].item())
        label_text = f"{class_names.get(class_id, f'Unknown-{class_id}')}: {score:.2f}"
        draw.rectangle(coords, outline=color, width=3)
        text_bbox = draw.textbbox((coords[0], coords[1]), label_text, font=font)
        text_bg_y1 = coords[1] - (text_bbox[3] - text_bbox[1]) if coords[1] > (text_bbox[3] - text_bbox[1]) else 0
        text_bg_coords = (coords[0], text_bg_y1, coords[0] + (text_bbox[2] - text_bbox[0]), coords[1])
        draw.rectangle(text_bg_coords, fill=color)
        draw.text((coords[0], text_bg_y1), label_text, fill="white", font=font)
    return img_copy

def draw_dino_predictions(image, results, font, color="green"):
    """Draws Grounding DINO predictions on an image with a dynamically sized font."""
    img_copy = image.copy()
    draw = ImageDraw.Draw(img_copy)
    if not results: return img_copy
    for box, score, label in zip(results.get("boxes", []), results.get("scores", []), results.get("text_labels", [])):
        coords = box.tolist()
        label_text = f"{label}: {score:.2f}"
        draw.rectangle(coords, outline=color, width=3)
        text_bbox = draw.textbbox((coords[0], coords[1]), label_text, font=font)
        text_bg_y1 = coords[1] - (text_bbox[3] - text_bbox[1]) if coords[1] > (text_bbox[3] - text_bbox[1]) else 0
        text_bg_coords = (coords[0], text_bg_y1, coords[0] + (text_bbox[2] - text_bbox[0]), coords[1])
        draw.rectangle(text_bg_coords, fill=color)
        draw.text((coords[0], text_bg_y1), label_text, fill="white", font=font)
    return img_copy

def visualize_embedding(embedding):
    """Visualizes a feature embedding as an image."""
    if embedding is None: return None
    if not hasattr(embedding, 'cpu'): return None
    if len(embedding.shape) == 1:
        embedding = embedding.unsqueeze(0)
    fig, ax = plt.subplots(figsize=(10, 2))
    ax.imshow(embedding.cpu().detach().numpy(), cmap='viridis', aspect='auto')
    ax.set_title("Feature Embedding Visualization")
    ax.set_xlabel("Feature Dimension")
    ax.set_yticks([])
    fig.tight_layout()
    buf = io.BytesIO()
    fig.savefig(buf, format='png')
    plt.close(fig)
    buf.seek(0)
    return Image.open(buf)

# --- CORRECTED Main Processing Function ---
def comprehensive_analysis(image, task, architecture, text_prompt, box_threshold, text_threshold):
    """Performs the main analysis with corrected logic."""
    if image is None:
        raise gr.Error("Please upload an image first!")

    # Calculate a dynamic font size based on image width.
    dynamic_font_size = max(15, int(image.width * 0.04))
    try:
        font = ImageFont.truetype("arial.ttf", dynamic_font_size)
    except IOError:
        font = ImageFont.load_default(size=dynamic_font_size)

    # Get a fresh model instance to avoid stateful errors
    model = get_model(task, architecture)
    outputs = {"annotated_image": None, "model_info": "", "classes_info": "", "embedding_plot": None}

    if task in ["Single-Class Detection", "Multi-Class Detection"]:
        results = model.predict(image)
        outputs["annotated_image"] = draw_yolo_predictions(image, results, font=font)
        features = model.extract_features(image)
        outputs["model_info"] = f"Architecture: {architecture.upper()}\nTask: {task}\nDevice: {model.device}"
        outputs["classes_info"] = f"Classes: {model.get_classes()}"
    else:  # Zero-Shot Detection
        if not text_prompt:
            raise gr.Error("Please provide a text prompt for Zero-Shot Detection.")
        
        results = model.predict(
            image,
            text_prompt=text_prompt,
            box_threshold=box_threshold,
            text_threshold=text_threshold
        )
        outputs["annotated_image"] = draw_dino_predictions(image, results, font=font)
        features = model.extract_features(image, text_prompt=text_prompt)
        outputs["model_info"] = f"Architecture: GROUNDING_DINO\nTask: {task}\nDevice: {model.device}\nHF Model ID: {model.model.config._name_or_path}"
        outputs["classes_info"] = f"Prompt: '{text_prompt}'"

    # Process features for visualization
    if isinstance(features, dict):
        outputs["embedding_plot"] = visualize_embedding(features.get('last_hidden_state'))
    else:
        outputs["embedding_plot"] = visualize_embedding(features)
    
    # Correctly placed return statement ensures all outputs are always returned
    return outputs["annotated_image"], outputs["model_info"], outputs["classes_info"], outputs["embedding_plot"]

# --- Gradio UI ---
def update_ui_for_task(task):
    """Updates the UI components based on the selected task."""
    if task in ["Single-Class Detection", "Multi-Class Detection"]:
        arch_choices = list(MODEL_REGISTRY[task].keys())
        return {
            arch_dropdown: gr.update(choices=arch_choices, value=arch_choices[0], visible=True, interactive=True),
            prompt_textbox: gr.update(visible=False, value=""),
            box_threshold_slider: gr.update(visible=False),
            text_threshold_slider: gr.update(visible=False)
        }
    else:  # Zero-Shot Detection
        arch_choices = list(MODEL_REGISTRY[task].keys())
        return {
            arch_dropdown: gr.update(choices=arch_choices, value=arch_choices[0], visible=False),
            prompt_textbox: gr.update(visible=True),
            box_threshold_slider: gr.update(visible=True),
            text_threshold_slider: gr.update(visible=True)
        }

with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# IBBI - Intelligent Bark Beetle Identifier")
    gr.Markdown("An all-in-one interface to analyze images using the `ibbi` library. Upload an image, select a task and model, and view the complete analysis.")

    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("### 1. Inputs")
            image_input = gr.Image(type="pil", label="Upload Image")
            
            task_selector = gr.Radio(
                choices=["Single-Class Detection", "Multi-Class Detection", "Zero-Shot Detection"],
                value="Single-Class Detection",
                label="Choose Task"
            )
            arch_dropdown = gr.Dropdown(
                choices=list(MODEL_REGISTRY["Single-Class Detection"].keys()),
                value="yolov10",
                label="Choose Model Architecture"
            )
            prompt_textbox = gr.Textbox(
                label="Enter Text Prompt (for Zero-Shot)",
                placeholder="e.g., insect . circle . metal ball",
                visible=False
            )
            box_threshold_slider = gr.Slider(
                minimum=0.05, maximum=1.0, value=0.25, step=0.05,
                label="Box Threshold (Zero-Shot)",
                info="Lower to detect more objects, even with low confidence.",
                visible=False
            )
            text_threshold_slider = gr.Slider(
                minimum=0.05, maximum=1.0, value=0.25, step=0.05,
                label="Text Threshold (Zero-Shot)",
                info="Lower to allow more labels to match detected objects.",
                visible=False
            )
            analyze_btn = gr.Button("Analyze Image", variant="primary")
            
        with gr.Column(scale=2):
            gr.Markdown("### 2. Analysis Results")
            output_image = gr.Image(label="Annotated Image")
            with gr.Accordion("Details", open=True):
                model_details_output = gr.Textbox(label="Model Details", lines=4)
                classes_output = gr.Textbox(label="Classes / Prompt")
                embedding_output = gr.Image(label="Feature Embedding Visualization")
    
    # --- Event Handlers ---
    task_selector.change(
        fn=update_ui_for_task,
        inputs=task_selector,
        outputs=[arch_dropdown, prompt_textbox, box_threshold_slider, text_threshold_slider]
    )
    
    analyze_btn.click(
        fn=comprehensive_analysis,
        inputs=[image_input, task_selector, arch_dropdown, prompt_textbox, box_threshold_slider, text_threshold_slider],
        outputs=[output_image, model_details_output, classes_output, embedding_output]
    )
    
    gr.Markdown("---")
    gr.Markdown("### 3. Or Start with an Example Image")
    
    example_list = [
        ["example_images/example1.jpg"],
        ["example_images/example2.jpg"],
        ["example_images/example3.jpg"],
        ["example_images/example4.jpg"],
        ["example_images/example5.jpg"],
    ]

    gr.Examples(
        examples=example_list,
        inputs=image_input,
        label="Select an image to load it"
    )

if __name__ == "__main__":
    demo.launch(share=True, inline=True, debug=True, show_error=True)