HugoHE commited on
Commit
1d08579
·
0 Parent(s):

Initial commit with code files

Browse files
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
1.png ADDED

Git LFS Details

  • SHA256: 24c16cfd753a55f1d427ffbfe77560580717ab224f91e7ef8db393fc50b000f4
  • Pointer size: 132 Bytes
  • Size of remote file: 2.02 MB
2.png ADDED

Git LFS Details

  • SHA256: 3cca99200d1f53acdc32fa8ecd729bcc8ed6b8e50a50ea80661cc22efdbfc982
  • Pointer size: 132 Bytes
  • Size of remote file: 1.43 MB
README.md ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YOLOv10 Saliency Heat-map Visualiser
2
+
3
+ This Gradio app demonstrates object detection and saliency visualization using YOLOv10 models trained on the VOC dataset. The app allows users to:
4
+
5
+ 1. Choose between vanilla and finetuned YOLOv10 models
6
+ 2. Upload custom images or use provided examples
7
+ 3. Visualize object detections with bounding boxes
8
+ 4. See saliency heat-maps for each detected object
9
+
10
+ ## Models
11
+
12
+ - **Vanilla VOC**: Base YOLOv10 model trained on VOC dataset
13
+ - **Finetune VOC**: Fine-tuned YOLOv10 model with enhanced performance
14
+
15
+ ## Features
16
+
17
+ - Interactive web interface
18
+ - Real-time object detection
19
+ - Saliency heat-map generation
20
+ - Adjustable confidence threshold
21
+ - Example images included
22
+
23
+ ## Usage
24
+
25
+ 1. Select a model from the dropdown menu
26
+ 2. Upload an image or use one of the example images
27
+ 3. Adjust the confidence threshold if needed
28
+ 4. View the detection results and saliency heat-maps
29
+
30
+ ## Technical Details
31
+
32
+ The app uses:
33
+ - Gradio for the web interface
34
+ - YOLOv10 for object detection
35
+ - Custom feature extraction for saliency visualization
36
+ - OpenCV for image processing
37
+
38
+ ## Examples
39
+
40
+ The app includes two example images demonstrating the capabilities of the vanilla model.
__pycache__/yolov10_RoIFX.cpython-310.pyc ADDED
Binary file (12.8 kB). View file
 
