File size: 10,273 Bytes
1d08579
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90b6956
 
1d08579
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90b6956
cb187fe
1d08579
 
 
90b6956
 
 
 
 
 
 
 
 
 
 
1d08579
 
90b6956
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb187fe
 
 
 
 
 
 
 
 
90b6956
 
 
 
1d08579
 
 
90b6956
1d08579
 
 
 
 
 
 
 
 
 
90b6956
 
1d08579
 
 
 
 
 
 
 
 
 
 
90b6956
1d08579
 
cb187fe
 
1d08579
90b6956
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb187fe
 
 
 
90b6956
 
cb187fe
90b6956
 
cb187fe
90b6956
cb187fe
 
90b6956
 
cb187fe
1d08579
 
 
 
 
 
 
5b4cb04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb187fe
5b4cb04
cb187fe
 
5b4cb04
 
 
 
cb187fe
5b4cb04
 
 
cb187fe
 
5b4cb04
1d08579
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
import functools
import cv2
import numpy as np
import gradio as gr
import os
from types import MethodType
from ultralytics import YOLO
from huggingface_hub import hf_hub_download

# Import helper functions from the existing feature-extractor script
from yolov10_RoIFX import (
    _predict_once,
    get_result_with_features_yolov10_simple,
    draw_modern_bbox,
    draw_feature_heatmap,
)

# ---------------------------
#  Constants & Setup
# ---------------------------

# Set up model and example paths
REPO_ID = "HugoHE/X-YOLOv10"
MODELS_DIR = "models"
os.makedirs(MODELS_DIR, exist_ok=True)

# Download models from Hugging Face Hub
def download_models():
    for model_file in ["vanilla.pt", "finetune.pt"]:
        if not os.path.exists(os.path.join(MODELS_DIR, model_file)):
            try:
                hf_hub_download(
                    repo_id=REPO_ID,
                    filename=f"models/{model_file}",
                    local_dir=".",
                    local_dir_use_symlinks=False
                )
            except Exception as e:
                print(f"Error downloading {model_file}: {e}")

# Download example images from Hugging Face Hub
def download_examples():
    for img_file in ["1.png", "2.png"]:
        if not os.path.exists(img_file):
            try:
                hf_hub_download(
                    repo_id=REPO_ID,
                    filename=img_file,
                    local_dir=".",
                    local_dir_use_symlinks=False
                )
            except Exception as e:
                print(f"Error downloading {img_file}: {e}")

# Download required files
download_models()
download_examples()

AVAILABLE_MODELS = {
    "Vanilla VOC": "vanilla.pt",
    "Finetune VOC": "finetune.pt"
}

# Example images with their descriptions
EXAMPLES = [
    ["1.png", 0.25],
    ["2.png", 0.25]
]

# ---------------------------
#  Model loading & caching
# ---------------------------

def load_model(model_name: str):
    """Load a YOLOv10 model and cache it so subsequent calls are fast."""

    @functools.lru_cache(maxsize=2)
    def _loader(name: str):
        model_path = os.path.join(MODELS_DIR, AVAILABLE_MODELS[name])
        model = YOLO(model_path)
        # Monkey-patch the predictor so we can extract feature maps on demand
        model.model._predict_once = MethodType(_predict_once, model.model)
        # Run a dummy inference to initialise internals
        model(np.zeros((640, 640, 3)), verbose=False)

        # Automatically determine which layers to use for feature extraction
        detect_layer_idx = -1
        for i, m in enumerate(model.model.model):
            if "Detect" in type(m).__name__:
                detect_layer_idx = i
                break
        if detect_layer_idx != -1:
            input_layer_idxs = model.model.model[detect_layer_idx].f
            embed_layers = sorted(input_layer_idxs) + [detect_layer_idx]
        else:
            embed_layers = [16, 19, 22, 23]  # fallback

        return model, tuple(embed_layers)

    return _loader(model_name)


