DariusGiannoli commited on
Commit
5cc06e1
Β·
1 Parent(s): ace5a51

feat: multi-object, ORB, evaluation page, feature importance

Browse files
app.py CHANGED
@@ -16,26 +16,31 @@ st.divider()
16
  # ===================================================================
17
  st.header("πŸ—ΊοΈ Pipeline Overview")
18
  st.markdown("""
19
- The app is structured as a **5-stage sequential pipeline**.
20
  Complete each page in order β€” every stage feeds the next.
21
  """)
22
 
23
  stages = [
24
  ("πŸ§ͺ", "1 Β· Data Lab", "Upload a stereo image pair, camera calibration file, and two PFM ground-truth depth maps. "
25
- "Define an object ROI (bounding box), then apply live data augmentation "
26
  "(brightness, contrast, rotation, noise, blur, shift, flip). "
27
  "All assets are locked into session state β€” nothing is written to disk."),
28
  ("πŸ”¬", "2 Β· Feature Lab", "Toggle RCE physics modules (Intensity Β· Sobel Β· Spectral) to build a modular "
29
  "feature vector. Compare it live against CNN activation maps extracted from a "
30
  "frozen backbone via forward hooks. Lock your active module configuration."),
31
  ("βš™οΈ", "3 Β· Model Tuning", "Train lightweight **heads** on your session data (augmented crop = positives, "
32
- "random non-overlapping patches from the scene = negatives). "
33
- "Both RCE and CNN heads are trained identically with LogisticRegression "
34
- "and stored in session state only β€” no disk writes."),
35
- ("🎯", "4 · Real-Time Detection","Run a **sliding window** across the right image using both the RCE head and "
36
- "your chosen CNN head simultaneously. Watch the scan live, then compare "
37
- "bounding boxes, confidence heatmaps, and latency."),
38
- ("πŸ“", "5 Β· Stereo Geometry", "Compute a disparity map with **StereoSGBM**, convert it to metric depth "
 
 
 
 
 
39
  "using the stereo formula $Z = fB/(d+d_{\\text{offs}})$, then read depth "
40
  "directly at every detected bounding box. Compare against PFM ground truth."),
41
  ]
 
16
  # ===================================================================
17
  st.header("πŸ—ΊοΈ Pipeline Overview")
18
  st.markdown("""
19
+ The app is structured as a **7-stage sequential pipeline**.
20
  Complete each page in order β€” every stage feeds the next.
21
  """)
22
 
23
  stages = [
24
  ("πŸ§ͺ", "1 Β· Data Lab", "Upload a stereo image pair, camera calibration file, and two PFM ground-truth depth maps. "
25
+ "Define one or more object ROIs (bounding boxes) with class labels, then apply live data augmentation "
26
  "(brightness, contrast, rotation, noise, blur, shift, flip). "
27
  "All assets are locked into session state β€” nothing is written to disk."),
28
  ("πŸ”¬", "2 Β· Feature Lab", "Toggle RCE physics modules (Intensity Β· Sobel Β· Spectral) to build a modular "
29
  "feature vector. Compare it live against CNN activation maps extracted from a "
30
  "frozen backbone via forward hooks. Lock your active module configuration."),
31
  ("βš™οΈ", "3 Β· Model Tuning", "Train lightweight **heads** on your session data (augmented crop = positives, "
32
+ "random non-overlapping patches = negatives). Compare three paradigms side by side: "
33
+ "RCE (with feature importance), CNN (with activation overlay), and ORB (keypoint matching)."),
34
+ ("πŸ”", "4 Β· Localization Lab", "Compare **five localization strategies** on top of your trained head: "
35
+ "Exhaustive Sliding Window, Image Pyramid (multi-scale), Coarse-to-Fine "
36
+ "hierarchical search, Contour Proposals (edge-driven), and Template "
37
+ "Matching (cross-correlation)."),
38
+ ("🎯", "5 · Real-Time Detection","Run a **sliding window** across the right image using RCE, CNN, and ORB "
39
+ "simultaneously. Watch the scan live, then compare bounding boxes, "
40
+ "confidence heatmaps, and latency across all three methods."),
41
+ ("πŸ“ˆ", "6 Β· Evaluation", "Quantitative evaluation with **confusion matrices**, **precision-recall curves**, "
42
+ "and **F1 scores** per method. Ground truth is derived from your Data Lab ROIs."),
43
+ ("πŸ“", "7 Β· Stereo Geometry", "Compute a disparity map with **StereoSGBM**, convert it to metric depth "
44
  "using the stereo formula $Z = fB/(d+d_{\\text{offs}})$, then read depth "
45
  "directly at every detected bounding box. Compare against PFM ground truth."),
46
  ]
pages/2_Data_Lab.py CHANGED
@@ -146,31 +146,85 @@ if up_l and up_r and up_conf and up_gt_l and up_gt_r:
146
  st.text_area("Raw Config", conf_content, height=200)
147
 
148
  # -----------------------------------------------------------------------
149
- # Step 3 β€” Crop ROI from Left Image
150
  # -----------------------------------------------------------------------
151
  st.divider()
152
- st.subheader("Step 3: Crop Region of Interest")
153
- st.write("Define the bounding box of the object you want to recognise (pixels).")
154
 
155
  H, W = img_l.shape[:2]
156
- cr1, cr2, cr3, cr4 = st.columns(4)
157
- x0 = cr1.number_input("X start", 0, W - 2, 0, step=1)
158
- y0 = cr2.number_input("Y start", 0, H - 2, 0, step=1)
159
- x1 = cr3.number_input("X end", int(x0) + 1, W, min(W, int(x0) + 100), step=1)
160
- y1 = cr4.number_input("Y end", int(y0) + 1, H, min(H, int(y0) + 100), step=1)
161
 
162
- x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1)
163
-
164
- # Overlay rectangle on left image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  overlay = img_l.copy()
166
- cv2.rectangle(overlay, (x0, y0), (x1, y1), (0, 255, 0), 2)
167
- crop_bgr = img_l[y0:y1, x0:x1].copy()
168
-
169
- ov1, ov2 = st.columns([3, 1])
 
 
 
 
 
 
170
  ov1.image(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB),
171
- caption="Left Image β€” ROI highlighted", use_container_width=True)
172
- ov2.image(cv2.cvtColor(crop_bgr, cv2.COLOR_BGR2RGB),
173
- caption="Crop", use_container_width=True)
 
 
 
 
 
 
 
 
 
 
174
 
175
  # -----------------------------------------------------------------------
176
  # Step 4 β€” Data Augmentation
@@ -195,28 +249,49 @@ if up_l and up_r and up_conf and up_gt_l and up_gt_r:
195
  aug = augment(crop_bgr, brightness, contrast, rotation,
196
  flip_h, flip_v, noise, blur, shift_x, shift_y)
197
 
 
 
 
 
 
198
  aug_col1, aug_col2 = st.columns(2)
199
  aug_col1.image(cv2.cvtColor(crop_bgr, cv2.COLOR_BGR2RGB),
200
- caption="Original Crop", use_container_width=True)
201
  aug_col2.image(cv2.cvtColor(aug, cv2.COLOR_BGR2RGB),
202
- caption="Augmented Crop", use_container_width=True)
 
 
 
203
 
204
  # -----------------------------------------------------------------------
205
  # Step 5 β€” Lock & Store
206
  # -----------------------------------------------------------------------
207
  st.divider()
208
  if st.button("πŸš€ Lock Data & Proceed to Benchmark"):
 
 
 
 
 
 
 
 
 
209
  st.session_state["pipeline_data"] = {
210
  "left": img_l,
211
  "right": img_r,
212
  "gt_left": gt_depth_l,
213
  "gt_right": gt_depth_r,
214
  "conf_raw": conf_content,
 
215
  "crop": crop_bgr,
216
  "crop_aug": aug,
217
  "crop_bbox": (x0, y0, x1, y1),
 
 
218
  }
219
- st.success("Data stored in session! Move to the 'Recognition' or 'Tuning' page.")
 
220
 
221
  else:
222
  st.info("Please upload all 5 files (left image, right image, config, left GT, right GT) to proceed.")
 
146
  st.text_area("Raw Config", conf_content, height=200)
147
 
148
  # -----------------------------------------------------------------------
149
+ # Step 3 β€” Crop ROI(s) from Left Image (Multi-Object)
150
  # -----------------------------------------------------------------------
151
  st.divider()
152
+ st.subheader("Step 3: Crop Region(s) of Interest")
153
+ st.write("Define one or more bounding boxes β€” each becomes a separate class for recognition.")
154
 
155
  H, W = img_l.shape[:2]
 
 
 
 
 
156
 
157
+ # Manage list of ROIs in session state
158
+ if "rois" not in st.session_state:
159
+ st.session_state["rois"] = [{"label": "object", "x0": 0, "y0": 0,
160
+ "x1": min(W, 100), "y1": min(H, 100)}]
161
+
162
+ def _add_roi():
163
+ st.session_state["rois"].append(
164
+ {"label": f"object_{len(st.session_state['rois'])+1}",
165
+ "x0": 0, "y0": 0,
166
+ "x1": min(W, 100), "y1": min(H, 100)})
167
+
168
+ def _remove_roi(idx):
169
+ if len(st.session_state["rois"]) > 1:
170
+ st.session_state["rois"].pop(idx)
171
+
172
+ ROI_COLORS = [(0,255,0), (255,0,0), (0,0,255), (255,255,0),
173
+ (255,0,255), (0,255,255), (128,255,0), (255,128,0)]
174
+
175
+ for i, roi in enumerate(st.session_state["rois"]):
176
+ color = ROI_COLORS[i % len(ROI_COLORS)]
177
+ color_hex = "#{:02x}{:02x}{:02x}".format(*color)
178
+ with st.container(border=True):
179
+ hc1, hc2, hc3 = st.columns([3, 6, 1])
180
+ hc1.markdown(f"**ROI {i+1}** <span style='color:{color_hex}'>β– </span>",
181
+ unsafe_allow_html=True)
182
+ roi["label"] = hc2.text_input("Class Label", roi["label"],
183
+ key=f"roi_lbl_{i}")
184
+ if len(st.session_state["rois"]) > 1:
185
+ hc3.button("βœ•", key=f"roi_del_{i}",
186
+ on_click=_remove_roi, args=(i,))
187
+
188
+ cr1, cr2, cr3, cr4 = st.columns(4)
189
+ roi["x0"] = int(cr1.number_input("X start", 0, W-2, int(roi["x0"]),
190
+ step=1, key=f"roi_x0_{i}"))
191
+ roi["y0"] = int(cr2.number_input("Y start", 0, H-2, int(roi["y0"]),
192
+ step=1, key=f"roi_y0_{i}"))
193
+ roi["x1"] = int(cr3.number_input("X end", roi["x0"]+1, W,
194
+ min(W, int(roi["x1"])),
195
+ step=1, key=f"roi_x1_{i}"))
196
+ roi["y1"] = int(cr4.number_input("Y end", roi["y0"]+1, H,
197
+ min(H, int(roi["y1"])),
198
+ step=1, key=f"roi_y1_{i}"))
199
+
200
+ st.button("βž• Add Another ROI", on_click=_add_roi)
201
+
202
+ # Draw all ROIs on the image
203
  overlay = img_l.copy()
204
+ crops = []
205
+ for i, roi in enumerate(st.session_state["rois"]):
206
+ color = ROI_COLORS[i % len(ROI_COLORS)]
207
+ x0, y0, x1, y1 = roi["x0"], roi["y0"], roi["x1"], roi["y1"]
208
+ cv2.rectangle(overlay, (x0, y0), (x1, y1), color, 2)
209
+ cv2.putText(overlay, roi["label"], (x0, y0 - 6),
210
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
211
+ crops.append(img_l[y0:y1, x0:x1].copy())
212
+
213
+ ov1, ov2 = st.columns([3, 2])
214
  ov1.image(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB),
215
+ caption="Left Image β€” ROIs highlighted", use_container_width=True)
216
+ with ov2:
217
+ for i, (c, roi) in enumerate(zip(crops, st.session_state["rois"])):
218
+ st.image(cv2.cvtColor(c, cv2.COLOR_BGR2RGB),
219
+ caption=f"{roi['label']} ({c.shape[1]}Γ—{c.shape[0]})",
220
+ width=160)
221
+
222
+ # For backward compatibility: first ROI is the "primary"
223
+ crop_bgr = crops[0]
224
+ x0, y0, x1, y1 = (st.session_state["rois"][0]["x0"],
225
+ st.session_state["rois"][0]["y0"],
226
+ st.session_state["rois"][0]["x1"],
227
+ st.session_state["rois"][0]["y1"])
228
 
229
  # -----------------------------------------------------------------------
230
  # Step 4 β€” Data Augmentation
 
249
  aug = augment(crop_bgr, brightness, contrast, rotation,
250
  flip_h, flip_v, noise, blur, shift_x, shift_y)
251
 
252
+ # Apply same augmentation to all crops
253
+ all_augs = [augment(c, brightness, contrast, rotation,
254
+ flip_h, flip_v, noise, blur, shift_x, shift_y)
255
+ for c in crops]
256
+
257
  aug_col1, aug_col2 = st.columns(2)