app.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import cv2
3
+ import numpy as np
4
+ import gradio as gr
5
+ import torch
6
+ import os
7
+ from types import MethodType
8
+ from ultralytics import YOLO
9
+ from huggingface_hub import hf_hub_download
10
+
11
+ # Import helper functions from the existing feature-extractor script
12
+ from yolov10_RoIFX import (
13
+ _predict_once,
14
+ get_result_with_features_yolov10_simple,
15
+ draw_modern_bbox,
16
+ draw_feature_heatmap,
17
+ )
18
+
19
+ # ---------------------------
20
+ # Constants & Setup
21
+ # ---------------------------
22
+
23
+ # Set up model and example paths
24
+ REPO_ID = "HugoHE/X-YOLOv10"
25
+ MODELS_DIR = "models"
26
+ os.makedirs(MODELS_DIR, exist_ok=True)
27
+
28
+ # Download models from Hugging Face Hub
29
+ def download_models():
30
+ for model_file in ["vanilla.pt", "finetune.pt"]:
31
+ if not os.path.exists(os.path.join(MODELS_DIR, model_file)):
32
+ try:
33
+ hf_hub_download(
34
+ repo_id=REPO_ID,
35
+ filename=f"models/{model_file}",
36
+ local_dir=".",
37
+ local_dir_use_symlinks=False
38
+ )
39
+ except Exception as e:
40
+ print(f"Error downloading {model_file}: {e}")
41
+
42
+ # Download example images from Hugging Face Hub
43
+ def download_examples():
44
+ for img_file in ["1.png", "2.png"]:
45
+ if not os.path.exists(img_file):
46
+ try:
47
+ hf_hub_download(
48
+ repo_id=REPO_ID,
49
+ filename=img_file,
50
+ local_dir=".",
51
+ local_dir_use_symlinks=False
52
+ )
53
+ except Exception as e:
54
+ print(f"Error downloading {img_file}: {e}")
55
+
56
+ # Download required files
57
+ download_models()
58
+ download_examples()
59
+
60
+ AVAILABLE_MODELS = {
61
+ "Vanilla VOC": "vanilla.pt",
62
+ "Finetune VOC": "finetune.pt"
63
+ }
64
+
65
+ # Example images with their descriptions
66
+ EXAMPLES = [
67
+ ["1.png", "Vanilla VOC", 0.25],
68
+ ["2.png", "Vanilla VOC", 0.25]
69
+ ]
70
+
71
+ # ---------------------------
72
+ # Model loading & caching
73
+ # ---------------------------
74
+
75
+ def load_model(model_name: str):
76
+ """Load a YOLOv10 model and cache it so subsequent calls are fast."""
77
+
78
+ @functools.lru_cache(maxsize=2)
79
+ def _loader(name: str):
80
+ model_path = os.path.join(MODELS_DIR, AVAILABLE_MODELS[name])
81
+ model = YOLO(model_path)
82
+ # Monkey-patch the predictor so we can extract feature maps on demand
83
+ model.model._predict_once = MethodType(_predict_once, model.model)
84
+ # Run a dummy inference to initialise internals
85
+ model(np.zeros((640, 640, 3)), verbose=False)
86
+
87
+ # Automatically determine which layers to use for feature extraction
88
+ detect_layer_idx = -1
89
+ for i, m in enumerate(model.model.model):
90
+ if "Detect" in type(m).__name__:
91
+ detect_layer_idx = i
92
+ break
93
+ if detect_layer_idx != -1:
94
+ input_layer_idxs = model.model.model[detect_layer_idx].f
95
+ embed_layers = sorted(input_layer_idxs) + [detect_layer_idx]
96
+ else:
97
+ embed_layers = [16, 19, 22, 23] # fallback
98
+
99
+ return model, tuple(embed_layers)
100
+
101
+ return _loader(model_name)
102
+
103
+
104
+ # ---------------------------
105
+ # Composite heat-map layout
106
+ # ---------------------------
107
+
108
+ def generate_heatmap_layout(img_rgb: np.ndarray, model_name: str, conf: float = 0.25):
109
+ """Return a composite saliency layout image for a given input image & model."""
110
+
111
+ model, embed_layers = load_model(model_name)
112
+
113
+ # Convert RGB (Gradio default) ➜ BGR (OpenCV default)
114
+ img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
115
+
116
+ # Run detection + feature extraction
117
+ results = get_result_with_features_yolov10_simple(
118
+ model, img_bgr, embed_layers, conf=conf
119
+ )
120
+
121
+ if not results or len(results) == 0 or not hasattr(results[0], "boxes"):
122
+ return img_rgb # nothing detected, return original
123
+
124
+ result = results[0]
125
+ if len(result.boxes) == 0:
126
+ return img_rgb
127
+
128
+ num_objects = len(result.boxes)
129
+
130
+ # -------------- Step-1: main image with bboxes --------------
131
+ main_img = img_bgr.copy()
132
+ names = [model.model.names[int(cls)] for cls in result.boxes.cls]
133
+ palette = [
134
+ (71, 224, 253),
135
+ (159, 128, 255),
136
+ (159, 227, 128),
137
+ (255, 191, 0),
138
+ (255, 165, 0),
139
+ (255, 0, 255),
140
+ ]
141
+ for i in range(num_objects):
142
+ lbl = f"{names[i]} {result.boxes.conf[i]:.2f}"
143
+ draw_modern_bbox(main_img, result.boxes.xyxy[i].cpu().numpy(), lbl, palette[i % len(palette)])
144
+
145
+ # -------------- Step-2: heat-map snippets ------------------
146
+ snippets = []
147
+ if hasattr(result, "pooled_feats") and result.pooled_feats:
148
+ last_pooled = result.pooled_feats[-1]
149
+ for i in range(num_objects):
150
+ box = result.boxes.xyxy[i]
151
+ fmap = last_pooled[i]
152
+ heatmap_full = draw_feature_heatmap(img_bgr.copy(), box, fmap)
153
+ x1, y1, x2, y2 = box.cpu().numpy().astype(int)
154
+ x1, y1 = max(0, x1), max(0, y1)
155
+ x2, y2 = min(img_bgr.shape[1], x2), min(img_bgr.shape[0], y2)
156
+ if x2 <= x1 or y2 <= y1:
157
+ continue
158
+ snippet = heatmap_full[y1:y2, x1:x2]
159
+
160
+ # Add a small caption under each snippet
161
+ caption = f"Obj #{i}: {names[i]}"
162
+ font = cv2.FONT_HERSHEY_SIMPLEX
163
+ (tw, th), _ = cv2.getTextSize(caption, font, 0.6, 1)
164
+ canvas = np.full((snippet.shape[0] + th + 15, max(snippet.shape[1], tw + 10), 3), 255, np.uint8)
165
+ # center the snippet
166
+ cx = (canvas.shape[1] - snippet.shape[1]) // 2
167
+ canvas[0 : snippet.shape[0], cx : cx + snippet.shape[1]] = snippet
168
+ # put caption
169
+ tx = (canvas.shape[1] - tw) // 2
170
+ cv2.putText(canvas, caption, (tx, snippet.shape[0] + th + 5), font, 0.6, (0, 0, 0), 1, cv2.LINE_AA)
171
+ cv2.rectangle(canvas, (0, 0), (canvas.shape[1] - 1, canvas.shape[0] - 1), (180, 180, 180), 1)
172
+ snippets.append(canvas)
173
+
174
+ if not snippets:
175
+ # just return the main image if no heatmaps were produced
176
+ return cv2.cvtColor(main_img, cv2.COLOR_BGR2RGB)
177
+
178
+ # -------------- Step-3: assemble composite canvas ----------
179
+ main_h, main_w = main_img.shape[:2]
180
+ pad = 20
181
+ row_h = max(s.shape[0] for s in snippets)
182
+ total_row_w = sum(s.shape[1] for s in snippets) + (len(snippets) - 1) * 10
183
+
184
+ row_canvas = np.full((row_h, total_row_w, 3), 255, np.uint8)
185
+ cur_x = 0
186
+ for s in snippets:
187
+ h, w = s.shape[:2]
188
+ y_off = (row_h - h) // 2
189
+ row_canvas[y_off : y_off + h, cur_x : cur_x + w] = s
190
+ cur_x += w + 10
191
+
192
+ canvas_h = main_h + row_h + 3 * pad
193
+ canvas_w = max(main_w, total_row_w) + 2 * pad
194
+ final = np.full((canvas_h, canvas_w, 3), 255, np.uint8)
195
+
196
+ # paste main image (top-center)
197
+ x_main = (canvas_w - main_w) // 2
198
+ final[pad : pad + main_h, x_main : x_main + main_w] = main_img
199
+
200
+ # paste snippets row (bottom-center)
201
+ x_row = (canvas_w - total_row_w) // 2
202
+ final[main_h + 2 * pad : main_h + 2 * pad + row_h, x_row : x_row + total_row_w] = row_canvas
203
+
204
+ # convert back to RGB for display
205
+ return cv2.cvtColor(final, cv2.COLOR_BGR2RGB)
206
+
207
+
208
+ # ---------------------------
209
+ # Gradio UI definition
210
+ # ---------------------------
211
+
212
+ def build_demo():
213
+ image_input = gr.Image(type="numpy", label="Input Image")
214
+ model_input = gr.Dropdown(
215
+ choices=list(AVAILABLE_MODELS.keys()),
216
+ value=list(AVAILABLE_MODELS.keys())[0],
217
+ label="Select Model"
218
+ )
219
+ conf_input = gr.Slider(minimum=0.05, maximum=1.0, step=0.05, value=0.25, label="Confidence Threshold")
220
+ outputs = gr.Image(type="numpy", label="Saliency Heat-map Layout")
221
+
222
+ demo = gr.Interface(
223
+ fn=generate_heatmap_layout,
224
+ inputs=[image_input, model_input, conf_input],
225
+ outputs=outputs,
226
+ title="YOLOv10 Saliency Heat-map Visualiser",
227
+ description="Select a model (vanilla-voc or finetune-voc) and upload an image. The app will overlay bounding boxes and generate saliency heat-maps for each detected object.",
228
+ examples=EXAMPLES,
229
+ cache_examples=True,
230
+ )
231
+ return demo
232
+
233
+
234
+ def main():
235
+ demo = build_demo()
236
+ demo.launch()
237
+
238
+
239
+ if __name__ == "__main__":
240
+ main()
models/finetune.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:71e09e27011f99e9a34df19be89a4ffb0167790871c23e6549c24ddec194cbba
3
+ size 98072713
models/vanilla.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:823126b7af91ebf5ca4a5926a94e10a32c3e95981f264809245d9ba7b197be0c
3
+ size 65543615
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ ultralytics>=8.0.0
3
+ opencv-python-headless>=4.8.0
4
+ numpy>=1.24.0
5
+ torch>=2.0.0
6
+ torchvision>=0.15.0
7
+ huggingface-hub>=0.20.0
yolov10_RoIFX.py ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ YOLOv10 Single Object Feature Extractor
4
+
5
+ This script extracts features for a specific detected object by its index.
6
+ It can be used to build feature databases or for targeted object analysis.
7
+ """
8
+
9
+ from ultralytics import YOLO
10
+ from ultralytics.utils.ops import xywh2xyxy, scale_boxes
11
+ from ultralytics.engine.results import Results
12
+ import torch
13
+ import time
14
+ from torch.nn.functional import cosine_similarity
15
+ import cv2
16
+ import matplotlib.pyplot as plt
17
+ import numpy as np
18
+ from pathlib import Path
19
+ import urllib.request
20
+ import argparse
21
+ import json
22
+
23
+ from torchvision.ops import RoIAlign as ROIAlign
24
+ import torch.nn as nn
25
+ import torch.nn.functional as F
26
+
27
+ from types import MethodType
28
+ import torchvision
29
+ import collections
30
+
31
+
32
+ # Monkey patch method to get feature maps
33
+ def _predict_once(self, x, profile=False, visualize=False, embed=None):
34
+ y, dt, embeddings = [], [], [] # outputs
35
+ for m in self.model:
36
+ if m.f != -1: # if not from previous layer
37
+ x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
38
+ if profile:
39
+ self._profile_one_layer(m, x, dt)
40
+ x = m(x) # run
41
+ y.append(x if m.i in self.save else None) # save output
42
+ if visualize:
43
+ feature_visualization(x, m.type, m.i, save_dir=visualize)
44
+
45
+ if embed and m.i in embed:
46
+ embeddings.append(x)
47
+ if m.i == max(embed):
48
+ return embeddings
49
+ return x
50
+
51
+
52
+ def get_yolov10_object_features_with_pooler(feat_list, idxs, boxes, orig_img_shape):
53
+ """
54
+ Extracts object features from YOLOv10 feature maps using RoIAlign.
55
+ Concatenates features from all levels for each detected object.
56
+ """
57
+ # Assuming input image is resized to 640x640
58
+ img_size = 640
59
+
60
+ # We need to know the downsampling ratio for each feature map
61
+ # P3 has stride 8, P4 has stride 16, P5 has stride 32
62
+ spatial_scales = [1.0 / 8, 1.0 / 16, 1.0 / 32]
63
+
64
+ num_rois = len(boxes)
65
+ if num_rois == 0:
66
+ return [torch.empty(0)], []
67
+
68
+ # Add batch index 0 to boxes for ROIAlign
69
+ zeros = torch.full((num_rois, 1), 0, device=boxes.device, dtype=boxes.dtype)
70
+ rois = torch.cat((zeros, boxes), dim=1)
71
+
72
+ poolers = [
73
+ ROIAlign(output_size=[7, 7], spatial_scale=ss, sampling_ratio=2) for ss in spatial_scales
74
+ ]
75
+
76
+ pooled_feats = []
77
+ for feat_map, pooler in zip(feat_list, poolers):
78
+ pooled_feats.append(pooler(feat_map, rois))
79
+
80
+ avg_pool = nn.AdaptiveAvgPool2d((1, 1))
81
+
82
+ pooled_feats_flat = [avg_pool(pf).view(num_rois, -1) for pf in pooled_feats]
83
+
84
+ # Concatenate features from all levels
85
+ final_feats = torch.cat(pooled_feats_flat, dim=1)
86
+
87
+ return [final_feats], pooled_feats
88
+
89
+
90
+ def get_result_with_features_yolov10_simple(model, imgs, embed_layers, conf=0.25):
91
+ """
92
+ Simplified approach: Use standard YOLO inference first, then extract features.
93
+ """
94
+ if not isinstance(imgs, list):
95
+ imgs = [imgs]
96
+
97
+ # First, run standard inference to get proper Results objects
98
+ results = model(imgs, verbose=False, conf=conf)
99
+
100
+ # Then extract features for each detected object
101
+ for i, result in enumerate(results):
102
+ if hasattr(result, 'boxes') and len(result.boxes) > 0:
103
+ # Get the preprocessed image that was used for inference
104
+ prepped = model.predictor.preprocess([result.orig_img])
105
+
106
+ # --- Temporarily set the embed layers ---
107
+ # Save the previous setting so we can restore it afterwards. Leaving a non-None
108
+ # value in `model.predictor.args.embed` would cause the model to return raw
109
+ # feature maps (instead of standard detection outputs) on the *next* call,
110
+ # which results in missing detections for every image processed after the
111
+ # first one. Restoring the value here ensures normal behaviour for the
112
+ # following iterations.
113
+ prev_embed = getattr(model.predictor.args, "embed", None)
114
+ model.predictor.args.embed = embed_layers
115
+
116
+ # Call inference with embedding to get feature maps
117
+ features = model.predictor.inference(prepped)
118
+
119
+ # Restore previous embed setting
120
+ model.predictor.args.embed = prev_embed
121
+
122
+ # The feature maps are all but the last element of the result
123
+ feature_maps = features[:-1]
124
+
125
+ # Extract features for each detected box
126
+ boxes_scaled = result.boxes.xyxy
127
+ # Scale boxes to the preprocessed image size for feature extraction
128
+ boxes_for_features = scale_boxes(result.orig_img.shape, boxes_scaled.clone(), prepped.shape[2:])
129
+
130
+ # Create dummy indices (we're not using NMS indices here)
131
+ dummy_idxs = [torch.arange(len(boxes_for_features))]
132
+
133
+ # Get features
134
+ obj_feats, pooled_feats = get_yolov10_object_features_with_pooler(feature_maps, dummy_idxs, boxes_for_features, result.orig_img.shape)
135
+
136
+ # Add features to the result
137
+ result.feats = obj_feats[0] if obj_feats else torch.empty(0)
138
+ result.pooled_feats = pooled_feats
139
+
140
+ return results
141
+
142
+
143
+ def draw_debug_image(img, boxes, class_names, save_path="debug_detections.png", highlight_idx=None):
144
+ """Draw bounding boxes on the original image for debugging."""
145
+ debug_img = img.copy()
146
+ for i, box in enumerate(boxes):
147
+ x1, y1, x2, y2 = box.cpu().numpy().astype(int)
148
+ # Clip coordinates to image bounds
149
+ x1, y1 = max(0, x1), max(0, y1)
150
+ x2, y2 = min(img.shape[1], x2), min(img.shape[0], y2)
151
+
152
+ # Highlight the selected object
153
+ color = (0, 0, 255) if i == highlight_idx else (0, 255, 0) # Red for selected, green for others
154
+ thickness = 3 if i == highlight_idx else 2
155
+
156
+ cv2.rectangle(debug_img, (x1, y1), (x2, y2), color, thickness)
157
+ cv2.putText(debug_img, f"{class_names[i]} #{i}", (x1, y1-10),
158
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)
159
+
160
+ cv2.imwrite(save_path, debug_img)
161
+ print(f"Debug image with bounding boxes saved to {save_path}")
162
+ return debug_img
163
+
164
+
165
+ def draw_feature_heatmap(image, box, feature_map):
166
+ """
167
+ Draws a feature map as a heatmap on a specific region of an image.
168
+ """
169
+ # Detach and move feature map to CPU
170
+ feature_map = feature_map.detach().cpu()
171
+
172
+ # Average features across channels to get a 2D heatmap
173
+ heatmap = torch.mean(feature_map, dim=0).numpy()
174
+
175
+ # Normalize heatmap to 0-255
176
+ if np.max(heatmap) > np.min(heatmap):
177
+ heatmap = (heatmap - np.min(heatmap)) / (np.max(heatmap) - np.min(heatmap))
178
+ heatmap = (heatmap * 255).astype(np.uint8)
179
+
180
+ # Get bounding box coordinates
181
+ x1, y1, x2, y2 = box.cpu().numpy().astype(int)
182
+ x1, y1 = max(0, x1), max(0, y1)
183
+ x2, y2 = min(image.shape[1], x2), min(image.shape[0], y2)
184
+
185
+ bbox_w, bbox_h = x2 - x1, y2 - y1
186
+ if bbox_w <= 0 or bbox_h <= 0:
187
+ return image # return original image
188
+
189
+ # Resize heatmap to bounding box size
190
+ heatmap_resized = cv2.resize(heatmap, (bbox_w, bbox_h), interpolation=cv2.INTER_LINEAR)
191
+
192
+ # Apply colormap
193
+ heatmap_colored = cv2.applyColorMap(heatmap_resized, cv2.COLORMAP_JET)
194
+
195
+ # Get the region of interest from the original image
196
+ roi = image[y1:y2, x1:x2]
197
+
198
+ # Blend heatmap with ROI
199
+ overlay = cv2.addWeighted(roi, 0.6, heatmap_colored, 0.4, 0)
200
+
201
+ # Place the overlay back onto the image
202
+ output_image = image.copy()
203
+ output_image[y1:y2, x1:x2] = overlay
204
+
205
+ return output_image
206
+
207
+
208
+ def draw_filled_rounded_rectangle(img, pt1, pt2, color, radius):
209
+ """Draws a filled rounded rectangle."""
210
+ x1, y1 = pt1
211
+ x2, y2 = pt2
212
+
213
+ # Draw circles at the corners
214
+ cv2.circle(img, (x1 + radius, y1 + radius), radius, color, -1)
215
+ cv2.circle(img, (x2 - radius, y1 + radius), radius, color, -1)
216
+ cv2.circle(img, (x1 + radius, y2 - radius), radius, color, -1)
217
+ cv2.circle(img, (x2 - radius, y2 - radius), radius, color, -1)
218
+
219
+ # Draw the central rectangles
220
+ cv2.rectangle(img, (x1 + radius, y1), (x2 - radius, y2), color, -1)
221
+ cv2.rectangle(img, (x1, y1 + radius), (x2, y2 - radius), color, -1)
222
+
223
+
224
+ def draw_modern_bbox(image, box, label, color):
225
+ """Draws a modern-style bounding box with a semi-transparent, rounded label."""
226
+ x1, y1, x2, y2 = box.astype(int)
227
+
228
+ # Draw the main bounding box outline
229
+ cv2.rectangle(image, (x1, y1), (x2, y2), color, thickness=2)
230
+
231
+ # --- Label ---
232
+ font = cv2.FONT_HERSHEY_SIMPLEX
233
+ font_scale = 0.5
234
+ font_thickness = 1
235
+ (text_w, text_h), _ = cv2.getTextSize(label, font, font_scale, font_thickness)
236
+
237
+ # Define label background position, handling top-of-image cases
238
+ label_bg_pt1 = (x1, y1 - text_h - 15)
239
+ label_bg_pt2 = (x1 + text_w + 10, y1)
240
+ if label_bg_pt1[1] < 0:
241
+ label_bg_pt1 = (x1, y1 + 5)
242
+ label_bg_pt2 = (x1 + text_w + 10, y1 + text_h + 20)
243
+
244
+ # Create an overlay for the semi-transparent background
245
+ overlay = image.copy()
246
+
247
+ # Draw the filled rounded rectangle on the overlay
248
+ draw_filled_rounded_rectangle(overlay, label_bg_pt1, label_bg_pt2, color, radius=8)
249
+
250
+ # Blend the overlay with the main image
251
+ alpha = 0.6
252
+ cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0, image)
253
+
254
+ # Define text position and draw it on the blended image
255
+ text_pt = (label_bg_pt1[0] + 5, label_bg_pt1[1] + text_h + 5)
256
+ cv2.putText(image, label, text_pt, font, font_scale, (0, 0, 0), font_thickness, cv2.LINE_AA)
257
+
258
+
259
+ def generate_feature_heatmaps(model, img_path, embed_layers, output_dir="./", conf=0.25):
260
+ """
261
+ Generates a single composite image containing the main image with bounding boxes
262
+ and separate heatmap snippets for each detected object.
263
+
264
+ Args:
265
+ model: YOLOv10 model
266
+ img_path: Path to the input image
267
+ embed_layers: List of layer indices to extract features from
268
+ output_dir: Directory to save outputs
269
+ conf: Object detection confidence threshold
270
+ """
271
+
272
+ # Load image
273
+ img = cv2.imread(img_path)
274
+ if img is None:
275
+ raise FileNotFoundError(f"Could not read image at {img_path}")
276
+
277
+ print(f"Processing image: {img_path}")
278
+
279
+ # Get results with features
280
+ results_with_feat = get_result_with_features_yolov10_simple(model, img_path, embed_layers, conf=conf)
281
+
282
+ if not results_with_feat or not isinstance(results_with_feat, list) or len(results_with_feat) == 0:
283
+ print("No results returned.")
284
+ return
285
+
286
+ result = results_with_feat[0]
287
+ if not hasattr(result, 'boxes') or len(result.boxes) == 0:
288
+ print("No objects detected in the image.")
289
+ return
290
+
291
+ num_objects = len(result.boxes)
292
+ print(f"Total objects detected: {num_objects}. Generating composite layout...")
293
+
294
+ # Get class names
295
+ all_class_names = [model.model.names[int(cls)] for cls in result.boxes.cls]
296
+
297
+ # --- Step 1: Create the main image with modern bounding boxes ---
298
+ main_image_with_boxes = img.copy()
299
+ colors = [(71, 224, 253), (159, 128, 255), (159, 227, 128), (255, 191, 0), (255, 165, 0), (255, 0, 255)]
300
+ for i in range(num_objects):
301
+ label = f"{all_class_names[i]} {result.boxes.conf[i]:.2f}"
302
+ color = colors[i % len(colors)]
303
+ draw_modern_bbox(main_image_with_boxes, result.boxes.xyxy[i].cpu().numpy(), label, color)
304
+
305
+ # --- Step 2: Generate individual heatmap snippets for each object ---
306
+ heatmap_snippets = []
307
+ if hasattr(result, 'pooled_feats') and result.pooled_feats:
308
+ last_layer_pooled_feats = result.pooled_feats[-1]
309
+ for i in range(num_objects):
310
+ box = result.boxes.xyxy[i]
311
+ feature_map = last_layer_pooled_feats[i]
312
+
313
+ heatmap_on_full = draw_feature_heatmap(img.copy(), box, feature_map)
314
+ x1, y1, x2, y2 = box.cpu().numpy().astype(int)
315
+ snippet = heatmap_on_full[y1:y2, x1:x2]
316
+
317
+ label_text = f"Obj #{i}: {all_class_names[i]}"
318
+ font = cv2.FONT_HERSHEY_SIMPLEX
319
+ (text_w, text_h), _ = cv2.getTextSize(label_text, font, 0.6, 1)
320
+
321
+ h, w, _ = snippet.shape
322
+
323
+ # Make the snippet canvas wide enough for the text label
324
+ new_w = max(w, text_w + 10)
325
+ snippet_with_label = np.full((h + text_h + 15, new_w, 3), 255, dtype=np.uint8)
326
+
327
+ # Paste the snippet (centered) onto the new canvas
328
+ paste_x = (new_w - w) // 2
329
+ snippet_with_label[0:h, paste_x:paste_x+w] = snippet
330
+
331
+ # Draw the label text (centered)
332
+ text_x = (new_w - text_w) // 2
333
+ cv2.putText(snippet_with_label, label_text, (text_x, h + text_h + 5), font, 0.6, (0,0,0), 1, cv2.LINE_AA)
334
+ cv2.rectangle(snippet_with_label, (0,0), (new_w-1, h+text_h+14), (180,180,180), 1)
335
+ heatmap_snippets.append(snippet_with_label)
336
+
337
+ if not heatmap_snippets:
338
+ print("No heatmaps generated. Saving image with bounding boxes only.")
339
+ image_name = Path(img_path).stem
340
+ save_path = Path(output_dir) / f"{image_name}_layout.png"
341
+ cv2.imwrite(str(save_path), main_image_with_boxes)
342
+ return
343
+
344
+ # --- Step 3: Arrange snippets and main image into a final composite image ---
345
+ main_h, main_w, _ = main_image_with_boxes.shape
346
+ padding = 20
347
+
348
+ # Arrange snippets into a horizontal row
349
+ snippets_row_h = max(s.shape[0] for s in heatmap_snippets)
350
+ total_snippets_w = sum(s.shape[1] for s in heatmap_snippets) + (len(heatmap_snippets) - 1) * 10
351
+
352
+ snippets_row = np.full((snippets_row_h, total_snippets_w, 3), 255, dtype=np.uint8)
353
+ current_x = 0
354
+ for snippet in heatmap_snippets:
355
+ h, w, _ = snippet.shape
356
+ paste_y = (snippets_row_h - h) // 2
357
+ snippets_row[paste_y:paste_y+h, current_x:current_x+w] = snippet
358
+ current_x += w + 10
359
+
360
+ # Create the final canvas and place the main image and the snippet row
361
+ canvas_h = main_h + snippets_row_h + 3 * padding
362
+ canvas_w = max(main_w, total_snippets_w) + 2 * padding
363
+ final_image = np.full((canvas_h, canvas_w, 3), 255, dtype=np.uint8)
364
+
365
+ # Paste main image at top-center
366
+ x_offset_main = (canvas_w - main_w) // 2
367
+ final_image[padding:padding+main_h, x_offset_main:x_offset_main+main_w] = main_image_with_boxes
368
+
369
+ # Paste snippet row at bottom-center
370
+ x_offset_snippets = (canvas_w - total_snippets_w) // 2
371
+ y_offset_snippets = main_h + 2 * padding
372
+ final_image[y_offset_snippets:y_offset_snippets+snippets_row_h, x_offset_snippets:x_offset_snippets+total_snippets_w] = snippets_row
373
+
374
+ # --- Step 4: Save the final composite image ---
375
+ image_name = Path(img_path).stem
376
+ heatmap_path = Path(output_dir) / f"{image_name}_heatmap_layout.png"
377
+ cv2.imwrite(str(heatmap_path), final_image)
378
+ print(f" - Saved composite heatmap layout to: {heatmap_path}")
379
+
380
+
381
+ def main():
382
+ parser = argparse.ArgumentParser(description='Generate a composite feature heatmap for all detected objects in an image or a directory of images.')
383
+ group = parser.add_mutually_exclusive_group(required=True)
384
+ group.add_argument('--image', '-i', type=str, help='Path to a single input image.')
385
+ group.add_argument('--input-dir', '-d', type=str, help='Path to a directory of input images.')
386
+
387
+ parser.add_argument('--model', '-m', type=str, default='yolov10n.pt', help='Path to YOLOv10 model')
388
+ parser.add_argument('--output', '-o', type=str, default='./heatmaps', help='Output directory for generated layouts.')
389
+ parser.add_argument('--conf', type=float, default=0.25, help='Object detection confidence threshold (e.g., 0.1 for more detections).')
390
+
391
+ args = parser.parse_args()
392
+
393
+ # Create output directory if it doesn't exist
394
+ Path(args.output).mkdir(parents=True, exist_ok=True)
395
+
396
+ # Load YOLOv10 model
397
+ print(f"Loading model: {args.model}")
398
+ model = YOLO(args.model)
399
+
400
+ # Monkey patch the model's prediction method
401
+ model.model._predict_once = MethodType(_predict_once, model.model)
402
+
403
+ # Initialize the predictor by running a dummy inference
404
+ model(np.zeros((640, 640, 3)), verbose=False)
405
+
406
+ # Dynamically find the feature map layer indices from the model
407
+ detect_layer_index = -1
408
+ for i, m in enumerate(model.model.model):
409
+ if 'Detect' in type(m).__name__:
410
+ detect_layer_index = i
411
+ break
412
+
413
+ if detect_layer_index != -1:
414
+ input_layers_indices = model.model.model[detect_layer_index].f
415
+ embed_layers = sorted(input_layers_indices) + [detect_layer_index]
416
+ print(f"Auto-detected feature layers at indices: {input_layers_indices}")
417
+ print(f"Embedding features from layers: {embed_layers}")
418
+ else:
419
+ print("Could not find Detect layer, falling back to hardcoded indices")
420
+ embed_layers = [16, 19, 22, 23]
421
+
422
+ # Process either a single image or a directory of images
423
+ if args.input_dir:
424
+ input_path = Path(args.input_dir)
425
+ image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.tif', '*.tiff']
426
+ image_files = []
427
+ for ext in image_extensions:
428
+ image_files.extend(input_path.glob(ext))
429
+
430
+ if not image_files:
431
+ print(f"No images found in '{args.input_dir}'.")
432
+ return
433
+
434
+ print(f"\nFound {len(image_files)} images in '{args.input_dir}'. Processing...")
435
+ for img_path in image_files:
436
+ generate_feature_heatmaps(
437
+ model=model,
438
+ img_path=str(img_path),
439
+ embed_layers=embed_layers,
440
+ output_dir=args.output,
441
+ conf=args.conf
442
+ )
443
+ else: # if args.image
444
+ generate_feature_heatmaps(
445
+ model=model,
446
+ img_path=args.image,
447
+ embed_layers=embed_layers,
448
+ output_dir=args.output,
449
+ conf=args.conf
450
+ )
451
+
452
+ print(f"\nProcessing complete. All layouts saved to '{args.output}'.")
453
+
454
+
455
+ if __name__ == "__main__":
456
+ # If run without arguments, use test image
457
+ import sys
458
+ if len(sys.argv) == 1:
459
+ print("No arguments provided. Running heatmap generation on a test image.")
460
+
461
+ # Load YOLOv10 model
462
+ print("Loading default model: yolov10n.pt")
463
+ model = YOLO('yolov10n.pt')
464
+ model.model._predict_once = MethodType(_predict_once, model.model)
465
+ model(np.zeros((640, 640, 3)), verbose=False)
466
+
467
+ # Auto-detect layers
468
+ detect_layer_index = -1
469
+ for i, m in enumerate(model.model.model):
470
+ if 'Detect' in type(m).__name__:
471
+ detect_layer_index = i
472
+ break
473
+
474
+ if detect_layer_index != -1:
475
+ input_layers_indices = model.model.model[detect_layer_index].f
476
+ embed_layers = sorted(input_layers_indices) + [detect_layer_index]
477
+ print(f"Auto-detected feature layers at indices: {input_layers_indices}")
478
+ else:
479
+ embed_layers = [16, 19, 22, 23]
480
+
481
+ # Define test image path
482
+ img_path = "/home/hew/yolov10FX_obj/id-1.jpg"
483
+
484
+ # Generate heatmaps for the test image
485
+ print("Using a lower confidence of 0.1 for test mode to find more objects.")
486
+ generate_feature_heatmaps(
487
+ model=model,
488
+ img_path=img_path,
489
+ embed_layers=embed_layers,
490
+ output_dir="./",
491
+ conf=0.1
492
+ )
493
+ print(f"\nHeatmap generation completed successfully for test image!")
494
+
495
+ else:
496
+ main()