# ---------------------------
#  Composite heat-map layout
# ---------------------------

def generate_heatmap_layout(img_rgb: np.ndarray, conf: float = 0.25):
    """Return separate XAI heatmap layouts for vanilla and fine-tuned models."""

    # Convert RGB (Gradio default) ➜ BGR (OpenCV default)
    img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
    
    # Load both models
    vanilla_model, vanilla_embed_layers = load_model("Vanilla VOC")
    finetune_model, finetune_embed_layers = load_model("Finetune VOC")
    
    # Run inference on both models
    vanilla_results = get_result_with_features_yolov10_simple(
        vanilla_model, img_bgr, vanilla_embed_layers, conf=conf
    )
    finetune_results = get_result_with_features_yolov10_simple(
        finetune_model, img_bgr, finetune_embed_layers, conf=conf
    )

    # Check if any detections were made
    vanilla_has_detections = (vanilla_results and len(vanilla_results) > 0 and 
                             hasattr(vanilla_results[0], "boxes") and len(vanilla_results[0].boxes) > 0)
    finetune_has_detections = (finetune_results and len(finetune_results) > 0 and 
                              hasattr(finetune_results[0], "boxes") and len(finetune_results[0].boxes) > 0)
    
    # Create heatmap visualizations for both models
    vanilla_heatmaps = []
    finetune_heatmaps = []
    
    if vanilla_has_detections:
        vanilla_result = vanilla_results[0]
        vanilla_names = [vanilla_model.model.names[int(cls)] for cls in vanilla_result.boxes.cls]
        vanilla_heatmaps = create_heatmap_snippets(img_bgr, vanilla_result, vanilla_names, "Vanilla")
    
    if finetune_has_detections:
        finetune_result = finetune_results[0]
        finetune_names = [finetune_model.model.names[int(cls)] for cls in finetune_result.boxes.cls]
        finetune_heatmaps = create_heatmap_snippets(img_bgr, finetune_result, finetune_names, "Fine-tuned")

    # Create separate layouts for each model
    vanilla_layout = create_model_layout(vanilla_heatmaps, "Vanilla Model", (0, 100, 0))
    finetune_layout = create_model_layout(finetune_heatmaps, "Fine-tuned Model", (0, 0, 200))
    
    # Convert BGR to RGB for display
    vanilla_output = cv2.cvtColor(vanilla_layout, cv2.COLOR_BGR2RGB) if vanilla_layout is not None else None
    finetune_output = cv2.cvtColor(finetune_layout, cv2.COLOR_BGR2RGB) if finetune_layout is not None else None
    
    return vanilla_output, finetune_output


def create_heatmap_snippets(img_bgr, result, names, model_type):
    """Create heatmap snippets for detected objects."""
    snippets = []
    if hasattr(result, "pooled_feats") and result.pooled_feats:
        last_pooled = result.pooled_feats[-1]
        for i in range(len(result.boxes)):
            box = result.boxes.xyxy[i]
            fmap = last_pooled[i]
            heatmap_full = draw_feature_heatmap(img_bgr.copy(), box, fmap)
            x1, y1, x2, y2 = box.cpu().numpy().astype(int)
            x1, y1 = max(0, x1), max(0, y1)
            x2, y2 = min(img_bgr.shape[1], x2), min(img_bgr.shape[0], y2)
            if x2 <= x1 or y2 <= y1:
                continue
            snippet = heatmap_full[y1:y2, x1:x2]

            # Add caption with model type and object info
            caption = f"{model_type}: {names[i]}"
            font = cv2.FONT_HERSHEY_SIMPLEX
            (tw, th), _ = cv2.getTextSize(caption, font, 0.6, 1)
            canvas = np.full((snippet.shape[0] + th + 15, max(snippet.shape[1], tw + 10), 3), 255, np.uint8)
            # center the snippet
            cx = (canvas.shape[1] - snippet.shape[1]) // 2
            canvas[0 : snippet.shape[0], cx : cx + snippet.shape[1]] = snippet
            # put caption
            tx = (canvas.shape[1] - tw) // 2
            cv2.putText(canvas, caption, (tx, snippet.shape[0] + th + 5), font, 0.6, (0, 0, 0), 1, cv2.LINE_AA)
            cv2.rectangle(canvas, (0, 0), (canvas.shape[1] - 1, canvas.shape[0] - 1), (180, 180, 180), 1)
            snippets.append(canvas)
    return snippets