258
  aug_col1.image(cv2.cvtColor(crop_bgr, cv2.COLOR_BGR2RGB),
259
+ caption="Original Crop (ROI 1)", use_container_width=True)
260
  aug_col2.image(cv2.cvtColor(aug, cv2.COLOR_BGR2RGB),
261
+ caption="Augmented Crop (ROI 1)", use_container_width=True)
262
+
263
+ if len(crops) > 1:
264
+ st.caption(f"Augmentation applied identically to all {len(crops)} ROIs.")
265
 
266
  # -----------------------------------------------------------------------
267
  # Step 5 β€” Lock & Store
268
  # -----------------------------------------------------------------------
269
  st.divider()
270
  if st.button("πŸš€ Lock Data & Proceed to Benchmark"):
271
+ rois_data = []
272
+ for i, roi in enumerate(st.session_state["rois"]):
273
+ rois_data.append({
274
+ "label": roi["label"],
275
+ "bbox": (roi["x0"], roi["y0"], roi["x1"], roi["y1"]),
276
+ "crop": crops[i],
277
+ "crop_aug": all_augs[i],
278
+ })
279
+
280
  st.session_state["pipeline_data"] = {
281
  "left": img_l,
282
  "right": img_r,
283
  "gt_left": gt_depth_l,
284
  "gt_right": gt_depth_r,
285
  "conf_raw": conf_content,
286
+ # Backward compatibility: first ROI
287
  "crop": crop_bgr,
288
  "crop_aug": aug,
289
  "crop_bbox": (x0, y0, x1, y1),
290
+ # Multi-object
291
+ "rois": rois_data,
292
  }
293
+ st.success(f"Data stored with **{len(rois_data)} ROI(s)**! "
294
+ f"Move to Feature Lab.")
295
 
296
  else:
297
  st.info("Please upload all 5 files (left image, right image, config, left GT, right GT) to proceed.")
pages/4_Model_Tuning.py CHANGED
@@ -20,41 +20,52 @@ if "pipeline_data" not in st.session_state or "crop" not in st.session_state.get
20
  st.stop()
21
 
22
  assets = st.session_state["pipeline_data"]
23
- crop = assets["crop"] # original crop from Data Lab
24
- crop_aug = assets.get("crop_aug", crop) # augmented crop from Data Lab
25
- left_img = assets["left"] # full left image
26
  bbox = assets.get("crop_bbox", (0, 0, crop.shape[1], crop.shape[0]))
 
 
27
  active_modules = st.session_state.get("active_modules", {k: True for k in REGISTRY})
28
 
 
29
 
30
  # ---------------------------------------------------------------------------
31
  # Build training set from session data (no disk reads)
32
  # ---------------------------------------------------------------------------
33
- def build_training_set(augment_fn=None):
34
  """
35
- Positive samples: original crop + augmented crop from Data Lab.
36
- Negative samples: random patches from the left image that do NOT
37
- overlap with the crop bounding box.
38
- Returns (images_list, labels_list).
39
  """
40
- positives = [crop, crop_aug]
41
- if augment_fn is not None:
42
- positives.append(augment_fn(crop))
43
 
44
- # --- Generate negatives from left image margins ---
45
- x0, y0, x1, y1 = bbox
 
 
 
 
 
46
  H, W = left_img.shape[:2]
47
- ch, cw = y1 - y0, x1 - x0 # crop height/width
48
- negatives = []
49
  rng = np.random.default_rng(42)
50
 
 
51
  attempts = 0
52
- while len(negatives) < len(positives) * 2 and attempts < 200:
53
- # Random patch of same size as crop
54
  rx = rng.integers(0, max(W - cw, 1))
55
  ry = rng.integers(0, max(H - ch, 1))
56
- # Reject if it overlaps the crop bbox (IoU > 0)
57
- if rx < x1 and rx + cw > x0 and ry < y1 and ry + ch > y0:
 
 
 
 
58
  attempts += 1
59
  continue
60
  patch = left_img[ry:ry+ch, rx:rx+cw]
@@ -62,13 +73,12 @@ def build_training_set(augment_fn=None):
62
  negatives.append(patch)
63
  attempts += 1
64
 
65
- images = positives + negatives
66
- labels = ["object"] * len(positives) + ["background"] * len(negatives)
67
  return images, labels
68
 
69
 
70
  def build_rce_vector(img_bgr):
71
- """Build the RCE feature vector from active modules."""
72
  gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
73
  vec = []
74
  for key, meta in REGISTRY.items():
@@ -79,20 +89,31 @@ def build_rce_vector(img_bgr):
79
 
80
 
81
  # ===================================================================
82
- # Show data used for training
83
  # ===================================================================
84
  st.subheader("Training Data (from Data Lab)")
85
- st.caption("Positives = your crop + augmented crop | Negatives = random non-overlapping patches from left image")
86
- td1, td2 = st.columns(2)
87
- td1.image(cv2.cvtColor(crop, cv2.COLOR_BGR2RGB), caption="Original Crop (positive)", width=180)
88
- td2.image(cv2.cvtColor(crop_aug, cv2.COLOR_BGR2RGB), caption="Augmented Crop (positive)", width=180)
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  st.divider()
91
 
92
  # ===================================================================
93
- # LAYOUT: LEFT = RCE | RIGHT = CNN
94
  # ===================================================================
95
- col_rce, col_cnn = st.columns(2)
96
 
97
  # ---------------------------------------------------------------------------
98
  # LEFT β€” RCE Training
@@ -103,68 +124,121 @@ with col_rce:
103
  active_names = [REGISTRY[k]["label"] for k in active_modules if active_modules[k]]
104
  if not active_names:
105
  st.error("No RCE modules selected. Go back to Feature Lab.")
106
- st.stop()
107
- st.write(f"**Active modules:** {', '.join(active_names)}")
108
-
109
- st.subheader("Training Parameters")
110
- rce_C = st.slider("Regularization (C)", 0.01, 10.0, 1.0, step=0.01,
111
- help="Higher = less regularization, may overfit")
112
- rce_max_iter = st.slider("Max Iterations", 100, 5000, 1000, step=100)
113
-
114
- if st.button("πŸš€ Train RCE Head"):
115
- images, labels = build_training_set()
116
- from sklearn.metrics import accuracy_score
117
-
118
- progress = st.progress(0, text="Extracting RCE features...")
119
- n = len(images)
120
- X = []
121
- for i, img in enumerate(images):
122
- X.append(build_rce_vector(img))
123
- progress.progress((i + 1) / n, text=f"Feature extraction: {i+1}/{n}")
124
-
125
- X = np.array(X)
126
- progress.progress(1.0, text="Fitting Logistic Regression...")
127
-
128
- t0 = time.perf_counter()
129
- head = RecognitionHead(C=rce_C, max_iter=rce_max_iter).fit(X, labels)
130
- train_time = time.perf_counter() - t0
131
- progress.progress(1.0, text="βœ… Training complete!")
132
-
133
- preds = head.model.predict(X)
134
- train_acc = accuracy_score(labels, preds)
135
-
136
- st.success(f"Trained in **{train_time:.2f}s**")
137
- m1, m2, m3 = st.columns(3)
138
- m1.metric("Train Accuracy", f"{train_acc:.1%}")
139
- m2.metric("Vector Size", f"{X.shape[1]} floats")
140
- m3.metric("Samples", f"{len(images)}")
141
-
142
- probs = head.predict_proba(X)
143
- fig = go.Figure()
144
- for ci, cls in enumerate(head.classes_):
145
- fig.add_trace(go.Histogram(x=probs[:, ci], name=cls, opacity=0.7, nbinsx=20))
146
- fig.update_layout(title="Confidence Distribution", barmode="overlay",
147
- template="plotly_dark", height=280,
148
- xaxis_title="Confidence", yaxis_title="Count")
149
- st.plotly_chart(fig, use_container_width=True)
150
-
151
- # Store head in session (no disk save)
152
- st.session_state["rce_head"] = head
153
- st.session_state["rce_train_acc"] = train_acc
154
-
155
- if "rce_head" in st.session_state:
156
- st.divider()
157
- st.subheader("Quick Predict (Crop)")
158
- head = st.session_state["rce_head"]
159
- t0 = time.perf_counter()
160
- vec = build_rce_vector(crop_aug)
161
- label, conf = head.predict(vec)
162
- dt = (time.perf_counter() - t0) * 1000
163
- st.write(f"**{label}** β€” {conf:.1%} confidence β€” {dt:.1f} ms")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
 
166
  # ---------------------------------------------------------------------------
167
- # RIGHT β€” CNN Fine-Tuning
168
  # ---------------------------------------------------------------------------
169
  with col_cnn:
170
  st.header("🧠 CNN Fine-Tuning")
@@ -181,7 +255,7 @@ with col_cnn:
181
 
182
  if st.button(f"πŸš€ Train {selected} Head"):
183
  images, labels = build_training_set()
184
- backbone = meta["loader"]() # cached frozen backbone
185
 
186
  from sklearn.metrics import accuracy_score
187
 
@@ -208,24 +282,46 @@ with col_cnn:
208
  m1.metric("Train Accuracy", f"{train_acc:.1%}")
209
  m2.metric("Vector Size", f"{X.shape[1]}D")
210
  m3.metric("Samples", f"{len(images)}")
 
 
211
 
212
  probs = head.predict_proba(X)
213
  fig = go.Figure()
214
  for ci, cls in enumerate(head.classes_):
215
- fig.add_trace(go.Histogram(x=probs[:, ci], name=cls, opacity=0.7, nbinsx=20))
 
216
  fig.update_layout(title="Confidence Distribution", barmode="overlay",
217
  template="plotly_dark", height=280,
218
  xaxis_title="Confidence", yaxis_title="Count")
219
  st.plotly_chart(fig, use_container_width=True)
220
 
221
- # Store head in session (no disk save)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  st.session_state[f"cnn_head_{selected}"] = head
223
  st.session_state[f"cnn_acc_{selected}"] = train_acc
224
 
225
  if f"cnn_head_{selected}" in st.session_state:
226
  st.divider()
227
  st.subheader("Quick Predict (Crop)")
228
- backbone = meta["loader"]() # cached frozen backbone
229
  head = st.session_state[f"cnn_head_{selected}"]
230
  t0 = time.perf_counter()
231
  feats = backbone.get_features(crop_aug)
@@ -234,22 +330,110 @@ with col_cnn:
234
  st.write(f"**{label}** β€” {conf:.1%} confidence β€” {dt:.1f} ms")
235
 
236
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  # ===========================================================================
238
  # Bottom β€” Side-by-side comparison table
239
  # ===========================================================================
240
  st.divider()
241
  st.subheader("πŸ“Š Training Comparison")
242
 
243
- rce_acc = st.session_state.get("rce_train_acc")
244
  rows = []
 
245
  if rce_acc is not None:
246
- rows.append({"Model": "RCE", "Train Accuracy": f"{rce_acc:.1%}",
 
247
  "Vector Size": str(sum(10 for k in active_modules if active_modules[k]))})
248
  for name in BACKBONES:
249
  acc = st.session_state.get(f"cnn_acc_{name}")
250
  if acc is not None:
251
- rows.append({"Model": name, "Train Accuracy": f"{acc:.1%}",
 
252
  "Vector Size": f"{BACKBONES[name]['dim']}D"})
 
 
 
 
 
253
 
254
  if rows:
255
  import pandas as pd
 
20
  st.stop()
21
 
22
  assets = st.session_state["pipeline_data"]
23
+ crop = assets["crop"]
24
+ crop_aug = assets.get("crop_aug", crop)
25
+ left_img = assets["left"]
26
  bbox = assets.get("crop_bbox", (0, 0, crop.shape[1], crop.shape[0]))
27
+ rois = assets.get("rois", [{"label": "object", "bbox": bbox,
28
+ "crop": crop, "crop_aug": crop_aug}])
29
  active_modules = st.session_state.get("active_modules", {k: True for k in REGISTRY})
30
 
31
+ is_multi = len(rois) > 1
32
 
33
  # ---------------------------------------------------------------------------
34
  # Build training set from session data (no disk reads)
35
  # ---------------------------------------------------------------------------
36
+ def build_training_set():
37
  """
38
+ Multi-class aware training set builder.
39
+ Positive samples per class: original crop + augmented crop.
40
+ Negative samples: random patches that don't overlap ANY ROI.
 
41
  """
42
+ images = []
43
+ labels = []
 
44
 
45
+ for roi in rois:
46
+ images.append(roi["crop"])
47
+ labels.append(roi["label"])
48
+ images.append(roi["crop_aug"])
49
+ labels.append(roi["label"])
50
+
51
+ all_bboxes = [roi["bbox"] for roi in rois]
52
  H, W = left_img.shape[:2]
53
+ x0r, y0r, x1r, y1r = rois[0]["bbox"]
54
+ ch, cw = y1r - y0r, x1r - x0r
55
  rng = np.random.default_rng(42)
56
 
57
+ n_neg_target = len(images) * 2
58
  attempts = 0
59
+ negatives = []
60
+ while len(negatives) < n_neg_target and attempts < 300:
61
  rx = rng.integers(0, max(W - cw, 1))
62
  ry = rng.integers(0, max(H - ch, 1))
63
+ overlaps = False
64
+ for bx0, by0, bx1, by1 in all_bboxes:
65
+ if rx < bx1 and rx + cw > bx0 and ry < by1 and ry + ch > by0:
66
+ overlaps = True
67
+ break
68
+ if overlaps:
69
  attempts += 1
70
  continue
71
  patch = left_img[ry:ry+ch, rx:rx+cw]
 
73
  negatives.append(patch)
74
  attempts += 1
75
 
76
+ images.extend(negatives)
77
+ labels.extend(["background"] * len(negatives))
78
  return images, labels
79
 
80
 
81
  def build_rce_vector(img_bgr):
 
82
  gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
83
  vec = []
84
  for key, meta in REGISTRY.items():
 
89
 
90
 
91
  # ===================================================================
92
+ # Show training data
93
  # ===================================================================
94
  st.subheader("Training Data (from Data Lab)")
95
+ if is_multi:
96
+ st.caption(f"**{len(rois)} classes** defined β€” each ROI becomes a separate class.")
97
+ roi_cols = st.columns(min(len(rois), 4))
98
+ for i, roi in enumerate(rois):
99
+ with roi_cols[i % len(roi_cols)]:
100
+ st.image(cv2.cvtColor(roi["crop"], cv2.COLOR_BGR2RGB),
101
+ caption=f"βœ… {roi['label']}", width=140)
102
+ else:
103
+ st.caption("Positives = your crop + augmented crop | "
104
+ "Negatives = random non-overlapping patches")
105
+ td1, td2 = st.columns(2)
106
+ td1.image(cv2.cvtColor(crop, cv2.COLOR_BGR2RGB),
107
+ caption="Original Crop (positive)", width=180)
108
+ td2.image(cv2.cvtColor(crop_aug, cv2.COLOR_BGR2RGB),
109
+ caption="Augmented Crop (positive)", width=180)
110
 
111
  st.divider()
112
 
113
  # ===================================================================
114
+ # LAYOUT: RCE | CNN | ORB
115
  # ===================================================================
116
+ col_rce, col_cnn, col_orb = st.columns(3)
117
 
118
  # ---------------------------------------------------------------------------
119
  # LEFT β€” RCE Training
 
124
  active_names = [REGISTRY[k]["label"] for k in active_modules if active_modules[k]]
125
  if not active_names:
126
  st.error("No RCE modules selected. Go back to Feature Lab.")
127
+ else:
128
+ st.write(f"**Active modules:** {', '.join(active_names)}")
129
+
130
+ st.subheader("Training Parameters")
131
+ rce_C = st.slider("Regularization (C)", 0.01, 10.0, 1.0, step=0.01,
132
+ help="Higher = less regularization, may overfit")
133
+ rce_max_iter = st.slider("Max Iterations", 100, 5000, 1000, step=100)
134
+
135
+ if st.button("πŸš€ Train RCE Head"):
136
+ images, labels = build_training_set()
137
+ from sklearn.metrics import accuracy_score
138
+
139
+ progress = st.progress(0, text="Extracting RCE features...")
140
+ n = len(images)
141
+ X = []
142
+ for i, img in enumerate(images):
143
+ X.append(build_rce_vector(img))
144
+ progress.progress((i + 1) / n, text=f"Feature extraction: {i+1}/{n}")
145
+
146
+ X = np.array(X)
147
+ progress.progress(1.0, text="Fitting Logistic Regression...")
148
+
149
+ t0 = time.perf_counter()
150
+ head = RecognitionHead(C=rce_C, max_iter=rce_max_iter).fit(X, labels)
151
+ train_time = time.perf_counter() - t0
152
+ progress.progress(1.0, text="βœ… Training complete!")
153
+
154
+ preds = head.model.predict(X)
155
+ train_acc = accuracy_score(labels, preds)
156
+
157
+ st.success(f"Trained in **{train_time:.2f}s**")
158
+ m1, m2, m3 = st.columns(3)
159
+ m1.metric("Train Accuracy", f"{train_acc:.1%}")
160
+ m2.metric("Vector Size", f"{X.shape[1]} floats")
161
+ m3.metric("Samples", f"{len(images)}")
162
+ if is_multi:
163
+ st.caption(f"Classes: {', '.join(head.classes_)}")
164
+
165
+ probs = head.predict_proba(X)
166
+ fig = go.Figure()
167
+ for ci, cls in enumerate(head.classes_):
168
+ fig.add_trace(go.Histogram(x=probs[:, ci], name=cls,
169
+ opacity=0.7, nbinsx=20))
170
+ fig.update_layout(title="Confidence Distribution", barmode="overlay",
171
+ template="plotly_dark", height=280,
172
+ xaxis_title="Confidence", yaxis_title="Count")
173
+ st.plotly_chart(fig, use_container_width=True)
174
+
175
+ # ---- Feature Importance (RCE) ----
176
+ st.subheader("πŸ” Feature Importance")
177
+ coefs = head.model.coef_
178
+ feat_names = []
179
+ for key, meta_r in REGISTRY.items():
180
+ if active_modules.get(key, False):
181
+ for b in range(10):
182
+ feat_names.append(f"{meta_r['label']}[{b}]")
183
+
184
+ if coefs.shape[0] == 1:
185
+ importance = np.abs(coefs[0])
186
+ fig_imp = go.Figure(go.Bar(
187
+ x=feat_names, y=importance,
188
+ marker_color=["#00d4ff" if "Intensity" in fn
189
+ else "#ff6600" if "Sobel" in fn
190
+ else "#aa00ff" for fn in feat_names]))
191
+ fig_imp.update_layout(title="LogReg Coefficient Magnitude",
192
+ template="plotly_dark", height=300,
193
+ xaxis_title="Feature", yaxis_title="|Coefficient|")
194
+ else:
195
+ fig_imp = go.Figure()
196
+ for ci, cls in enumerate(head.classes_):
197
+ if cls == "background":
198
+ continue
199
+ fig_imp.add_trace(go.Bar(
200
+ x=feat_names, y=np.abs(coefs[ci]),
201
+ name=cls, opacity=0.8))
202
+ fig_imp.update_layout(title="LogReg Coefficients per Class",
203
+ template="plotly_dark", height=300,
204
+ barmode="group",
205
+ xaxis_title="Feature", yaxis_title="|Coefficient|")
206
+ st.plotly_chart(fig_imp, use_container_width=True)
207
+
208
+ # Module-level aggregation
209
+ module_importance = {}
210
+ idx = 0
211
+ for key, meta_r in REGISTRY.items():
212
+ if active_modules.get(key, False):
213
+ module_importance[meta_r["label"]] = float(
214
+ np.abs(coefs[:, idx:idx+10]).mean())
215
+ idx += 10
216
+
217
+ if module_importance:
218
+ fig_mod = go.Figure(go.Pie(
219
+ labels=list(module_importance.keys()),
220
+ values=list(module_importance.values()),
221
+ hole=0.4))
222
+ fig_mod.update_layout(title="Module Contribution (avg |coef|)",
223
+ template="plotly_dark", height=280)
224
+ st.plotly_chart(fig_mod, use_container_width=True)
225
+
226
+ st.session_state["rce_head"] = head
227
+ st.session_state["rce_train_acc"] = train_acc
228
+
229
+ if "rce_head" in st.session_state:
230
+ st.divider()
231
+ st.subheader("Quick Predict (Crop)")
232
+ head = st.session_state["rce_head"]
233
+ t0 = time.perf_counter()
234
+ vec = build_rce_vector(crop_aug)
235
+ label, conf = head.predict(vec)
236
+ dt = (time.perf_counter() - t0) * 1000
237
+ st.write(f"**{label}** β€” {conf:.1%} confidence β€” {dt:.1f} ms")
238
 
239
 
240
  # ---------------------------------------------------------------------------
241
+ # MIDDLE β€” CNN Fine-Tuning
242
  # ---------------------------------------------------------------------------
243
  with col_cnn:
244
  st.header("🧠 CNN Fine-Tuning")
 
255
 
256
  if st.button(f"πŸš€ Train {selected} Head"):
257
  images, labels = build_training_set()
258
+ backbone = meta["loader"]()
259
 
260
  from sklearn.metrics import accuracy_score
261
 
 
282
  m1.metric("Train Accuracy", f"{train_acc:.1%}")
283
  m2.metric("Vector Size", f"{X.shape[1]}D")
284
  m3.metric("Samples", f"{len(images)}")
285
+ if is_multi:
286
+ st.caption(f"Classes: {', '.join(head.classes_)}")
287
 
288
  probs = head.predict_proba(X)
289
  fig = go.Figure()
290
  for ci, cls in enumerate(head.classes_):
291
+ fig.add_trace(go.Histogram(x=probs[:, ci], name=cls,
292
+ opacity=0.7, nbinsx=20))
293
  fig.update_layout(title="Confidence Distribution", barmode="overlay",
294
  template="plotly_dark", height=280,
295
  xaxis_title="Confidence", yaxis_title="Count")
296
  st.plotly_chart(fig, use_container_width=True)
297
 
298
+ # ---- Activation Overlay (Grad-CAM style) ----
299
+ st.subheader("πŸ” Activation Overlay")
300
+ st.caption("Highest-activation spatial regions from the hooked layer, "
301
+ "overlaid on the crop as a Grad-CAM–style heatmap.")
302
+ try:
303
+ act_maps = backbone.get_activation_maps(crop_aug, n_maps=1)
304
+ if act_maps:
305
+ cam = act_maps[0]
306
+ cam_resized = cv2.resize(cam, (crop_aug.shape[1], crop_aug.shape[0]))
307
+ cam_color = cv2.applyColorMap(
308
+ (cam_resized * 255).astype(np.uint8), cv2.COLORMAP_JET)
309
+ overlay_img = cv2.addWeighted(crop_aug, 0.5, cam_color, 0.5, 0)
310
+ gc1, gc2 = st.columns(2)
311
+ gc1.image(cv2.cvtColor(crop_aug, cv2.COLOR_BGR2RGB),
312
+ caption="Input Crop", use_container_width=True)
313
+ gc2.image(cv2.cvtColor(overlay_img, cv2.COLOR_BGR2RGB),
314
+ caption="Activation Overlay", use_container_width=True)
315
+ except Exception:
316
+ pass
317
+
318
  st.session_state[f"cnn_head_{selected}"] = head
319
  st.session_state[f"cnn_acc_{selected}"] = train_acc
320
 
321
  if f"cnn_head_{selected}" in st.session_state:
322
  st.divider()
323
  st.subheader("Quick Predict (Crop)")
324
+ backbone = meta["loader"]()
325
  head = st.session_state[f"cnn_head_{selected}"]
326
  t0 = time.perf_counter()
327
  feats = backbone.get_features(crop_aug)
 
330
  st.write(f"**{label}** β€” {conf:.1%} confidence β€” {dt:.1f} ms")
331
 
332
 
333
+ # ---------------------------------------------------------------------------
334
+ # RIGHT β€” ORB Training
335
+ # ---------------------------------------------------------------------------
336
+ with col_orb:
337
+ st.header("πŸ›οΈ ORB Matching")
338
+ st.caption("Keypoint-based matching β€” a fundamentally different paradigm. "
339
+ "Extracts ORB descriptors from each ROI crop and matches them "
340
+ "against image patches using brute-force Hamming distance.")
341
+
342
+ from src.detectors.orb import ORBDetector
343
+
344
+ orb_dist_thresh = st.slider("Match Distance Threshold", 10, 100, 70,
345
+ key="orb_dist")
346
+ orb_min_matches = st.slider("Min Good Matches", 1, 20, 5, key="orb_min")
347
+
348
+ if st.button("πŸš€ Train ORB Reference"):
349
+ orb = ORBDetector()
350
+ progress = st.progress(0, text="Extracting ORB descriptors...")
351
+
352
+ orb_refs = {}
353
+ for i, roi in enumerate(rois):
354
+ gray = cv2.cvtColor(roi["crop_aug"], cv2.COLOR_BGR2GRAY)
355
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
356
+ gray = clahe.apply(gray)
357
+ kp, des = orb.orb.detectAndCompute(gray, None)
358
+ n_feat = 0 if des is None else len(des)
359
+ orb_refs[roi["label"]] = {
360
+ "descriptors": des,
361
+ "n_features": n_feat,
362
+ "keypoints": kp,
363
+ "crop": roi["crop_aug"],
364
+ }
365
+ progress.progress((i + 1) / len(rois),
366
+ text=f"ROI {i+1}/{len(rois)}: {n_feat} features")
367
+
368
+ progress.progress(1.0, text="βœ… ORB references extracted!")
369
+
370
+ for lbl, ref in orb_refs.items():
371
+ if ref["keypoints"]:
372
+ vis = cv2.drawKeypoints(ref["crop"], ref["keypoints"],
373
+ None, color=(0, 255, 0))
374
+ st.image(cv2.cvtColor(vis, cv2.COLOR_BGR2RGB),
375
+ caption=f"{lbl}: {ref['n_features']} keypoints",
376
+ use_container_width=True)
377
+ else:
378
+ st.warning(f"{lbl}: No keypoints detected")
379
+
380
+ st.session_state["orb_detector"] = orb
381
+ st.session_state["orb_refs"] = orb_refs
382
+ st.session_state["orb_dist_thresh"] = orb_dist_thresh
383
+ st.session_state["orb_min_matches"] = orb_min_matches
384
+ st.success("ORB references stored in session!")
385
+
386
+ if "orb_refs" in st.session_state:
387
+ st.divider()
388
+ st.subheader("Quick Predict (Crop)")
389
+ orb = st.session_state["orb_detector"]
390
+ refs = st.session_state["orb_refs"]
391
+ dt_thresh = st.session_state["orb_dist_thresh"]
392
+ min_m = st.session_state["orb_min_matches"]
393
+
394
+ gray = cv2.cvtColor(crop_aug, cv2.COLOR_BGR2GRAY)
395
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
396
+ gray = clahe.apply(gray)
397
+ kp, des = orb.orb.detectAndCompute(gray, None)
398
+
399
+ if des is not None:
400
+ for lbl, ref in refs.items():
401
+ if ref["descriptors"] is None:
402
+ st.write(f"**{lbl}:** no reference features")
403
+ continue
404
+ matches = orb.bf.match(ref["descriptors"], des)
405
+ good = [m for m in matches if m.distance < dt_thresh]
406
+ conf = min(len(good) / max(min_m, 1), 1.0)
407
+ verdict = lbl if len(good) >= min_m else "background"
408
+ st.write(f"**{verdict}** β€” {len(good)} matches β€” "
409
+ f"{conf:.0%} confidence")
410
+ else:
411
+ st.write("No keypoints in test image.")
412
+
413
+
414
  # ===========================================================================