def create_model_layout(heatmaps, title, color):
    """Create a layout for one model's heatmaps."""
    pad = 20
    
    if not heatmaps:
        # Create empty section with title
        font = cv2.FONT_HERSHEY_SIMPLEX
        (tw, th), _ = cv2.getTextSize(title, font, 1.0, 2)
        canvas = np.full((th + 40, tw + 20, 3), 255, np.uint8)
        cv2.putText(canvas, title, (10, th + 20), font, 1.0, color, 2, cv2.LINE_AA)
        return canvas
    
    # Arrange heatmaps in a row
    max_h = max(h.shape[0] for h in heatmaps)
    total_w = sum(h.shape[1] for h in heatmaps) + (len(heatmaps) - 1) * 10
    
    # Add title space
    title_font = cv2.FONT_HERSHEY_SIMPLEX
    (tw, th), _ = cv2.getTextSize(title, title_font, 1.0, 2)
    section_h = max_h + th + 40
    section_w = max(total_w, tw + 20)
    
    # Create canvas with padding
    canvas_h = section_h + 2 * pad
    canvas_w = section_w + 2 * pad
    canvas = np.full((canvas_h, canvas_w, 3), 255, np.uint8)
    
    # Add title
    cv2.putText(canvas, title, (pad + 10, pad + th + 20), title_font, 1.0, color, 2, cv2.LINE_AA)
    
    # Arrange heatmaps
    cur_x = pad
    for h in heatmaps:
        y_off = pad + th + 30 + (max_h - h.shape[0]) // 2
        canvas[y_off : y_off + h.shape[0], cur_x : cur_x + h.shape[1]] = h
        cur_x += h.shape[1] + 10
    
    return canvas


# ---------------------------
#  Gradio UI definition
# ---------------------------

def build_demo():
    with gr.Blocks(title="YOLOv10 XAI Heatmap Comparison") as demo:
        gr.Markdown("# YOLOv10 XAI Heatmap Comparison")
        gr.Markdown("Upload an image to compare XAI heatmaps between vanilla and fine-tuned YOLOv10 models.")
        
        with gr.Row():
            # Left side - Input controls
            with gr.Column(scale=1):
                image_input = gr.Image(type="numpy", label="Input Image")
                conf_input = gr.Slider(minimum=0.05, maximum=1.0, step=0.05, value=0.25, label="Confidence Threshold")
                gr.Markdown("### Example Images")
                gr.Examples(
                    examples=EXAMPLES,
                    inputs=[image_input, conf_input],
                    label="Click to load example"
                )
            
            # Right side - Output visualizations (separated vertically)
            with gr.Column(scale=2):
                vanilla_output = gr.Image(type="numpy", label="Vanilla Model Heatmap")
                finetune_output = gr.Image(type="numpy", label="Fine-tuned Model Heatmap")
        
        # Connect inputs to the function
        def update_heatmap(image, confidence):
            if image is None:
                return None, None
            return generate_heatmap_layout(image, confidence)
        
        # Set up the interface
        image_input.change(fn=update_heatmap, inputs=[image_input, conf_input], outputs=[vanilla_output, finetune_output])
        conf_input.change(fn=update_heatmap, inputs=[image_input, conf_input], outputs=[vanilla_output, finetune_output])
    
    return demo


def main():
    demo = build_demo()
    demo.launch()


if __name__ == "__main__":
    main()