415
  # Bottom β€” Side-by-side comparison table
416
  # ===========================================================================
417
  st.divider()
418
  st.subheader("πŸ“Š Training Comparison")
419
 
 
420
  rows = []
421
+ rce_acc = st.session_state.get("rce_train_acc")
422
  if rce_acc is not None:
423
+ rows.append({"Model": "RCE", "Type": "Feature Engineering",
424
+ "Train Accuracy": f"{rce_acc:.1%}",
425
  "Vector Size": str(sum(10 for k in active_modules if active_modules[k]))})
426
  for name in BACKBONES:
427
  acc = st.session_state.get(f"cnn_acc_{name}")
428
  if acc is not None:
429
+ rows.append({"Model": name, "Type": "CNN Backbone",
430
+ "Train Accuracy": f"{acc:.1%}",
431
  "Vector Size": f"{BACKBONES[name]['dim']}D"})
432
+ if "orb_refs" in st.session_state:
433
+ total_kp = sum(r["n_features"] for r in st.session_state["orb_refs"].values())
434
+ rows.append({"Model": "ORB", "Type": "Keypoint Matching",
435
+ "Train Accuracy": "N/A (matching)",
436
+ "Vector Size": f"{total_kp} descriptors"})
437
 
438
  if rows:
439
  import pandas as pd
pages/5_Localization_Lab.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import cv2
3
+ import numpy as np
4
+ import pandas as pd
5
+ import plotly.graph_objects as go
6
+ import sys, os
7
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
8
+
9
+ from src.detectors.rce.features import REGISTRY
10
+ from src.models import BACKBONES, RecognitionHead
11
+ from src.localization import (
12
+ exhaustive_sliding_window,
13
+ image_pyramid,
14
+ coarse_to_fine,
15
+ contour_proposals,
16
+ template_matching,
17
+ STRATEGIES,
18
+ )
19
+
20
+ st.set_page_config(page_title="Localization Lab", layout="wide")
21
+ st.title("πŸ” Localization Lab")
22
+ st.markdown(
23
+ "Compare **localization strategies** β€” algorithms that decide *where* "
24
+ "to look in the image. The recognition head stays the same; only the "
25
+ "search method changes."
26
+ )
27
+
28
+ # ===================================================================
29
+ # Guard
30
+ # ===================================================================
31
+ if "pipeline_data" not in st.session_state or \
32
+ "crop" not in st.session_state.get("pipeline_data", {}):
33
+ st.error("Complete **Data Lab** first (upload assets & define a crop).")
34
+ st.stop()
35
+
36
+ assets = st.session_state["pipeline_data"]
37
+ right_img = assets["right"]
38
+ crop = assets["crop"]
39
+ crop_aug = assets.get("crop_aug", crop)
40
+ bbox = assets.get("crop_bbox", (0, 0, crop.shape[1], crop.shape[0]))
41
+ active_mods = st.session_state.get("active_modules",
42
+ {k: True for k in REGISTRY})
43
+
44
+ x0, y0, x1, y1 = bbox
45
+ win_h, win_w = y1 - y0, x1 - x0
46
+
47
+ rce_head = st.session_state.get("rce_head")
48
+ has_any_cnn = any(f"cnn_head_{n}" in st.session_state for n in BACKBONES)
49
+
50
+ if rce_head is None and not has_any_cnn:
51
+ st.warning("No trained heads found. Go to **Model Tuning** first.")
52
+ st.stop()
53
+
54
+
55
+ # ===================================================================
56
+ # RCE feature function
57
+ # ===================================================================
58
+ def rce_feature_fn(patch_bgr):
59
+ gray = cv2.cvtColor(patch_bgr, cv2.COLOR_BGR2GRAY)
60
+ vec = []
61
+ for key, meta in REGISTRY.items():
62
+ if active_mods.get(key, False):
63
+ v, _ = meta["fn"](gray)
64
+ vec.extend(v)
65
+ return np.array(vec, dtype=np.float32)
66
+
67
+
68
+ # ===================================================================
69
+ # Algorithm Reference (collapsible)
70
+ # ===================================================================
71
+ st.divider()
72
+ with st.expander("πŸ“š **Algorithm Reference** β€” click to expand", expanded=False):
73
+ tabs = st.tabs([f"{v['icon']} {k}" for k, v in STRATEGIES.items()])
74
+ for tab, (name, meta) in zip(tabs, STRATEGIES.items()):
75
+ with tab:
76
+ st.markdown(f"### {meta['icon']} {name}")
77
+ st.caption(meta["short"])
78
+ st.markdown(meta["detail"])
79
+
80
+
81
+ # ===================================================================
82
+ # Configuration
83
+ # ===================================================================
84
+ st.divider()
85
+ st.header("βš™οΈ Configuration")
86
+
87
+ # --- Head selection ---
88
+ col_head, col_info = st.columns([2, 3])
89
+ with col_head:
90
+ head_options = []
91
+ if rce_head is not None:
92
+ head_options.append("RCE")
93
+ trained_cnns = [n for n in BACKBONES if f"cnn_head_{n}" in st.session_state]
94
+ head_options.extend(trained_cnns)
95
+ selected_head = st.selectbox("Recognition Head", head_options,
96
+ key="loc_head")
97
+
98
+ if selected_head == "RCE":
99
+ feature_fn = rce_feature_fn
100
+ head = rce_head
101
+ else:
102
+ bmeta = BACKBONES[selected_head]
103
+ backbone = bmeta["loader"]()
104
+ feature_fn = backbone.get_features
105
+ head = st.session_state[f"cnn_head_{selected_head}"]
106
+
107
+ with col_info:
108
+ if selected_head == "RCE":
109
+ mods = [REGISTRY[k]["label"] for k in active_mods if active_mods[k]]
110
+ st.info(f"**RCE** β€” Modules: {', '.join(mods)}")
111
+ else:
112
+ st.info(f"**{selected_head}** β€” "
113
+ f"{BACKBONES[selected_head]['dim']}D feature vector")
114
+
115
+ # --- Algorithm checkboxes ---
116
+ st.subheader("Select Algorithms to Compare")
117
+ algo_cols = st.columns(5)
118
+ algo_names = list(STRATEGIES.keys())
119
+ algo_checks = {}
120
+ for col, name in zip(algo_cols, algo_names):
121
+ algo_checks[name] = col.checkbox(
122
+ f"{STRATEGIES[name]['icon']} {name}",
123
+ value=(name != "Template Matching"), # default all on except TM
124
+ key=f"chk_{name}")
125
+
126
+ any_selected = any(algo_checks.values())
127
+
128
+ # --- Shared parameters ---
129
+ st.subheader("Parameters")
130
+ sp1, sp2, sp3 = st.columns(3)
131
+ stride = sp1.slider("Base Stride (px)", 4, max(win_w, win_h),
132
+ max(win_w // 4, 4), step=2, key="loc_stride")
133
+ conf_thresh = sp2.slider("Confidence Threshold", 0.5, 1.0, 0.7, 0.05,
134
+ key="loc_conf")
135
+ nms_iou = sp3.slider("NMS IoU Threshold", 0.1, 0.9, 0.3, 0.05,
136
+ key="loc_nms")
137
+
138
+ # --- Per-algorithm settings ---
139
+ with st.expander("πŸ”§ Per-Algorithm Settings"):
140
+ pa1, pa2, pa3 = st.columns(3)
141
+ with pa1:
142
+ st.markdown("**Image Pyramid**")
143
+ pyr_min = st.slider("Min Scale", 0.3, 1.0, 0.5, 0.05, key="pyr_min")
144
+ pyr_max = st.slider("Max Scale", 1.0, 2.0, 1.5, 0.1, key="pyr_max")
145
+ pyr_n = st.slider("Number of Scales", 3, 7, 5, key="pyr_n")
146
+ with pa2:
147
+ st.markdown("**Coarse-to-Fine**")
148
+ c2f_factor = st.slider("Coarse Factor", 2, 8, 4, key="c2f_factor")
149
+ c2f_radius = st.slider("Refine Radius (strides)", 1, 5, 2,
150
+ key="c2f_radius")
151
+ with pa3:
152
+ st.markdown("**Contour Proposals**")
153
+ cnt_low = st.slider("Canny Low", 10, 100, 50, key="cnt_low")
154
+ cnt_high = st.slider("Canny High", 50, 300, 150, key="cnt_high")
155
+ cnt_tol = st.slider("Area Tolerance", 1.5, 10.0, 3.0, 0.5,
156
+ key="cnt_tol")
157
+
158
+ st.caption(
159
+ f"Window: **{win_w}Γ—{win_h} px** Β· "
160
+ f"Image: **{right_img.shape[1]}Γ—{right_img.shape[0]} px** Β· "
161
+ f"Stride: **{stride} px**"
162
+ )
163
+
164
+
165
+ # ===================================================================
166
+ # Run
167
+ # ===================================================================
168
+ st.divider()
169
+ run_btn = st.button("β–Ά Run Comparison", type="primary",
170
+ disabled=not any_selected, use_container_width=True)
171
+
172
+ if run_btn:
173
+ selected_algos = [n for n in algo_names if algo_checks[n]]
174
+ progress = st.progress(0, text="Starting…")
175
+ results = {}
176
+ edge_maps = {} # for contour visualisation
177
+
178
+ for i, name in enumerate(selected_algos):
179
+ progress.progress(i / len(selected_algos), text=f"Running **{name}**…")
180
+
181
+ if name == "Exhaustive Sliding Window":
182
+ dets, n, ms, hmap = exhaustive_sliding_window(
183
+ right_img, win_h, win_w, feature_fn, head,
184
+ stride, conf_thresh, nms_iou)
185
+
186
+ elif name == "Image Pyramid":
187
+ scales = np.linspace(pyr_min, pyr_max, pyr_n).tolist()
188
+ dets, n, ms, hmap = image_pyramid(
189
+ right_img, win_h, win_w, feature_fn, head,
190
+ stride, conf_thresh, nms_iou, scales=scales)
191
+
192
+ elif name == "Coarse-to-Fine":
193
+ dets, n, ms, hmap = coarse_to_fine(
194
+ right_img, win_h, win_w, feature_fn, head,
195
+ stride, conf_thresh, nms_iou,
196
+ coarse_factor=c2f_factor, refine_radius=c2f_radius)
197
+
198
+ elif name == "Contour Proposals":
199
+ dets, n, ms, hmap, edges = contour_proposals(
200
+ right_img, win_h, win_w, feature_fn, head,
201
+ conf_thresh, nms_iou,
202
+ canny_low=cnt_low, canny_high=cnt_high,
203
+ area_tolerance=cnt_tol)
204
+ edge_maps[name] = edges
205
+
206
+ elif name == "Template Matching":
207
+ dets, n, ms, hmap = template_matching(
208
+ right_img, crop_aug, conf_thresh, nms_iou)
209
+
210
+ results[name] = {
211
+ "dets": dets, "n_proposals": n,
212
+ "time_ms": ms, "heatmap": hmap,
213
+ }
214
+
215
+ progress.progress(1.0, text="Done!")
216
+
217
+ # ===============================================================
218
+ # Summary Table
219
+ # ===============================================================
220
+ st.header("πŸ“Š Results")
221
+
222
+ baseline_ms = results.get("Exhaustive Sliding Window", {}).get("time_ms")
223
+ rows = []
224
+ for name, r in results.items():
225
+ speedup = (baseline_ms / r["time_ms"]
226
+ if baseline_ms and r["time_ms"] > 0 else None)
227
+ rows.append({
228
+ "Algorithm": name,
229
+ "Proposals": r["n_proposals"],
230
+ "Time (ms)": round(r["time_ms"], 1),
231
+ "Detections": len(r["dets"]),
232
+ "ms / Proposal": round(r["time_ms"] / max(r["n_proposals"], 1), 4),
233
+ "Speedup": f"{speedup:.1f}Γ—" if speedup else "β€”",
234
+ })
235
+
236
+ st.dataframe(pd.DataFrame(rows), use_container_width=True, hide_index=True)
237
+
238
+ # ===============================================================
239
+ # Detection Images & Heatmaps (one tab per algorithm)
240
+ # ===============================================================
241
+ st.subheader("Detection Results")
242
+ COLORS = {
243
+ "Exhaustive Sliding Window": (0, 255, 0),
244
+ "Image Pyramid": (255, 128, 0),
245
+ "Coarse-to-Fine": (0, 128, 255),
246
+ "Contour Proposals": (255, 0, 255),
247
+ "Template Matching": (0, 255, 255),
248
+ }
249
+
250
+ result_tabs = st.tabs(
251
+ [f"{STRATEGIES[n]['icon']} {n}" for n in results])
252
+
253
+ for tab, (name, r) in zip(result_tabs, results.items()):
254
+ with tab:
255
+ c1, c2 = st.columns(2)
256
+ color = COLORS.get(name, (0, 255, 0))
257
+
258
+ # --- Detection overlay ---
259
+ vis = right_img.copy()
260
+ for x1d, y1d, x2d, y2d, _, cf in r["dets"]:
261
+ cv2.rectangle(vis, (x1d, y1d), (x2d, y2d), color, 2)
262
+ cv2.putText(vis, f"{cf:.0%}", (x1d, y1d - 6),
263
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
264
+ c1.image(cv2.cvtColor(vis, cv2.COLOR_BGR2RGB),
265
+ caption=f"{name} β€” {len(r['dets'])} detections",
266
+ use_container_width=True)
267
+
268
+ # --- Heatmap ---
269
+ hmap = r["heatmap"]
270
+ if hmap.max() > 0:
271
+ hmap_color = cv2.applyColorMap(
272
+ (hmap / hmap.max() * 255).astype(np.uint8),
273
+ cv2.COLORMAP_JET)
274
+ blend = cv2.addWeighted(right_img, 0.5, hmap_color, 0.5, 0)
275
+ c2.image(cv2.cvtColor(blend, cv2.COLOR_BGR2RGB),
276
+ caption=f"{name} β€” Confidence Heatmap",
277
+ use_container_width=True)
278
+ else:
279
+ c2.info("No positive responses above threshold.")
280
+
281
+ # --- Contour edge map (extra) ---
282
+ if name in edge_maps:
283
+ st.image(edge_maps[name],
284
+ caption="Canny Edge Map (proposals derived from these contours)",
285
+ use_container_width=True, clamp=True)
286
+
287
+ # --- Per-algorithm metrics ---
288
+ m1, m2, m3, m4 = st.columns(4)
289
+ m1.metric("Proposals", r["n_proposals"])
290
+ m2.metric("Time", f"{r['time_ms']:.0f} ms")
291
+ m3.metric("Detections", len(r["dets"]))
292
+ m4.metric("ms / Proposal",
293
+ f"{r['time_ms'] / max(r['n_proposals'], 1):.3f}")
294
+
295
+ # --- Detection table ---
296
+ if r["dets"]:
297
+ df = pd.DataFrame(r["dets"],
298
+ columns=["x1","y1","x2","y2","label","conf"])
299
+ st.dataframe(df, use_container_width=True, hide_index=True)
300
+
301
+ # ===============================================================
302
+ # Performance Charts
303
+ # ===============================================================
304
+ st.subheader("πŸ“ˆ Performance Comparison")
305
+ ch1, ch2 = st.columns(2)
306
+
307
+ names = list(results.keys())
308
+ times = [results[n]["time_ms"] for n in names]
309
+ props = [results[n]["n_proposals"] for n in names]
310
+ n_dets = [len(results[n]["dets"]) for n in names]
311
+ colors_hex = ["#00cc66", "#ff8800", "#0088ff", "#ff00ff", "#00cccc"]
312
+
313
+ with ch1:
314
+ fig = go.Figure(go.Bar(
315
+ x=names, y=times,
316
+ text=[f"{t:.0f}" for t in times], textposition="auto",
317
+ marker_color=colors_hex[:len(names)]))
318
+ fig.update_layout(title="Total Time (ms)",
319
+ yaxis_title="ms", height=400)
320
+ st.plotly_chart(fig, use_container_width=True)
321
+
322
+ with ch2:
323
+ fig = go.Figure(go.Bar(
324
+ x=names, y=props,
325
+ text=[str(p) for p in props], textposition="auto",
326
+ marker_color=colors_hex[:len(names)]))
327
+ fig.update_layout(title="Proposals Evaluated",
328
+ yaxis_title="Count", height=400)
329
+ st.plotly_chart(fig, use_container_width=True)
330
+
331
+ # --- Scatter: proposals vs time (marker = detections) ---
332
+ fig = go.Figure()
333
+ for i, name in enumerate(names):
334
+ fig.add_trace(go.Scatter(
335
+ x=[props[i]], y=[times[i]],
336
+ mode="markers+text",
337
+ marker=dict(size=max(n_dets[i] * 12, 18),
338
+ color=colors_hex[i % len(colors_hex)]),
339
+ text=[name], textposition="top center",
340
+ name=name,
341
+ ))
342
+ fig.update_layout(
343
+ title="Proposals vs Time (marker size ∝ detections)",
344
+ xaxis_title="Proposals Evaluated",
345
+ yaxis_title="Time (ms)",
346
+ height=500,
347
+ )
348
+ st.plotly_chart(fig, use_container_width=True)
pages/{5_RealTime_Detection.py β†’ 6_RealTime_Detection.py} RENAMED
@@ -24,15 +24,22 @@ right_img = assets["right"]
24
  crop = assets["crop"]
25
  crop_aug = assets.get("crop_aug", crop)
26
  bbox = assets.get("crop_bbox", (0, 0, crop.shape[1], crop.shape[0]))
 
 
27
  active_mods = st.session_state.get("active_modules", {k: True for k in REGISTRY})
28
 
29
  x0, y0, x1, y1 = bbox
30
  win_h, win_w = y1 - y0, x1 - x0 # window = same size as crop
31
 
 
 
 
 
32
  rce_head = st.session_state.get("rce_head")
33
  has_any_cnn = any(f"cnn_head_{n}" in st.session_state for n in BACKBONES)
 
34
 
35
- if rce_head is None and not has_any_cnn:
36
  st.warning("No trained heads found. Go to **Model Tuning** and train at least one head.")
37
  st.stop()
38
 
@@ -77,8 +84,8 @@ def sliding_window_detect(
77
  feats = feature_fn(patch)
78
  label, conf = head.predict(feats)
79
 
80
- # Fill heatmap with object confidence
81
- if label == "object":
82
  heatmap[y:y+win_h, x:x+win_w] = np.maximum(
83
  heatmap[y:y+win_h, x:x+win_w], conf)
84
  if conf >= conf_thresh:
@@ -167,7 +174,7 @@ st.divider()
167
  # ===================================================================
168
  # Side-by-side layout
169
  # ===================================================================
170
- col_rce, col_cnn = st.columns(2)
171
 
172
  # -------------------------------------------------------------------
173
  # LEFT β€” RCE Detection
@@ -194,10 +201,13 @@ with col_rce:
194
 
195
  # Final image with boxes
196
  final = right_img.copy()
 
197
  for x1d, y1d, x2d, y2d, lbl, cf in dets:
198
- cv2.rectangle(final, (x1d, y1d), (x2d, y2d), (0, 255, 0), 2)
199
- cv2.putText(final, f"{cf:.0%}", (x1d, y1d - 6),
200
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
 
 
201
  rce_live.image(cv2.cvtColor(final, cv2.COLOR_BGR2RGB),
202
  caption="RCE β€” Final Detections",
203
  use_container_width=True)
@@ -263,10 +273,13 @@ with col_cnn:
263
 
264
  # Final image
265
  final = right_img.copy()
 
266
  for x1d, y1d, x2d, y2d, lbl, cf in dets:
267
- cv2.rectangle(final, (x1d, y1d), (x2d, y2d), (0, 0, 255), 2)
268
- cv2.putText(final, f"{cf:.0%}", (x1d, y1d - 6),
269
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
 
 
270
  cnn_live.image(cv2.cvtColor(final, cv2.COLOR_BGR2RGB),
271
  caption=f"{selected} β€” Final Detections",
272
  use_container_width=True)
@@ -297,42 +310,146 @@ with col_cnn:
297
  st.session_state["cnn_det_ms"] = ms
298
 
299
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
  # ===================================================================
301
- # Bottom β€” Comparison (if both have run)
302
  # ===================================================================
303
  rce_dets = st.session_state.get("rce_dets")
304
  cnn_dets = st.session_state.get("cnn_dets")
 
 
 
 
 
 
 
 
 
305
 
306
- if rce_dets is not None and cnn_dets is not None:
307
  st.divider()
308
  st.subheader("πŸ“Š Side-by-Side Comparison")
309
 
310
  import pandas as pd
311
- comp = pd.DataFrame({
312
- "Metric": ["Detections", "Best Confidence", "Total Time (ms)"],
313
- "RCE": [
314
- len(rce_dets),
315
- f"{max((d[5] for d in rce_dets), default=0):.1%}",
316
- f"{st.session_state.get('rce_det_ms', 0):.0f}",
317
- ],
318
- "CNN": [
319
- len(cnn_dets),
320
- f"{max((d[5] for d in cnn_dets), default=0):.1%}",
321
- f"{st.session_state.get('cnn_det_ms', 0):.0f}",
322
- ],
323
- })
324
- st.dataframe(comp, use_container_width=True, hide_index=True)
325
-
326
- # Overlay both on one image
327
  overlay = right_img.copy()
328
- for x1d, y1d, x2d, y2d, _, cf in rce_dets:
329
- cv2.rectangle(overlay, (x1d, y1d), (x2d, y2d), (0, 255, 0), 2)
330
- cv2.putText(overlay, f"RCE {cf:.0%}", (x1d, y1d - 6),
331
- cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 255, 0), 1)
332
- for x1d, y1d, x2d, y2d, _, cf in cnn_dets:
333
- cv2.rectangle(overlay, (x1d, y1d), (x2d, y2d), (0, 0, 255), 2)
334
- cv2.putText(overlay, f"CNN {cf:.0%}", (x1d, y2d + 12),
335
- cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 255), 1)
336
  st.image(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB),
337
- caption="Green = RCE | Blue = CNN",
338
- use_container_width=True)
 
24
  crop = assets["crop"]
25
  crop_aug = assets.get("crop_aug", crop)
26
  bbox = assets.get("crop_bbox", (0, 0, crop.shape[1], crop.shape[0]))
27
+ rois = assets.get("rois", [{"label": "object", "bbox": bbox,
28
+ "crop": crop, "crop_aug": crop_aug}])
29
  active_mods = st.session_state.get("active_modules", {k: True for k in REGISTRY})
30
 
31
  x0, y0, x1, y1 = bbox
32
  win_h, win_w = y1 - y0, x1 - x0 # window = same size as crop
33
 
34
+ # Color palette for multi-class drawing
35
+ CLASS_COLORS = [(0,255,0),(0,0,255),(255,165,0),(255,0,255),(0,255,255),
36
+ (128,255,0),(255,128,0),(0,128,255)]
37
+
38
  rce_head = st.session_state.get("rce_head")
39
  has_any_cnn = any(f"cnn_head_{n}" in st.session_state for n in BACKBONES)
40
+ has_orb = "orb_refs" in st.session_state
41
 
42
+ if rce_head is None and not has_any_cnn and not has_orb:
43
  st.warning("No trained heads found. Go to **Model Tuning** and train at least one head.")
44
  st.stop()
45
 
 
84
  feats = feature_fn(patch)
85
  label, conf = head.predict(feats)
86
 
87
+ # Fill heatmap with non-background confidence
88
+ if label != "background":
89
  heatmap[y:y+win_h, x:x+win_w] = np.maximum(
90
  heatmap[y:y+win_h, x:x+win_w], conf)
91
  if conf >= conf_thresh:
 
174
  # ===================================================================
175
  # Side-by-side layout
176
  # ===================================================================
177
+ col_rce, col_cnn, col_orb = st.columns(3)
178
 
179
  # -------------------------------------------------------------------
180
  # LEFT β€” RCE Detection
 
201
 
202
  # Final image with boxes
203
  final = right_img.copy()
204
+ class_labels = sorted(set(d[4] for d in dets)) if dets else []
205
  for x1d, y1d, x2d, y2d, lbl, cf in dets:
206
+ ci = class_labels.index(lbl) if lbl in class_labels else 0
207
+ clr = CLASS_COLORS[ci % len(CLASS_COLORS)]
208
+ cv2.rectangle(final, (x1d, y1d), (x2d, y2d), clr, 2)
209
+ cv2.putText(final, f"{lbl} {cf:.0%}", (x1d, y1d - 6),
210
+ cv2.FONT_HERSHEY_SIMPLEX, 0.4, clr, 1)
211
  rce_live.image(cv2.cvtColor(final, cv2.COLOR_BGR2RGB),
212
  caption="RCE β€” Final Detections",
213
  use_container_width=True)
 
273
 
274
  # Final image
275
  final = right_img.copy()
276
+ class_labels = sorted(set(d[4] for d in dets)) if dets else []
277
  for x1d, y1d, x2d, y2d, lbl, cf in dets:
278
+ ci = class_labels.index(lbl) if lbl in class_labels else 0
279
+ clr = CLASS_COLORS[ci % len(CLASS_COLORS)]
280
+ cv2.rectangle(final, (x1d, y1d), (x2d, y2d), clr, 2)
281
+ cv2.putText(final, f"{lbl} {cf:.0%}", (x1d, y1d - 6),
282
+ cv2.FONT_HERSHEY_SIMPLEX, 0.4, clr, 1)
283
  cnn_live.image(cv2.cvtColor(final, cv2.COLOR_BGR2RGB),
284
  caption=f"{selected} β€” Final Detections",
285
  use_container_width=True)
 
310
  st.session_state["cnn_det_ms"] = ms
311
 
312
 
313
+ # -------------------------------------------------------------------
314
+ # RIGHT β€” ORB Detection
315
+ # -------------------------------------------------------------------
316
+ with col_orb:
317
+ st.header("πŸ›οΈ ORB Detection")
318
+ if not has_orb:
319
+ st.info("No ORB reference trained. Train one in **Model Tuning**.")
320
+ else:
321
+ orb_det = st.session_state["orb_detector"]
322
+ orb_refs = st.session_state["orb_refs"]
323
+ dt_thresh = st.session_state.get("orb_dist_thresh", 70)
324
+ min_m = st.session_state.get("orb_min_matches", 5)
325
+ st.caption(f"References: {', '.join(orb_refs.keys())} | "
326
+ f"dist<{dt_thresh}, min {min_m} matches")
327
+ orb_run = st.button("β–Ά Run ORB Scan", key="orb_run")
328
+
329
+ orb_progress = st.empty()
330
+ orb_live = st.empty()
331
+ orb_results = st.container()
332
+
333
+ if orb_run:
334
+ H, W = right_img.shape[:2]
335
+ positions = [(x, y)
336
+ for y in range(0, H - win_h + 1, stride)
337
+ for x in range(0, W - win_w + 1, stride)]
338
+ n_total = len(positions)
339
+ heatmap = np.zeros((H, W), dtype=np.float32)
340
+ detections = []
341
+ t0 = time.perf_counter()
342
+
343
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
344
+
345
+ for idx, (px, py) in enumerate(positions):
346
+ patch = right_img[py:py+win_h, px:px+win_w]
347
+ gray = cv2.cvtColor(patch, cv2.COLOR_BGR2GRAY)
348
+ gray = clahe.apply(gray)
349
+ kp, des = orb_det.orb.detectAndCompute(gray, None)
350
+
351
+ if des is not None:
352
+ best_label, best_conf = "background", 0.0
353
+ for lbl, ref in orb_refs.items():
354
+ if ref["descriptors"] is None:
355
+ continue
356
+ matches = orb_det.bf.match(ref["descriptors"], des)
357
+ good = [m for m in matches if m.distance < dt_thresh]
358
+ conf = min(len(good) / max(min_m, 1), 1.0)
359
+ if len(good) >= min_m and conf > best_conf:
360
+ best_label, best_conf = lbl, conf
361
+
362
+ if best_label != "background":
363
+ heatmap[py:py+win_h, px:px+win_w] = np.maximum(
364
+ heatmap[py:py+win_h, px:px+win_w], best_conf)
365
+ if best_conf >= conf_thresh:
366
+ detections.append(
367
+ (px, py, px+win_w, py+win_h, best_label, best_conf))
368
+
369
+ if idx % 5 == 0 or idx == n_total - 1:
370
+ orb_progress.progress((idx+1)/n_total,
371
+ text=f"Window {idx+1}/{n_total}")
372
+
373
+ total_ms = (time.perf_counter() - t0) * 1000
374
+ if detections:
375
+ detections = _nms(detections, nms_iou)
376
+
377
+ final = right_img.copy()
378
+ cls_labels = sorted(set(d[4] for d in detections)) if detections else []
379
+ for x1d, y1d, x2d, y2d, lbl, cf in detections:
380
+ ci = cls_labels.index(lbl) if lbl in cls_labels else 0
381
+ clr = CLASS_COLORS[ci % len(CLASS_COLORS)]
382
+ cv2.rectangle(final, (x1d, y1d), (x2d, y2d), clr, 2)
383
+ cv2.putText(final, f"{lbl} {cf:.0%}", (x1d, y1d - 6),
384
+ cv2.FONT_HERSHEY_SIMPLEX, 0.4, clr, 1)
385
+ orb_live.image(cv2.cvtColor(final, cv2.COLOR_BGR2RGB),
386
+ caption="ORB β€” Final Detections",
387
+ use_container_width=True)
388
+ orb_progress.empty()
389
+
390
+ with orb_results:
391
+ om1, om2, om3, om4 = st.columns(4)
392
+ om1.metric("Detections", len(detections))
393
+ om2.metric("Windows", n_total)
394
+ om3.metric("Total Time", f"{total_ms:.0f} ms")
395
+ om4.metric("Per Window", f"{total_ms/max(n_total,1):.2f} ms")
396
+
397
+ if heatmap.max() > 0:
398
+ hmap_color = cv2.applyColorMap(
399
+ (heatmap / heatmap.max() * 255).astype(np.uint8),
400
+ cv2.COLORMAP_JET)
401
+ blend = cv2.addWeighted(right_img, 0.5, hmap_color, 0.5, 0)
402
+ st.image(cv2.cvtColor(blend, cv2.COLOR_BGR2RGB),
403
+ caption="ORB β€” Confidence Heatmap",
404
+ use_container_width=True)
405
+
406
+ if detections:
407
+ import pandas as pd
408
+ df = pd.DataFrame(detections,
409
+ columns=["x1","y1","x2","y2","label","conf"])
410
+ st.dataframe(df, use_container_width=True, hide_index=True)
411
+
412
+ st.session_state["orb_dets"] = detections
413
+ st.session_state["orb_det_ms"] = total_ms
414
+
415
+
416
  # ===================================================================
417
+ # Bottom β€” Comparison (if any two have run)
418
  # ===================================================================
419
  rce_dets = st.session_state.get("rce_dets")
420
  cnn_dets = st.session_state.get("cnn_dets")
421
+ orb_dets = st.session_state.get("orb_dets")
422
+
423
+ methods = {}
424
+ if rce_dets is not None:
425
+ methods["RCE"] = (rce_dets, st.session_state.get("rce_det_ms", 0), (0,255,0))
426
+ if cnn_dets is not None:
427
+ methods["CNN"] = (cnn_dets, st.session_state.get("cnn_det_ms", 0), (0,0,255))
428
+ if orb_dets is not None:
429
+ methods["ORB"] = (orb_dets, st.session_state.get("orb_det_ms", 0), (255,165,0))
430
 
431
+ if len(methods) >= 2:
432
  st.divider()
433
  st.subheader("πŸ“Š Side-by-Side Comparison")
434
 
435
  import pandas as pd
436
+ comp = {"Metric": ["Detections", "Best Confidence", "Total Time (ms)"]}
437
+ for name, (dets, ms, _) in methods.items():
438
+ comp[name] = [
439
+ len(dets),
440
+ f"{max((d[5] for d in dets), default=0):.1%}",
441
+ f"{ms:.0f}",
442
+ ]
443
+ st.dataframe(pd.DataFrame(comp), use_container_width=True, hide_index=True)
444
+
445
+ # Overlay all methods on one image
 
 
 
 
 
 
446
  overlay = right_img.copy()
447
+ for name, (dets, _, clr) in methods.items():
448
+ for x1d, y1d, x2d, y2d, lbl, cf in dets:
449
+ cv2.rectangle(overlay, (x1d, y1d), (x2d, y2d), clr, 2)
450
+ cv2.putText(overlay, f"{name}:{lbl} {cf:.0%}", (x1d, y1d - 6),
451
+ cv2.FONT_HERSHEY_SIMPLEX, 0.35, clr, 1)
452
+ legend = " | ".join(f"{n}={'green' if c==(0,255,0) else 'blue' if c==(0,0,255) else 'orange'}"
453
+ for n, (_, _, c) in methods.items())
 
454
  st.image(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB),
455
+ caption=legend, use_container_width=True)
 
pages/7_Evaluation.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import cv2
3
+ import numpy as np
4
+ import plotly.graph_objects as go
5
+ import plotly.figure_factory as ff
6
+ import sys, os
7
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
8
+
9
+ from src.detectors.rce.features import REGISTRY
10
+ from src.models import BACKBONES
11
+
12
+ st.set_page_config(page_title="Evaluation", layout="wide")
13
+ st.title("πŸ“ˆ Evaluation: Confusion Matrix & PR Curves")
14
+
15
+ # ---------------------------------------------------------------------------
16
+ # Guard
17
+ # ---------------------------------------------------------------------------
18
+ if "pipeline_data" not in st.session_state:
19
+ st.error("Complete the **Data Lab** first.")
20
+ st.stop()
21
+
22
+ assets = st.session_state["pipeline_data"]
23
+ crop = assets["crop"]
24
+ crop_aug = assets.get("crop_aug", crop)
25
+ bbox = assets.get("crop_bbox", (0, 0, crop.shape[1], crop.shape[0]))
26
+ rois = assets.get("rois", [{"label": "object", "bbox": bbox,
27
+ "crop": crop, "crop_aug": crop_aug}])
28
+
29
+ rce_dets = st.session_state.get("rce_dets")
30
+ cnn_dets = st.session_state.get("cnn_dets")
31
+ orb_dets = st.session_state.get("orb_dets")
32
+
33
+ if rce_dets is None and cnn_dets is None and orb_dets is None:
34
+ st.warning("Run detection on at least one method in **Real-Time Detection** first.")
35
+ st.stop()
36
+
37
+
38
+ # ---------------------------------------------------------------------------
39
+ # Ground truth from ROIs
40
+ # ---------------------------------------------------------------------------
41
+ gt_boxes = [(roi["bbox"], roi["label"]) for roi in rois]
42
+
43
+ st.sidebar.subheader("Evaluation Settings")
44
+ iou_thresh = st.sidebar.slider("IoU Threshold", 0.1, 0.9, 0.5, 0.05,
45
+ help="Minimum IoU to count a detection as TP")
46
+
47
+ st.subheader("Ground Truth (from Data Lab ROIs)")
48
+ st.caption(f"{len(gt_boxes)} ground-truth ROIs defined")
49
+ gt_vis = assets["right"].copy()
50
+ for (bx0, by0, bx1, by1), lbl in gt_boxes:
51
+ cv2.rectangle(gt_vis, (bx0, by0), (bx1, by1), (0, 255, 255), 2)
52
+ cv2.putText(gt_vis, lbl, (bx0, by0 - 6),
53
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 255), 1)
54
+ st.image(cv2.cvtColor(gt_vis, cv2.COLOR_BGR2RGB),
55
+ caption="Ground Truth Annotations", use_container_width=True)
56
+
57
+ st.divider()
58
+
59
+
60
+ # ---------------------------------------------------------------------------
61
+ # Matching helpers
62
+ # ---------------------------------------------------------------------------
63
+ def _iou(a, b):
64
+ xi1 = max(a[0], b[0]); yi1 = max(a[1], b[1])
65
+ xi2 = min(a[2], b[2]); yi2 = min(a[3], b[3])
66
+ inter = max(0, xi2 - xi1) * max(0, yi2 - yi1)
67
+ aa = (a[2] - a[0]) * (a[3] - a[1])
68
+ ab = (b[2] - b[0]) * (b[3] - b[1])
69
+ return inter / (aa + ab - inter + 1e-6)
70
+
71
+
72
+ def match_detections(dets, gt_list, iou_thr):
73
+ """
74
+ Match detections to GT boxes.
75
+ Returns list of (det, matched_gt_label_or_None, iou) sorted by confidence.
76
+ """
77
+ dets_sorted = sorted(dets, key=lambda d: d[5], reverse=True)
78
+ matched_gt = set()
79
+ results = []
80
+
81
+ for det in dets_sorted:
82
+ det_box = det[:4]
83
+ det_label = det[4]
84
+ best_iou = 0.0
85
+ best_gt_idx = -1
86
+ best_gt_label = None
87
+
88
+ for gi, (gt_box, gt_label) in enumerate(gt_list):
89
+ if gi in matched_gt:
90
+ continue
91
+ iou_val = _iou(det_box, gt_box)
92
+ if iou_val > best_iou:
93
+ best_iou = iou_val
94
+ best_gt_idx = gi
95
+ best_gt_label = gt_label
96
+
97
+ if best_iou >= iou_thr and best_gt_idx >= 0:
98
+ matched_gt.add(best_gt_idx)
99
+ results.append((det, best_gt_label, best_iou))
100
+ else:
101
+ results.append((det, None, best_iou))
102
+
103
+ return results, len(gt_list) - len(matched_gt)
104
+
105
+
106
+ def compute_pr_curve(dets, gt_list, iou_thr, steps=50):
107
+ """
108
+ Sweep confidence thresholds and compute precision/recall.
109
+ Returns (thresholds, precisions, recalls, f1s).
110
+ """
111
+ if not dets:
112
+ return [], [], [], []
113
+
114
+ thresholds = np.linspace(0.0, 1.0, steps)
115
+ precisions, recalls, f1s = [], [], []
116
+
117
+ for thr in thresholds:
118
+ filtered = [d for d in dets if d[5] >= thr]
119
+ if not filtered:
120
+ precisions.append(1.0)
121
+ recalls.append(0.0)
122
+ f1s.append(0.0)
123
+ continue
124
+
125
+ matched, n_missed = match_detections(filtered, gt_list, iou_thr)
126
+ tp = sum(1 for _, gt_lbl, _ in matched if gt_lbl is not None)
127
+ fp = sum(1 for _, gt_lbl, _ in matched if gt_lbl is None)
128
+ fn = n_missed
129
+
130
+ prec = tp / (tp + fp) if (tp + fp) > 0 else 1.0
131
+ rec = tp / (tp + fn) if (tp + fn) > 0 else 0.0
132
+ f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0.0
133
+ precisions.append(prec)
134
+ recalls.append(rec)
135
+ f1s.append(f1)
136
+
137
+ return thresholds.tolist(), precisions, recalls, f1s
138
+
139
+
140
+ def build_confusion_matrix(dets, gt_list, iou_thr):
141
+ """
142
+ Build a confusion matrix: rows = predicted, cols = actual.
143
+ Classes = all GT labels + 'background'.
144
+ """
145
+ gt_labels = sorted(set(lbl for _, lbl in gt_list))
146
+ all_labels = gt_labels + ["background"]
147
+
148
+ n = len(all_labels)
149
+ matrix = np.zeros((n, n), dtype=int)
150
+ label_to_idx = {lbl: i for i, lbl in enumerate(all_labels)}
151
+
152
+ matched, n_missed = match_detections(dets, gt_list, iou_thr)
153
+
154
+ for det, gt_lbl, _ in matched:
155
+ pred_lbl = det[4]
156
+ if gt_lbl is not None:
157
+ # TP or mislabel
158
+ pi = label_to_idx.get(pred_lbl, label_to_idx["background"])
159
+ gi = label_to_idx[gt_lbl]
160
+ matrix[pi][gi] += 1
161
+ else:
162
+ # FP
163
+ pi = label_to_idx.get(pred_lbl, label_to_idx["background"])
164
+ matrix[pi][label_to_idx["background"]] += 1
165
+
166
+ # FN: unmatched GT
167
+ matched_gt_indices = set()
168
+ for det, gt_lbl, _ in matched:
169
+ if gt_lbl is not None:
170
+ for gi, (_, gl) in enumerate(gt_list):
171
+ if gl == gt_lbl and gi not in matched_gt_indices:
172
+ matched_gt_indices.add(gi)
173
+ break
174
+ for gi, (_, gt_lbl) in enumerate(gt_list):
175
+ if gi not in matched_gt_indices:
176
+ matrix[label_to_idx["background"]][label_to_idx[gt_lbl]] += 1
177
+
178
+ return matrix, all_labels
179
+
180
+
181
+ # ---------------------------------------------------------------------------
182
+ # Collect all methods with detections
183
+ # ---------------------------------------------------------------------------
184
+ methods = {}
185
+ if rce_dets is not None:
186
+ methods["RCE"] = rce_dets
187
+ if cnn_dets is not None:
188
+ methods["CNN"] = cnn_dets
189
+ if orb_dets is not None:
190
+ methods["ORB"] = orb_dets
191
+
192
+
193
+ # ===================================================================
194
+ # 1. Confusion Matrices
195
+ # ===================================================================
196
+ st.subheader("πŸ”² Confusion Matrices")
197
+ cm_cols = st.columns(len(methods))
198
+
199
+ for col, (name, dets) in zip(cm_cols, methods.items()):
200
+ with col:
201
+ st.markdown(f"**{name}**")
202
+ matrix, labels = build_confusion_matrix(dets, gt_boxes, iou_thresh)
203
+
204
+ fig_cm = ff.create_annotated_heatmap(
205
+ z=matrix.tolist(),
206
+ x=labels, y=labels,
207
+ colorscale="Blues",
208
+ showscale=True)
209
+ fig_cm.update_layout(
210
+ title=f"{name} Confusion Matrix",
211
+ xaxis_title="Actual",
212
+ yaxis_title="Predicted",
213
+ template="plotly_dark",
214
+ height=350)
215
+ fig_cm.update_yaxes(autorange="reversed")
216
+ st.plotly_chart(fig_cm, use_container_width=True)
217
+
218
+ # Summary metrics at this default threshold
219
+ matched, n_missed = match_detections(dets, gt_boxes, iou_thresh)
220
+ tp = sum(1 for _, g, _ in matched if g is not None)
221
+ fp = sum(1 for _, g, _ in matched if g is None)
222
+ fn = n_missed
223
+ prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0
224
+ rec = tp / (tp + fn) if (tp + fn) > 0 else 0.0
225
+ f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0.0
226
+
227
+ m1, m2, m3 = st.columns(3)
228
+ m1.metric("Precision", f"{prec:.1%}")
229
+ m2.metric("Recall", f"{rec:.1%}")
230
+ m3.metric("F1 Score", f"{f1:.1%}")
231
+
232
+
233
+ # ===================================================================
234
+ # 2. Precision-Recall Curves
235
+ # ===================================================================
236
+ st.divider()
237
+ st.subheader("πŸ“‰ Precision-Recall Curves")
238
+
239
+ method_colors = {"RCE": "#00ff88", "CNN": "#4488ff", "ORB": "#ff8800"}
240
+ fig_pr = go.Figure()
241
+ fig_f1 = go.Figure()
242
+
243
+ summary_rows = []
244
+
245
+ for name, dets in methods.items():
246
+ thrs, precs, recs, f1s = compute_pr_curve(dets, gt_boxes, iou_thresh)
247
+ clr = method_colors.get(name, "#ffffff")
248
+
249
+ fig_pr.add_trace(go.Scatter(
250
+ x=recs, y=precs, mode="lines+markers",
251
+ name=name, line=dict(color=clr, width=2),
252
+ marker=dict(size=4)))
253
+
254
+ fig_f1.add_trace(go.Scatter(
255
+ x=thrs, y=f1s, mode="lines",
256
+ name=name, line=dict(color=clr, width=2)))
257
+
258
+ # AP (area under PR curve)
259
+ if recs and precs:
260
+ ap = float(np.trapz(precs, recs))
261
+ else:
262
+ ap = 0.0
263
+
264
+ best_f1_idx = int(np.argmax(f1s)) if f1s else 0
265
+ summary_rows.append({
266
+ "Method": name,
267
+ "AP": f"{abs(ap):.3f}",
268
+ "Best F1": f"{f1s[best_f1_idx]:.3f}" if f1s else "N/A",
269
+ "@ Threshold": f"{thrs[best_f1_idx]:.2f}" if thrs else "N/A",
270
+ "Detections": len(dets),
271
+ })
272
+
273
+ fig_pr.update_layout(
274
+ title="Precision vs Recall",
275
+ xaxis_title="Recall", yaxis_title="Precision",
276
+ template="plotly_dark", height=400,
277
+ xaxis=dict(range=[0, 1.05]), yaxis=dict(range=[0, 1.05]))
278
+
279
+ fig_f1.update_layout(
280
+ title="F1 Score vs Confidence Threshold",
281
+ xaxis_title="Confidence Threshold", yaxis_title="F1 Score",
282
+ template="plotly_dark", height=400,
283
+ xaxis=dict(range=[0, 1.05]), yaxis=dict(range=[0, 1.05]))
284
+
285
+ pc1, pc2 = st.columns(2)
286
+ pc1.plotly_chart(fig_pr, use_container_width=True)
287
+ pc2.plotly_chart(fig_f1, use_container_width=True)
288
+
289
+
290
+ # ===================================================================
291
+ # 3. Summary Table
292
+ # ===================================================================
293
+ st.divider()
294
+ st.subheader("πŸ“Š Summary")
295
+
296
+ import pandas as pd
297
+ st.dataframe(pd.DataFrame(summary_rows), use_container_width=True, hide_index=True)
298
+
299
+ st.caption(f"All metrics computed at IoU threshold = **{iou_thresh:.2f}**. "
300
+ "Adjust in the sidebar to explore sensitivity.")
pages/{6_Stereo_Geometry.py β†’ 8_Stereo_Geometry.py} RENAMED
File without changes
src/localization.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ src/localization.py β€” Localization Strategy Library
3
+ =====================================================
4
+ Five strategies that decide WHERE to evaluate a recognition head.
5
+ The head stays the same β€” only the search method changes.
6
+
7
+ Strategies
8
+ ----------
9
+ 1. Exhaustive Sliding Window β€” brute-force grid scan
10
+ 2. Image Pyramid β€” multi-scale resize + sliding window
11
+ 3. Coarse-to-Fine Search β€” two-pass hierarchical refinement
12
+ 4. Contour Proposals β€” edge-driven candidate regions
13
+ 5. Template Matching β€” OpenCV cross-correlation (no head)
14
+
15
+ Every function returns the same tuple:
16
+ (detections, n_proposals, elapsed_ms, heatmap)
17
+ """
18
+
19
+ import cv2
20
+ import numpy as np
21
+ import time
22
+
23
+
24
+ # ===================================================================
25
+ # Shared utilities
26
+ # ===================================================================
27
+
28
+ def nms(dets, iou_thresh):
29
+ """Greedy NMS on list of (x1, y1, x2, y2, label, conf)."""
30
+ dets = sorted(dets, key=lambda d: d[5], reverse=True)
31
+ keep = []
32
+ while dets:
33
+ best = dets.pop(0)
34
+ keep.append(best)
35
+ dets = [d for d in dets if _iou(best, d) < iou_thresh]
36
+ return keep
37
+
38
+
39
+ def _iou(a, b):
40
+ xi1, yi1 = max(a[0], b[0]), max(a[1], b[1])
41
+ xi2, yi2 = min(a[2], b[2]), min(a[3], b[3])
42
+ inter = max(0, xi2 - xi1) * max(0, yi2 - yi1)
43
+ aa = (a[2] - a[0]) * (a[3] - a[1])
44
+ ab = (b[2] - b[0]) * (b[3] - b[1])
45
+ return inter / (aa + ab - inter + 1e-6)
46
+
47
+
48
+ # ===================================================================
49
+ # 1. Exhaustive Sliding Window
50
+ # ===================================================================
51
+
52
+ def exhaustive_sliding_window(image, win_h, win_w, feature_fn, head,
53
+ stride, conf_thresh, nms_iou):
54
+ """
55
+ Brute-force grid scan. Evaluates the head at **every** position
56
+ spaced by *stride* pixels.
57
+ """
58
+ H, W = image.shape[:2]
59
+ heatmap = np.zeros((H, W), dtype=np.float32)
60
+ detections = []
61
+ n_proposals = 0
62
+ t0 = time.perf_counter()
63
+
64
+ for y in range(0, H - win_h + 1, stride):
65
+ for x in range(0, W - win_w + 1, stride):
66
+ patch = image[y:y + win_h, x:x + win_w]
67
+ feats = feature_fn(patch)
68
+ label, conf = head.predict(feats)
69
+ n_proposals += 1
70
+ if label == "object":
71
+ heatmap[y:y + win_h, x:x + win_w] = np.maximum(
72
+ heatmap[y:y + win_h, x:x + win_w], conf)
73
+ if conf >= conf_thresh:
74
+ detections.append((x, y, x + win_w, y + win_h, label, conf))
75
+
76
+ elapsed_ms = (time.perf_counter() - t0) * 1000
77
+ if detections:
78
+ detections = nms(detections, nms_iou)
79
+ return detections, n_proposals, elapsed_ms, heatmap
80
+
81
+
82
+ # ===================================================================
83
+ # 2. Image Pyramid
84
+ # ===================================================================
85
+
86
+ def image_pyramid(image, win_h, win_w, feature_fn, head,
87
+ stride, conf_thresh, nms_iou,
88
+ scales=(0.5, 0.75, 1.0, 1.25, 1.5)):
89
+ """
90
+ Resize the image at several scales, run a sliding window at each
91
+ level, and map detections back to original coordinates.
92
+ Finds objects at sizes different from the training crop.
93
+ """
94
+ H, W = image.shape[:2]
95
+ heatmap = np.zeros((H, W), dtype=np.float32)
96
+ detections = []
97
+ n_proposals = 0
98
+ t0 = time.perf_counter()
99
+
100
+ for scale in scales:
101
+ sH, sW = int(H * scale), int(W * scale)
102
+ if sH < win_h or sW < win_w:
103
+ continue
104
+ scaled = cv2.resize(image, (sW, sH))
105
+
106
+ for y in range(0, sH - win_h + 1, stride):
107
+ for x in range(0, sW - win_w + 1, stride):
108
+ patch = scaled[y:y + win_h, x:x + win_w]
109
+ feats = feature_fn(patch)
110
+ label, conf = head.predict(feats)
111
+ n_proposals += 1
112
+ if label == "object":
113
+ # Map back to original image coordinates
114
+ ox = int(x / scale)
115
+ oy = int(y / scale)
116
+ ox2 = min(int((x + win_w) / scale), W)
117
+ oy2 = min(int((y + win_h) / scale), H)
118
+ heatmap[oy:oy2, ox:ox2] = np.maximum(
119
+ heatmap[oy:oy2, ox:ox2], conf)
120
+ if conf >= conf_thresh:
121
+ detections.append((ox, oy, ox2, oy2, label, conf))
122
+
123
+ elapsed_ms = (time.perf_counter() - t0) * 1000
124
+ if detections:
125
+ detections = nms(detections, nms_iou)
126
+ return detections, n_proposals, elapsed_ms, heatmap
127
+
128
+
129
+ # ===================================================================
130
+ # 3. Coarse-to-Fine Search
131
+ # ===================================================================
132
+
133
+ def coarse_to_fine(image, win_h, win_w, feature_fn, head,
134
+ fine_stride, conf_thresh, nms_iou,
135
+ coarse_factor=4, refine_radius=2):
136
+ """
137
+ Two-pass hierarchical search.
138
+
139
+ Pass 1 β€” Scan at *coarse_factor Γ— fine_stride* to cheaply identify
140
+ hot regions (using a relaxed threshold of 0.7 Γ— conf_thresh).
141
+ Pass 2 β€” Re-scan **only** the neighbourhood of each hit at
142
+ *fine_stride*, within *refine_radius* steps in each direction.
143
+ """
144
+ H, W = image.shape[:2]
145
+ heatmap = np.zeros((H, W), dtype=np.float32)
146
+ detections = []
147
+ n_proposals = 0
148
+ t0 = time.perf_counter()
149
+
150
+ coarse_stride = fine_stride * coarse_factor
151
+
152
+ # --- Pass 1: coarse ---
153
+ hot_spots = []
154
+ for y in range(0, H - win_h + 1, coarse_stride):
155
+ for x in range(0, W - win_w + 1, coarse_stride):
156
+ patch = image[y:y + win_h, x:x + win_w]
157
+ feats = feature_fn(patch)
158
+ label, conf = head.predict(feats)
159
+ n_proposals += 1
160
+ if label == "object" and conf >= conf_thresh * 0.7:
161
+ hot_spots.append((x, y))
162
+ heatmap[y:y + win_h, x:x + win_w] = np.maximum(
163
+ heatmap[y:y + win_h, x:x + win_w], conf)
164
+
165
+ # --- Pass 2: fine around hot spots ---
166
+ visited = set()
167
+ for hx, hy in hot_spots:
168
+ for dy in range(-refine_radius, refine_radius + 1):
169
+ for dx in range(-refine_radius, refine_radius + 1):
170
+ x = hx + dx * fine_stride
171
+ y = hy + dy * fine_stride
172
+ if (x, y) in visited:
173
+ continue
174
+ if x < 0 or y < 0 or x + win_w > W or y + win_h > H:
175
+ continue
176
+ visited.add((x, y))
177
+ patch = image[y:y + win_h, x:x + win_w]
178
+ feats = feature_fn(patch)
179
+ label, conf = head.predict(feats)
180
+ n_proposals += 1
181
+ if label == "object":
182
+ heatmap[y:y + win_h, x:x + win_w] = np.maximum(
183
+ heatmap[y:y + win_h, x:x + win_w], conf)
184
+ if conf >= conf_thresh:
185
+ detections.append((x, y, x + win_w, y + win_h,
186
+ label, conf))
187
+
188
+ elapsed_ms = (time.perf_counter() - t0) * 1000
189
+ if detections:
190
+ detections = nms(detections, nms_iou)
191
+ return detections, n_proposals, elapsed_ms, heatmap
192
+
193
+
194
+ # ===================================================================
195
+ # 4. Contour Proposals
196
+ # ===================================================================
197
+
198
+ def contour_proposals(image, win_h, win_w, feature_fn, head,
199
+ conf_thresh, nms_iou,
200
+ canny_low=50, canny_high=150,
201
+ area_tolerance=3.0):
202
+ """
203
+ Generate candidate regions from image structure:
204
+ Canny edges β†’ morphological closing β†’ contour extraction.
205
+ Keep contours whose bounding-box area is within *area_tolerance*Γ—
206
+ of the window area, centre a window on each, and score with the head.
207
+
208
+ Returns an extra key ``edge_map`` in the heatmap slot for
209
+ visualisation on the page (the caller can detect this).
210
+ """
211
+ H, W = image.shape[:2]
212
+ heatmap = np.zeros((H, W), dtype=np.float32)
213
+ detections = []
214
+ n_proposals = 0
215
+ t0 = time.perf_counter()
216
+
217
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
218
+ blurred = cv2.GaussianBlur(gray, (5, 5), 0)
219
+ edges = cv2.Canny(blurred, canny_low, canny_high)
220
+ kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
221
+ edges = cv2.morphologyEx(edges, cv2.MORPH_CLOSE, kernel)
222
+
223
+ contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL,
224
+ cv2.CHAIN_APPROX_SIMPLE)
225
+
226
+ target_area = win_h * win_w
227
+ min_area = target_area / area_tolerance
228
+ max_area = target_area * area_tolerance
229
+
230
+ for cnt in contours:
231
+ area = cv2.contourArea(cnt)
232
+ if area < min_area or area > max_area:
233
+ continue
234
+ bx, by, bw, bh = cv2.boundingRect(cnt)
235
+ # Centre a window on the contour centre
236
+ cx, cy = bx + bw // 2, by + bh // 2
237
+ px = max(0, min(cx - win_w // 2, W - win_w))
238
+ py = max(0, min(cy - win_h // 2, H - win_h))
239
+
240
+ patch = image[py:py + win_h, px:px + win_w]
241
+ if patch.shape[0] != win_h or patch.shape[1] != win_w:
242
+ continue
243
+
244
+ feats = feature_fn(patch)
245
+ label, conf = head.predict(feats)
246
+ n_proposals += 1
247
+
248
+ if label == "object":
249
+ heatmap[py:py + win_h, px:px + win_w] = np.maximum(
250
+ heatmap[py:py + win_h, px:px + win_w], conf)
251
+ if conf >= conf_thresh:
252
+ detections.append((px, py, px + win_w, py + win_h,
253
+ label, conf))
254
+
255
+ elapsed_ms = (time.perf_counter() - t0) * 1000
256
+ if detections:
257
+ detections = nms(detections, nms_iou)
258
+ return detections, n_proposals, elapsed_ms, heatmap, edges
259
+
260
+
261
+ # ===================================================================
262
+ # 5. Template Matching
263
+ # ===================================================================
264
+
265
+ def template_matching(image, template, conf_thresh, nms_iou,
266
+ method=cv2.TM_CCOEFF_NORMED):
267
+ """
268
+ OpenCV normalised cross-correlation.
269
+ No trained head β€” pure pixel similarity between *template* and every
270
+ image position. Extremely fast (optimised C++) but not invariant to
271
+ rotation, scale, or illumination.
272
+ """
273
+ H, W = image.shape[:2]
274
+ th, tw = template.shape[:2]
275
+ t0 = time.perf_counter()
276
+
277
+ result = cv2.matchTemplate(image, template, method)
278
+
279
+ if method in (cv2.TM_CCOEFF_NORMED, cv2.TM_CCORR_NORMED):
280
+ score_map = np.clip(result, 0, 1).astype(np.float32)
281
+ else:
282
+ lo, hi = result.min(), result.max()
283
+ score_map = ((result - lo) / (hi - lo + 1e-6)).astype(np.float32)
284
+
285
+ # Full-size heatmap (resize for visualisation)
286
+ heatmap = cv2.resize(score_map, (W, H), interpolation=cv2.INTER_LINEAR)
287
+
288
+ # Extract detections above threshold
289
+ detections = []
290
+ locs = np.where(score_map >= conf_thresh)
291
+ for y, x in zip(*locs):
292
+ detections.append((int(x), int(y), int(x + tw), int(y + th),
293
+ "object", float(score_map[y, x])))
294
+
295
+ n_proposals = score_map.shape[0] * score_map.shape[1]
296
+ elapsed_ms = (time.perf_counter() - t0) * 1000
297
+
298
+ if detections:
299
+ detections = nms(detections, nms_iou)
300
+ return detections, n_proposals, elapsed_ms, heatmap
301
+
302
+
303
+ # ===================================================================
304
+ # Registry β€” metadata used by the Streamlit page
305
+ # ===================================================================
306
+
307
+ STRATEGIES = {
308
+ "Exhaustive Sliding Window": {
309
+ "icon": "πŸ”²",
310
+ "fn": exhaustive_sliding_window,
311
+ "needs_head": True,
312
+ "short": "Brute-force grid scan at every stride position.",
313
+ "detail": (
314
+ "The simplest approach: a fixed-size window slides across the "
315
+ "**entire image** at regular intervals. At every position the "
316
+ "patch is extracted, features are computed, and the head classifies it.\n\n"
317
+ "**Complexity:** $O\\!\\left(\\frac{W}{s} \\times \\frac{H}{s}\\right)$ "
318
+ "where $s$ = stride.\n\n"
319
+ "**Pro:** Guaranteed to evaluate every location β€” nothing is missed.\n\n"
320
+ "**Con:** Extremely slow on large images or small strides."
321
+ ),
322
+ },
323
+ "Image Pyramid": {
324
+ "icon": "πŸ”Ί",
325
+ "fn": image_pyramid,
326
+ "needs_head": True,
327
+ "short": "Multi-scale resize + sliding window.",
328
+ "detail": (
329
+ "Builds a **Gaussian pyramid** by resizing the image to several "
330
+ "scales (e.g. 50 %, 75 %, 100 %, 125 %, 150 %). A sliding-window "
331
+ "scan runs at each level and detections are mapped back to original "
332
+ "coordinates.\n\n"
333
+ "**Why:** The training crop has a fixed size. If the real object "
334
+ "appears larger or smaller in the scene, a single-scale scan will "
335
+ "miss it. The pyramid handles **scale variation**.\n\n"
336
+ "**Cost:** Multiplies the number of proposals by the number of "
337
+ "scales β€” slower than single-scale exhaustive."
338
+ ),
339
+ },
340
+ "Coarse-to-Fine": {
341
+ "icon": "🎯",
342
+ "fn": coarse_to_fine,
343
+ "needs_head": True,
344
+ "short": "Two-pass hierarchical refinement.",
345
+ "detail": (
346
+ "**Pass 1 β€” Coarse:** Scans the image with a large stride "
347
+ "(coarse\\_factor Γ— fine\\_stride) using a relaxed confidence "
348
+ "threshold (70 % of the target) to cheaply identify *hot regions*.\n\n"
349
+ "**Pass 2 β€” Fine:** Re-scans **only** the neighbourhood around "
350
+ "each coarse hit at the fine stride, within *refine\\_radius* steps "
351
+ "in each direction.\n\n"
352
+ "**Speedup:** Typically **3–10Γ—** faster than exhaustive when the "
353
+ "object is spatially sparse (i.e. most of the image is background)."
354
+ ),
355
+ },
356
+ "Contour Proposals": {
357
+ "icon": "✏️",
358
+ "fn": contour_proposals,
359
+ "needs_head": True,
360
+ "short": "Edge-driven candidate regions scored by head.",
361
+ "detail": (
362
+ "Instead of scanning everywhere, this method lets **image "
363
+ "structure** drive the search:\n\n"
364
+ "1. Canny edge detection\n"
365
+ "2. Morphological closing to bridge nearby edges\n"
366
+ "3. External contour extraction\n"
367
+ "4. Filter contours whose area falls within *area\\_tolerance* "
368
+ "of the window area\n"
369
+ "5. Centre a window on each surviving contour and score with "
370
+ "the trained head\n\n"
371
+ "**Proposals evaluated:** Typically 10–100Γ— fewer than exhaustive. "
372
+ "Speed depends on scene complexity (more edges β†’ more proposals)."
373
+ ),
374
+ },
375
+ "Template Matching": {
376
+ "icon": "πŸ“‹",
377
+ "fn": template_matching,
378
+ "needs_head": False,
379
+ "short": "OpenCV cross-correlation β€” no head needed.",
380
+ "detail": (
381
+ "Classical **normalised cross-correlation** (NCC). Slides the "
382
+ "crop template over the image computing pixel-level similarity "
383
+ "at every position. No trained head is involved.\n\n"
384
+ "**Speed:** Runs entirely in OpenCV's optimised C++ backend β€” "
385
+ "orders of magnitude faster than Python-level loops.\n\n"
386
+ "**Limitation:** Not invariant to rotation, scale, or illumination "
387
+ "changes. Works best when the object appears at the **exact same "
388
+ "size and orientation** as the crop."
389
+ ),
390
+ },
391
+ }