DariusGiannoli commited on
Commit
a51a1a7
·
1 Parent(s): 397a1b0

refactor: tab-based routing with two pipelines (Stereo+Depth & Generalisation)

Browse files

- Rewrite app.py as sidebar routing controller (no more multi-page Streamlit)
- Add tabs/stereo/ (7 stages: data_lab, feature_lab, model_tuning, localization,
detection, evaluation, stereo_depth)
- Add tabs/generalisation/ (6 stages: same minus stereo_depth)
- Add utils/middlebury_loader.py (PFM + calib parsing, scene group scanner)
- Namespace session state into stereo_pipeline / gen_pipeline dicts
- Fix data leakage: train on LEFT/variant-A, detect on RIGHT/variant-B
- All widget keys prefixed per pipeline to prevent collisions
- Remove deprecated pages/ directory

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. CLAUDE_CODE_PROMPT.md +656 -0
  2. app.py +171 -155
  3. dataOLD/README.md +5 -0
  4. dataOLD/artroom/bird/yolo/bird_data.yaml +7 -0
  5. dataOLD/artroom/bird/yolo/train/images/bird_01_original.png +3 -0
  6. dataOLD/artroom/bird/yolo/train/images/bird_02_rot_pos5.png +3 -0
  7. dataOLD/artroom/bird/yolo/train/images/bird_03_rot_neg5.png +3 -0
  8. dataOLD/artroom/bird/yolo/train/images/bird_04_bright.png +3 -0
  9. dataOLD/artroom/bird/yolo/train/images/bird_05_dark.png +3 -0
  10. dataOLD/artroom/bird/yolo/train/images/bird_06_noisy.png +3 -0
  11. dataOLD/artroom/bird/yolo/train/images/bird_07_flip.png +3 -0
  12. dataOLD/artroom/bird/yolo/train/images/bird_08_blur.png +3 -0
  13. dataOLD/artroom/bird/yolo/train/images/bird_09_shift_x.png +3 -0
  14. dataOLD/artroom/bird/yolo/train/images/bird_10_shift_y.png +3 -0
  15. dataOLD/artroom/bird/yolo/train/images/room_1.png +3 -0
  16. dataOLD/artroom/bird/yolo/train/images/room_2.png +3 -0
  17. dataOLD/artroom/bird/yolo/train/images/room_3.png +3 -0
  18. dataOLD/artroom/bird/yolo/train/images/room_4.png +3 -0
  19. dataOLD/artroom/bird/yolo/train/images/room_5.png +3 -0
  20. dataOLD/artroom/bird/yolo/train/labels.cache +0 -0
  21. dataOLD/artroom/bird/yolo/train/labels/bird_01_original.txt +1 -0
  22. dataOLD/artroom/bird/yolo/train/labels/bird_02_rot_pos5.txt +1 -0
  23. dataOLD/artroom/bird/yolo/train/labels/bird_03_rot_neg5.txt +1 -0
  24. dataOLD/artroom/bird/yolo/train/labels/bird_04_bright.txt +1 -0
  25. dataOLD/artroom/bird/yolo/train/labels/bird_05_dark.txt +1 -0
  26. dataOLD/artroom/bird/yolo/train/labels/bird_06_noisy.txt +1 -0
  27. dataOLD/artroom/bird/yolo/train/labels/bird_07_flip.txt +1 -0
  28. dataOLD/artroom/bird/yolo/train/labels/bird_08_blur.txt +1 -0
  29. dataOLD/artroom/bird/yolo/train/labels/bird_09_shift_x.txt +1 -0
  30. dataOLD/artroom/bird/yolo/train/labels/bird_10_shift_y.txt +1 -0
  31. dataOLD/artroom/bird/yolo/train/labels/room_1.txt +0 -0
  32. dataOLD/artroom/bird/yolo/train/labels/room_2.txt +0 -0
  33. dataOLD/artroom/bird/yolo/train/labels/room_3.txt +0 -0
  34. dataOLD/artroom/bird/yolo/train/labels/room_4.txt +0 -0
  35. dataOLD/artroom/bird/yolo/train/labels/room_5.txt +0 -0
  36. dataOLD/artroom/im0.png +3 -0
  37. pages/2_Data_Lab.py +0 -321
  38. pages/3_Feature_Lab.py +0 -111
  39. pages/4_Model_Tuning.py +0 -475
  40. pages/5_Localization_Lab.py +0 -348
  41. pages/6_RealTime_Detection.py +0 -435
  42. pages/7_Evaluation.py +0 -295
  43. pages/8_Stereo_Geometry.py +0 -353
  44. tabs/__init__.py +0 -0
  45. tabs/generalisation/__init__.py +0 -0
  46. tabs/generalisation/data_lab.py +269 -0
  47. tabs/generalisation/detection.py +388 -0
  48. tabs/generalisation/evaluation.py +205 -0
  49. tabs/generalisation/feature_lab.py +102 -0
  50. tabs/generalisation/localization.py +302 -0
CLAUDE_CODE_PROMPT.md ADDED
@@ -0,0 +1,656 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Claude Code Implementation Prompt
2
+ # Recognition-BenchMark — Full Restructure
3
+
4
+ ---
5
+
6
+ ## Context
7
+
8
+ This is a Streamlit-based stereo-vision benchmarking platform called **Recognition-BenchMark**. It compares a custom hand-crafted feature extractor called **RCE (Relative Contextual Encoding)** against CNN-based deep learning approaches for object recognition and depth estimation.
9
+
10
+ ### Current Project Structure
11
+
12
+ ```
13
+ app.py ← Landing page (home)
14
+ pages/
15
+ ├── 2_Data_Lab.py ← Stage 1
16
+ ├── 3_Feature_Lab.py ← Stage 2
17
+ ├── 4_Model_Tuning.py ← Stage 3
18
+ ├── 5_Localization_Lab.py ← Stage 4
19
+ ├── 6_RealTime_Detection.py ← Stage 5
20
+ ├── 7_Evaluation.py ← Stage 6
21
+ └── 8_Stereo_Geometry.py ← Stage 7
22
+ src/
23
+ ├── config.py ← App configuration constants
24
+ ├── detectors/
25
+ │ ├── base.py ← Base detector class
26
+ │ ├── rce/
27
+ │ │ ├── __init__.py
28
+ │ │ └── features.py ← RCE feature extractor (DO NOT MODIFY)
29
+ │ ├── mobilenet.py ← MobileNetV3 detector (DO NOT MODIFY)
30
+ │ ├── mobilevit.py ← MobileViT detector (DO NOT MODIFY)
31
+ │ ├── resnet.py ← ResNet-18 detector (DO NOT MODIFY)
32
+ │ ├── orb.py ← ORB detector (DO NOT MODIFY)
33
+ │ └── yolo.py ← YOLOv8 detector (DO NOT MODIFY)
34
+ ├── localization.py ← Localization strategies (DO NOT MODIFY)
35
+ └── models.py ← Model loading utilities (DO NOT MODIFY)
36
+ models/
37
+ ├── mobilenet_v3_head.pkl
38
+ ├── mobilenet_v3.pth
39
+ ├── mobilevit_head.pkl
40
+ ├── mobilevit_xxs.pth
41
+ ├── orb_reference.pkl
42
+ ├── resnet18_head.pkl
43
+ ├── resnet18.pth
44
+ └── yolov8n.pt
45
+ data/
46
+ └── middlebury/ ← Bundled dataset (already present)
47
+ ```
48
+
49
+ **The entire `src/` directory must not be modified.** All detector logic, feature extraction, localization strategies, and model loading are already implemented there. The pages in `pages/` import from `src/` and must be migrated to the new `tabs/` structure while continuing to import from `src/`.
50
+
51
+ ---
52
+
53
+ ## Critical Bug To Fix
54
+
55
+ **Data leakage through circular evaluation.** Currently the detection/recognition stage runs on the same left image used to define the training ROI. This is scientifically invalid — the model is tested on its own training source.
56
+
57
+ **The fix:**
58
+ - In the Stereo pipeline: train on LEFT image crop → detect on RIGHT image
59
+ - In the Generalisation pipeline: train on image 1 crop → detect on image 2
60
+
61
+ This must be propagated through session state so every stage after Data Lab knows which image is for training (source) and which is for testing (target).
62
+
63
+ ---
64
+
65
+ ## Target Architecture
66
+
67
+ ### File Structure After Refactor
68
+
69
+ ```
70
+ app.py ← REPLACE: routing controller + home page
71
+ tabs/
72
+ ├── stereo/
73
+ │ ├── __init__.py
74
+ │ ├── data_lab.py ← NEW: replaces pages/2_Data_Lab.py for stereo
75
+ │ ├── feature_lab.py ← MIGRATE: from pages/3_Feature_Lab.py
76
+ │ ├── model_tuning.py ← MIGRATE: from pages/4_Model_Tuning.py
77
+ │ ├── localization.py ← MIGRATE: from pages/5_Localization_Lab.py
78
+ │ ├── detection.py ← MIGRATE + FIX: from pages/6_RealTime_Detection.py
79
+ │ ├── evaluation.py ← MIGRATE: from pages/7_Evaluation.py
80
+ │ └── stereo_depth.py ← MIGRATE: from pages/8_Stereo_Geometry.py
81
+ ├── generalisation/
82
+ │ ├── __init__.py
83
+ │ ├── data_lab.py ← NEW: generalisation-specific data loading
84
+ │ ├── feature_lab.py ← ADAPT: stereo version with gen_pipeline keys
85
+ │ ├── model_tuning.py ← ADAPT: stereo version with gen_pipeline keys
86
+ │ ├── localization.py ← ADAPT: stereo version with gen_pipeline keys
87
+ │ ├── detection.py ← ADAPT + FIX: stereo version with gen_pipeline keys
88
+ │ └── evaluation.py ← ADAPT: stereo version with gen_pipeline keys
89
+ utils/
90
+ └── middlebury_loader.py ← NEW: dataset scanning, loading, parsing
91
+ src/ ← DO NOT TOUCH: all detector/model logic stays here
92
+ pages/ ← DELETE after migration is complete and verified
93
+ data/
94
+ └── middlebury/ ← Already present, do not modify
95
+ ```
96
+
97
+ ---
98
+
99
+ ## Part 1 — app.py Routing Controller
100
+
101
+ Replace the existing `app.py` entirely. The new `app.py` is a **routing controller** that:
102
+
103
+ 1. Sets page config (keep existing title/icon/layout)
104
+ 2. Builds the sidebar navigation manually using session state
105
+ 3. Renders the correct module based on navigation state
106
+ 4. Preserves the existing landing page content (pipeline overview, models, depth info)
107
+
108
+ ### Sidebar Logic
109
+
110
+ ```python
111
+ import streamlit as st
112
+
113
+ # Top-level navigation
114
+ st.sidebar.title("🦅 Recognition BenchMark")
115
+
116
+ top_section = st.sidebar.radio(
117
+ "Navigation",
118
+ ["🏠 Home", "📷 Stereo + Depth", "🌍 Generalisation"],
119
+ key="top_nav"
120
+ )
121
+
122
+ if top_section == "🏠 Home":
123
+ # render home/landing page content inline in app.py
124
+ render_home()
125
+
126
+ elif top_section == "📷 Stereo + Depth":
127
+ stereo_stage = st.sidebar.radio(
128
+ "Pipeline Stage",
129
+ [
130
+ "🧪 1 · Data Lab",
131
+ "🔬 2 · Feature Lab",
132
+ "⚙️ 3 · Model Tuning",
133
+ "🔍 4 · Localization",
134
+ "🎯 5 · Detection",
135
+ "📈 6 · Evaluation",
136
+ "📐 7 · Stereo Depth"
137
+ ],
138
+ key="stereo_stage"
139
+ )
140
+ # import and call the appropriate render() function from tabs/stereo/
141
+
142
+ elif top_section == "🌍 Generalisation":
143
+ gen_stage = st.sidebar.radio(
144
+ "Pipeline Stage",
145
+ [
146
+ "🧪 1 · Data Lab",
147
+ "🔬 2 · Feature Lab",
148
+ "⚙️ 3 · Model Tuning",
149
+ "🔍 4 · Localization",
150
+ "🎯 5 · Detection",
151
+ "📈 6 · Evaluation"
152
+ ],
153
+ key="gen_stage"
154
+ )
155
+ # import and call the appropriate render() function from tabs/generalisation/
156
+ ```
157
+
158
+ ### Stage Guard Pattern
159
+
160
+ Every stage except Data Lab must check if the previous stage is complete. Use this pattern at the top of each stage's `render()` function:
161
+
162
+ ```python
163
+ def render():
164
+ pipe = st.session_state.get("stereo_pipeline", {})
165
+ if "train_image" not in pipe:
166
+ st.warning("⚠️ Complete **Data Lab** first before accessing this stage.")
167
+ st.stop()
168
+ # ... rest of stage logic
169
+ ```
170
+
171
+ ### Session State Namespacing
172
+
173
+ **Critical:** The two pipelines must never share session state keys.
174
+
175
+ - Stereo pipeline uses: `st.session_state["stereo_pipeline"]` — a dict containing all stereo stage data
176
+ - Generalisation pipeline uses: `st.session_state["gen_pipeline"]` — a dict containing all generalisation stage data
177
+
178
+ Within each dict, use consistent keys:
179
+ ```python
180
+ # Stereo pipeline dict keys
181
+ stereo_pipeline = {
182
+ "train_image": np.ndarray, # LEFT image — used for ROI + training
183
+ "test_image": np.ndarray, # RIGHT image — used for detection
184
+ "calib": dict, # parsed calibration parameters
185
+ "disparity_gt": np.ndarray, # ground truth disparity (optional, may be None)
186
+ "roi": dict, # {"x", "y", "w", "h", "label"}
187
+ "crop": np.ndarray, # cropped ROI from train_image
188
+ "crop_aug": list, # augmented crop variants
189
+ "active_modules": list, # RCE modules ["intensity", "sobel", "spectral"]
190
+ "rce_head": object, # trained LogisticRegression
191
+ "cnn_heads": dict, # {"ResNet-18": ..., "MobileNetV3": ..., "MobileViT-XXS": ...}
192
+ "rce_dets": list, # detection results on test_image
193
+ "cnn_dets": dict, # detection results per CNN model
194
+ "source": str, # "middlebury" or "custom"
195
+ "scene_name": str, # Middlebury scene name (if source == "middlebury")
196
+ }
197
+
198
+ # Generalisation pipeline dict keys — same structure minus calib/disparity_gt
199
+ gen_pipeline = {
200
+ "train_image": np.ndarray, # im0.png from training scene variant
201
+ "test_image": np.ndarray, # im0.png from test scene variant
202
+ "roi": dict,
203
+ "crop": np.ndarray,
204
+ "crop_aug": list,
205
+ "active_modules": list,
206
+ "rce_head": object,
207
+ "cnn_heads": dict,
208
+ "rce_dets": list,
209
+ "cnn_dets": dict,
210
+ "source": str, # "middlebury" or "custom"
211
+ "scene_group": str, # e.g. "artroom" (Middlebury only)
212
+ "train_scene": str, # e.g. "artroom1" (Middlebury only)
213
+ "test_scene": str, # e.g. "artroom2" (Middlebury only)
214
+ }
215
+ ```
216
+
217
+ ---
218
+
219
+ ## Part 2 — Middlebury Loader Utility
220
+
221
+ Create `utils/middlebury_loader.py` with the following functions:
222
+
223
+ ### `scan_dataset_root(root_path: str) -> list[str]`
224
+ - Scan root directory for valid scene folders
225
+ - A valid scene must contain: `im0.png`, `im1.png`, `calib.txt`
226
+ - Return sorted list of scene names
227
+
228
+ ### `get_scene_groups(root_path: str) -> dict`
229
+ - Scan all valid scenes and group them by scene base name (strip trailing digit)
230
+ - e.g. `artroom1`, `artroom2` → group `"artroom"`
231
+ - Return dict: `{"artroom": ["artroom1", "artroom2"], "curule": ["curule1", "curule2", "curule3"], ...}`
232
+ - Used by Tab 2 to present scene group selection then variant selection
233
+
234
+ ### `get_available_views(scene_path: str) -> list[dict]`
235
+ - This dataset has NO multi-exposure variants (no im0E.png etc.)
236
+ - Function kept for future compatibility but always returns single entry:
237
+ `[{"suffix": "", "label": "Primary (im0/im1)"}]`
238
+
239
+ ### `load_stereo_pair(scene_path: str, view_suffix: str = '') -> dict`
240
+ - Load `im0{suffix}.png` as left image (train_image)
241
+ - Load `im1{suffix}.png` as right image (test_image)
242
+ - Load and parse `calib.txt`
243
+ - Load `disp0.pfm` if it exists (else None)
244
+ - Return dict with keys: `left`, `right`, `calib`, `disparity_gt`
245
+
246
+ ### `load_single_view(scene_path: str, view_suffix: str) -> np.ndarray`
247
+ - Load and return a single image: `im0{suffix}.png`
248
+ - Used by generalisation tab when selecting individual views
249
+
250
+ ### `parse_calib(calib_path: str) -> dict`
251
+ Parse Middlebury `calib.txt` format:
252
+ ```
253
+ cam0=[fx 0 cx; 0 fy cy; 0 0 1]
254
+ cam1=[fx 0 cx; 0 fy cy; 0 0 1]
255
+ doffs=x_offset
256
+ baseline=Bmm
257
+ width=W
258
+ height=H
259
+ ndisp=N
260
+ vmin=v
261
+ vmax=v
262
+ ```
263
+ Extract and return: `{"fx": float, "baseline": float, "doffs": float, "width": int, "height": int, "ndisp": int}`
264
+
265
+ Use regex to extract `fx` from the camera matrix string: first numeric value after `cam0=[`.
266
+
267
+ ### `load_pfm(filepath: str) -> np.ndarray`
268
+ Load PFM (Portable FloatMap) file:
269
+ - Read header line (`PF` = color, `Pf` = grayscale)
270
+ - Read dimensions line
271
+ - Read scale factor (negative = little-endian)
272
+ - Read float32 binary data
273
+ - Flip vertically (PFM origin is bottom-left)
274
+ - Return numpy array
275
+
276
+ ### Dataset Root Resolution
277
+
278
+ The dataset is **bundled directly in the repo** at `./data/middlebury/`. No user configuration needed.
279
+
280
+ ```python
281
+ DEFAULT_MIDDLEBURY_ROOT = "./data/middlebury"
282
+ ```
283
+
284
+ If the path does not exist or contains no valid scenes, show a clear error. This should not happen in normal deployment since the data is bundled.
285
+
286
+ ### Bundled Scenes Reference
287
+
288
+ The following 10 scene folders are bundled, forming 4 scene groups:
289
+
290
+ ```python
291
+ BUNDLED_SCENES = {
292
+ "artroom": ["artroom1", "artroom2"],
293
+ "curule": ["curule1", "curule2", "curule3"],
294
+ "skates": ["skates1", "skates2"],
295
+ "skiboots": ["skiboots1", "skiboots2", "skiboots3"],
296
+ }
297
+ ```
298
+
299
+ Each folder contains exactly: `im0.png`, `im1.png`, `disp0.pfm`, `disp1.pfm`, `calib.txt`.
300
+
301
+ There are **no multi-exposure variants** (no `im0E.png` etc.) — the scene groups ARE the multi-condition variants. `artroom1` and `artroom2` are different captures of the same artroom scene.
302
+
303
+ ---
304
+
305
+ ## Part 3 — Tab 1: Stereo Data Lab
306
+
307
+ Create `tabs/stereo/data_lab.py` with a `render()` function.
308
+
309
+ ### Data Source Selection
310
+
311
+ ```python
312
+ st.header("🧪 Data Lab — Stereo + Depth")
313
+ st.info("**How this works:** Define your object of interest in the LEFT image. The system trains on it and attempts to recognise it in the RIGHT image — a genuinely different viewpoint.")
314
+
315
+ source = st.radio(
316
+ "Data source",
317
+ ["📦 Middlebury Dataset", "📁 Upload your own files"],
318
+ horizontal=True
319
+ )
320
+ ```
321
+
322
+ ### If Middlebury Selected
323
+
324
+ ```
325
+ 1. Scan dataset root → show selectbox of available scenes
326
+ 2. Auto-load im0.png (train/left) and im1.png (test/right)
327
+ 3. Auto-load calib.txt → parse parameters
328
+ 4. Auto-load disp0.pfm if present
329
+ 5. Display LEFT image (train) and RIGHT image (test) side by side
330
+ 6. Show parsed calibration parameters in an expander
331
+ 7. Show ground truth disparity colormap if available
332
+ ```
333
+
334
+ Show a clear visual label:
335
+ - Left image labeled: `🟦 TRAIN IMAGE (Left)`
336
+ - Right image labeled: `🟥 TEST IMAGE (Right)`
337
+
338
+ ### If Custom Upload Selected
339
+
340
+ ```
341
+ - Left image uploader (png/jpg) → labeled as TRAIN IMAGE
342
+ - Right image uploader (png/jpg) → labeled as TEST IMAGE
343
+ - Calibration file uploader (txt) — REQUIRED for depth estimation
344
+ - PFM ground truth uploader (pfm) — optional, disables depth evaluation if missing
345
+ - If calibration file not provided: show warning "Depth estimation will be disabled"
346
+ ```
347
+
348
+ ### ROI Definition
349
+
350
+ After images are loaded (either source):
351
+
352
+ ```
353
+ 1. Display LEFT (train) image only for ROI definition
354
+ 2. Use streamlit-cropper or manual coordinate inputs for ROI selection
355
+ - If streamlit-cropper available: use it
356
+ - Fallback: four number_input widgets for x, y, w, h
357
+ 3. Text input for class label (default: "object")
358
+ 4. Show cropped ROI preview
359
+ 5. "Lock Data Lab" button → saves everything to st.session_state["stereo_pipeline"]
360
+ ```
361
+
362
+ ### Data Augmentation
363
+
364
+ After ROI is locked, show augmentation controls (preserve existing augmentation logic):
365
+ - Rotation, brightness, contrast, noise, blur, flip
366
+ - Preview augmented crops
367
+ - "Apply Augmentation" button
368
+
369
+ ### What Gets Saved to Session State
370
+
371
+ ```python
372
+ st.session_state["stereo_pipeline"] = {
373
+ "train_image": left_image, # numpy array, BGR
374
+ "test_image": right_image, # numpy array, BGR ← KEY FIX
375
+ "calib": calib_dict, # parsed params or None
376
+ "disparity_gt": disp_gt, # numpy array or None
377
+ "roi": {"x":x, "y":y, "w":w, "h":h, "label":label},
378
+ "crop": cropped_roi,
379
+ "crop_aug": augmented_list,
380
+ "source": "middlebury" or "custom",
381
+ "scene_name": scene_name or "",
382
+ }
383
+ ```
384
+
385
+ ---
386
+
387
+ ## Part 4 — Tab 2: Generalisation Data Lab
388
+
389
+ Create `tabs/generalisation/data_lab.py` with a `render()` function.
390
+
391
+ ### Key Difference From Stereo Data Lab
392
+
393
+ - No calibration file
394
+ - No depth estimation
395
+ - Two images can be completely independent OR different views from Middlebury
396
+ - Goal is testing appearance generalisation, not stereo geometry
397
+
398
+ ### Data Source Selection
399
+
400
+ ```python
401
+ st.header("🧪 Data Lab — Generalisation")
402
+ st.info("**How this works:** Train on one image, test on a completely different image of the same object. No stereo geometry — pure recognition generalisation.")
403
+
404
+ source = st.radio(
405
+ "Data source",
406
+ ["📦 Middlebury Multi-View", "📁 Upload your own files"],
407
+ horizontal=True
408
+ )
409
+ ```
410
+
411
+ ### If Middlebury Selected
412
+
413
+ ```
414
+ 1. Show scene group selector: ["artroom", "curule", "skates", "skiboots"]
415
+ 2. Based on selected group, show available variants:
416
+ - artroom → [artroom1, artroom2]
417
+ - curule → [curule1, curule2, curule3]
418
+ - skates → [skates1, skates2]
419
+ - skiboots → [skiboots1, skiboots2, skiboots3]
420
+ 3. Two selectboxes:
421
+ - "Training scene" → user picks one variant (e.g. artroom1)
422
+ - "Test scene" → user picks a DIFFERENT variant (e.g. artroom2)
423
+ - Validate: training scene ≠ test scene, show error if same selected
424
+ 4. Load train_scene/im0.png as train_image
425
+ 5. Load test_scene/im0.png as test_image
426
+ (NOTE: both are LEFT images im0.png, from different scene variants)
427
+ 6. Display both side by side with clear labels
428
+ ```
429
+
430
+ Show labels:
431
+ - Train image: `🟦 TRAIN IMAGE (artroom1)`
432
+ - Test image: `🟥 TEST IMAGE (artroom2)`
433
+
434
+ Also show an explanation: *"Both images show the same scene type captured under different conditions. The model trains on one variant and must recognise the same object class in the other — testing genuine appearance generalisation."*
435
+
436
+ ### If Custom Upload Selected
437
+
438
+ ```
439
+ - Train image uploader → labeled TRAIN IMAGE
440
+ - Test image uploader → labeled TEST IMAGE
441
+ - No calibration, no PFM needed
442
+ - Simple, low barrier
443
+ ```
444
+
445
+ ### ROI Definition and Augmentation
446
+
447
+ Same as stereo data lab but on the TRAIN image only. Save to `st.session_state["gen_pipeline"]` with same key structure (minus calib and disparity_gt).
448
+
449
+ ---
450
+
451
+ ## Part 5 — Migrate Existing Pipeline Stages
452
+
453
+ The existing pages (Feature Lab, Model Tuning, Localization, Detection, Evaluation, Stereo Depth) must be migrated into the new `tabs/` structure.
454
+
455
+ ### Migration Rules
456
+
457
+ 1. **Read each existing page file** before migrating it:
458
+ - `pages/2_Data_Lab.py` → split into `tabs/stereo/data_lab.py` and `tabs/generalisation/data_lab.py`
459
+ - `pages/3_Feature_Lab.py` → `tabs/stereo/feature_lab.py` (adapt for `tabs/generalisation/feature_lab.py`)
460
+ - `pages/4_Model_Tuning.py` → `tabs/stereo/model_tuning.py`
461
+ - `pages/5_Localization_Lab.py` → `tabs/stereo/localization.py`
462
+ - `pages/6_RealTime_Detection.py` → `tabs/stereo/detection.py` ← apply data leakage fix here
463
+ - `pages/7_Evaluation.py` → `tabs/stereo/evaluation.py`
464
+ - `pages/8_Stereo_Geometry.py` → `tabs/stereo/stereo_depth.py`
465
+
466
+ 2. **Each page becomes a module** with a `render()` function — wrap all existing page code inside `def render(): ...`
467
+
468
+ 3. **Update all session state reads** — the existing pages use `st.session_state.get("pipeline_data", {})` or similar flat keys. Replace with namespaced dicts:
469
+ - Stereo stages: `st.session_state.get("stereo_pipeline", {})`
470
+ - Generalisation stages: `st.session_state.get("gen_pipeline", {})`
471
+
472
+ 4. **Preserve all imports from `src/`** — every `from src.xxx import yyy` in the existing pages must be kept exactly as-is
473
+
474
+ 5. **Detection stage fix** — the critical data leakage fix:
475
+ - OLD: detection runs on the same image used for training ROI definition
476
+ - NEW: `image_to_scan = pipe["test_image"]` ← the other image
477
+
478
+ ### Stage Guard Template
479
+
480
+ ```python
481
+ def render():
482
+ pipe = st.session_state.get("stereo_pipeline", {}) # or gen_pipeline
483
+
484
+ required_keys = ["train_image", "test_image", "crop_aug"]
485
+ missing = [k for k in required_keys if k not in pipe]
486
+ if missing:
487
+ st.warning("⚠️ Complete the **Data Lab** stage first.")
488
+ st.info("Go to: 📷 Stereo + Depth → 🧪 Data Lab")
489
+ st.stop()
490
+ ```
491
+
492
+ ### Specific Migration Notes Per Stage
493
+
494
+ **Feature Lab (Stage 2):**
495
+ - Visualise features extracted from `pipe["crop"]` (from train_image)
496
+ - No changes needed beyond session state key updates
497
+
498
+ **Model Tuning (Stage 3):**
499
+ - Training data comes from `pipe["crop_aug"]` (augmented crops of train_image ROI)
500
+ - Negatives sampled from `pipe["train_image"]` (not test_image)
501
+ - No changes needed beyond session state key updates
502
+
503
+ **Detection (Stage 5) — CRITICAL FIX:**
504
+ - Run sliding window on `pipe["test_image"]` NOT `pipe["train_image"]`
505
+ - Add a visual reminder in the UI: *"Running detection on TEST image (right/second image)"*
506
+ - For stereo: show test_image (right) with detection results
507
+ - For generalisation: show test_image (different exposure) with detection results
508
+
509
+ **Stereo Depth (Stage 7 — stereo only):**
510
+ - Requires `pipe["calib"]` to be not None
511
+ - If calib is None (custom upload without calibration): show warning and disable depth computation
512
+ - If disparity_gt is None: skip ground truth comparison, show note
513
+ - Otherwise: preserve existing StereoSGBM logic entirely
514
+
515
+ ---
516
+
517
+ ## Part 6 — Session Status Widget Update
518
+
519
+ Update the session status display in `app.py` (home page) to show status for BOTH pipelines:
520
+
521
+ ```python
522
+ st.header("📋 Session Status")
523
+
524
+ col1, col2 = st.columns(2)
525
+
526
+ with col1:
527
+ st.subheader("📷 Stereo Pipeline")
528
+ stereo = st.session_state.get("stereo_pipeline", {})
529
+ stereo_checks = {
530
+ "Data loaded": "train_image" in stereo and "test_image" in stereo,
531
+ "ROI defined": "roi" in stereo,
532
+ "Augmentation done": "crop_aug" in stereo,
533
+ "Modules locked": "active_modules" in stereo,
534
+ "Models trained": "rce_head" in stereo,
535
+ "Detection run": "rce_dets" in stereo,
536
+ }
537
+ for label, done in stereo_checks.items():
538
+ st.markdown(f"{'✅' if done else '⬜'} {label}")
539
+
540
+ with col2:
541
+ st.subheader("🌍 Generalisation Pipeline")
542
+ gen = st.session_state.get("gen_pipeline", {})
543
+ gen_checks = {
544
+ "Data loaded": "train_image" in gen and "test_image" in gen,
545
+ "ROI defined": "roi" in gen,
546
+ "Augmentation done": "crop_aug" in gen,
547
+ "Modules locked": "active_modules" in gen,
548
+ "Models trained": "rce_head" in gen,
549
+ "Detection run": "rce_dets" in gen,
550
+ }
551
+ for label, done in gen_checks.items():
552
+ st.markdown(f"{'✅' if done else '⬜'} {label}")
553
+ ```
554
+
555
+ ---
556
+
557
+ ## Part 7 — Shared Utility Modules
558
+
559
+ All core logic lives in `src/` and must be imported identically by both `tabs/stereo/` and `tabs/generalisation/` stages. Do not duplicate or move anything from `src/`.
560
+
561
+ Key imports used by the stage files:
562
+ ```python
563
+ from src.detectors.rce.features import RCEExtractor # RCE feature extraction
564
+ from src.detectors.resnet import ResNetDetector # ResNet-18
565
+ from src.detectors.mobilenet import MobileNetDetector # MobileNetV3
566
+ from src.detectors.mobilevit import MobileViTDetector # MobileViT-XXS
567
+ from src.detectors.orb import ORBDetector # ORB keypoint matching
568
+ from src.detectors.yolo import YOLODetector # YOLOv8
569
+ from src.localization import LocalizationStrategy # All 5 localization strategies
570
+ from src.models import load_model # Model loading from models/
571
+ from src.config import * # App constants
572
+ ```
573
+
574
+ The `models/` directory contains pre-trained weights referenced by `src/models.py`. Do not move or rename any files in `models/`.
575
+
576
+ ---
577
+
578
+ ## Part 8 — Landing Page (Home)
579
+
580
+ The existing landing page content in `app.py` must be preserved and rendered when `top_nav == "🏠 Home"`. Extract it into a `render_home()` function within `app.py`.
581
+
582
+ Update the Pipeline Overview section to reflect the new two-pipeline structure:
583
+
584
+ ```
585
+ 🗺️ Pipeline Overview
586
+
587
+ This platform provides two evaluation pipelines:
588
+
589
+ 📷 Stereo + Depth (7 stages)
590
+ Train on the LEFT image, detect in the RIGHT image, estimate metric depth.
591
+ Evaluates RCE in a constrained stereo-vision scenario.
592
+
593
+ 🌍 Generalisation (6 stages)
594
+ Train on one view/exposure, detect in a different view/exposure.
595
+ Evaluates RCE's robustness to appearance variation.
596
+
597
+ Both pipelines compare: RCE · ResNet-18 · MobileNetV3-Small · MobileViT-XXS · ORB
598
+ ```
599
+
600
+ Update the bottom caption: *"Navigate using the sidebar → Choose a pipeline to begin"*
601
+
602
+ ---
603
+
604
+ ## Implementation Order
605
+
606
+ Implement in this exact order to avoid breaking dependencies:
607
+
608
+ 1. `utils/middlebury_loader.py` — no dependencies, can be tested in isolation
609
+ 2. `app.py` — routing shell, import stubs for tabs not yet created
610
+ 3. `tabs/stereo/data_lab.py` — foundation of stereo pipeline
611
+ 4. `tabs/generalisation/data_lab.py` — foundation of generalisation pipeline
612
+ 5. Migrate existing stages into `tabs/stereo/` — feature_lab, model_tuning, localization, detection (with fix), evaluation, stereo_depth
613
+ 6. Create `tabs/generalisation/` stages — reuse stereo logic with gen_pipeline session keys
614
+ 7. Update home page session status widget
615
+
616
+ ---
617
+
618
+ ## What NOT To Change
619
+
620
+ - **Entire `src/` directory** — all detector logic, RCE, CNN backbones, ORB, localization, model loading
621
+ - **Entire `models/` directory** — pre-trained weights
622
+ - **Entire `data/` directory** — Middlebury dataset
623
+ - **`notebooks/`, `training/`, `scripts/`** — development artifacts, leave untouched
624
+ - **`Dockerfile`, `packages.txt`, `requirements.txt`** — deployment config (add `streamlit-cropper` to requirements.txt only)
625
+ - The visual style and markdown descriptions of each stage
626
+ - The models description tabs on the home page (RCE, ResNet-18, MobileNetV3, MobileViT-XXS)
627
+ - The depth estimation explanation and LaTeX formula
628
+
629
+ ---
630
+
631
+ ## Dependencies To Add
632
+
633
+ Add to `requirements.txt` if not already present:
634
+ ```
635
+ streamlit-cropper # for ROI selection (optional but preferred)
636
+ ```
637
+
638
+ No other new dependencies are needed. The Middlebury loader uses only `numpy`, `opencv-python`, `re`, `os`, and `pathlib` — all already present.
639
+
640
+ ---
641
+
642
+ ## Testing Checklist
643
+
644
+ After implementation, verify:
645
+
646
+ - [ ] Home page renders with both pipeline status widgets
647
+ - [ ] Clicking "Stereo + Depth" shows 7-stage sub-navigation
648
+ - [ ] Clicking "Generalisation" shows 6-stage sub-navigation
649
+ - [ ] Clicking stages before Data Lab shows guard warning
650
+ - [ ] Middlebury loader finds scenes from `./data/middlebury/`
651
+ - [ ] Stereo Data Lab correctly assigns LEFT → train_image, RIGHT → test_image
652
+ - [ ] Generalisation Data Lab correctly assigns View1 → train_image, View2 → test_image
653
+ - [ ] Detection stage uses `pipe["test_image"]` in BOTH pipelines
654
+ - [ ] Stereo and Generalisation pipelines do not share session state
655
+ - [ ] Depth estimation gracefully disabled when calib is None
656
+ - [ ] All existing stage logic works after migration
app.py CHANGED
@@ -3,195 +3,211 @@ import streamlit as st
3
  st.set_page_config(page_title="Perception Benchmark", layout="wide", page_icon="🦅")
4
 
5
  # ===================================================================
6
- # Header
7
  # ===================================================================
8
- st.title("🦅 Recognition BenchMark")
9
- st.subheader("A stereo-vision pipeline for object recognition & depth estimation")
10
- st.caption("Compare classical feature engineering (RCE) against modern deep learning backbones — end-to-end, in your browser.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- st.divider()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  # ===================================================================
15
- # Pipeline Overview
16
  # ===================================================================
17
- st.header("🗺️ Pipeline Overview")
18
- st.markdown("""
19
- The app is structured as a **7-stage sequential pipeline**.
20
- Complete each page in order every stage feeds the next.
21
- """)
22
-
23
- stages = [
24
- ("🧪", "1 · Data Lab", "Upload a stereo image pair, camera calibration file, and two PFM ground-truth depth maps. "
25
- "Define one or more object ROIs (bounding boxes) with class labels, then apply live data augmentation "
26
- "(brightness, contrast, rotation, noise, blur, shift, flip). "
27
- "All assets are locked into session state — nothing is written to disk."),
28
- ("🔬", "2 · Feature Lab", "Toggle RCE physics modules (Intensity · Sobel · Spectral) to build a modular "
29
- "feature vector. Compare it live against CNN activation maps extracted from a "
30
- "frozen backbone via forward hooks. Lock your active module configuration."),
31
- ("⚙️", "3 · Model Tuning", "Train lightweight **heads** on your session data (augmented crop = positives, "
32
- "random non-overlapping patches = negatives). Compare three paradigms side by side: "
33
- "RCE (with feature importance), CNN (with activation overlay), and ORB (keypoint matching)."),
34
- ("🔍", "4 · Localization Lab", "Compare **five localization strategies** on top of your trained head: "
35
- "Exhaustive Sliding Window, Image Pyramid (multi-scale), Coarse-to-Fine "
36
- "hierarchical search, Contour Proposals (edge-driven), and Template "
37
- "Matching (cross-correlation)."),
38
- ("🎯", "5 · Real-Time Detection","Run a **sliding window** across the right image using RCE, CNN, and ORB "
39
- "simultaneously. Watch the scan live, then compare bounding boxes, "
40
- "confidence heatmaps, and latency across all three methods."),
41
- ("📈", "6 · Evaluation", "Quantitative evaluation with **confusion matrices**, **precision-recall curves**, "
42
- "and **F1 scores** per method. Ground truth is derived from your Data Lab ROIs."),
43
- ("📐", "7 · Stereo Geometry", "Compute a disparity map with **StereoSGBM**, convert it to metric depth "
44
- "using the stereo formula $Z = fB/(d+d_{\\text{offs}})$, then read depth "
45
- "directly at every detected bounding box. Compare against PFM ground truth."),
46
- ]
47
-
48
- for icon, title, desc in stages:
49
- with st.container(border=True):
50
- c1, c2 = st.columns([1, 12])
51
- c1.markdown(f"## {icon}")
52
- c2.markdown(f"**{title}** \n{desc}")
53
-
54
- st.divider()
55
 
56
- # ===================================================================
57
- # Models
58
- # ===================================================================
59
- st.header("🧠 Models Used")
60
 
61
- tab_rce, tab_resnet, tab_mobilenet, tab_mobilevit = st.tabs(
62
- ["RCE Engine", "ResNet-18", "MobileNetV3-Small", "MobileViT-XXS"])
63
-
64
- with tab_rce:
65
- st.markdown("### 🧬 RCE — Relative Contextual Encoding")
66
  st.markdown("""
67
- **Type:** Modular hand-crafted feature extractor
68
- **Architecture:** Three physics-inspired modules, each producing a 10-bin histogram:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  | Module | Input | Operation |
71
  |--------|-------|-----------|
72
  | **Intensity** | Grayscale | Pixel-value histogram (global appearance) |
73
  | **Sobel** | Gradient magnitude | Edge strength distribution (texture) |
74
  | **Spectral** | FFT log-magnitude | Frequency content (pattern / structure) |
75
-
76
- **Strengths:**
77
- - Fully explainable every dimension has a physical meaning
78
- - Extremely fast (µs per patch, no GPU needed)
79
- - Modular: disable any module and immediately see the effect on the vector
80
- - Zero pre-training needed
81
-
82
- **Weakness:** Less discriminative than deep features for complex visual scenes.
83
- """)
84
-
85
- with tab_resnet:
86
- st.markdown("### 🏗️ ResNet-18")
87
- st.markdown("""
88
- **Source:** PyTorch Hub (`torchvision.models.ResNet18_Weights.DEFAULT`)
89
- **Pre-training:** ImageNet-1k (1.28 M images, 1 000 classes)
90
- **Backbone output:** 512-dimensional embedding (after `avgpool`)
91
  **Head:** LogisticRegression trained on your session data
92
 
93
- **Architecture highlights:**
94
- - 18 layers with residual (skip) connections
95
- - Residual blocks prevent vanishing gradients in deeper networks
96
- - `layer4` is hooked for activation map visualisation
97
-
98
- **In this app:** The entire backbone is **frozen** (`requires_grad=False`).
99
  Only the lightweight head adapts to your specific object.
100
- """)
101
-
102
- with tab_mobilenet:
103
- st.markdown("### 📱 MobileNetV3-Small")
104
- st.markdown("""
105
- **Source:** PyTorch Hub (`torchvision.models.MobileNet_V3_Small_Weights.DEFAULT`)
106
- **Pre-training:** ImageNet-1k
107
- **Backbone output:** 576-dimensional embedding (classifier replaced with `Identity`)
108
  **Head:** LogisticRegression trained on your session data
109
 
110
- **Architecture highlights:**
111
- - Inverted residuals + linear bottlenecks (MobileNetV2 heritage)
112
- - Hard-Swish / Hard-Sigmoid activations (hardware-friendly)
113
- - Squeeze-and-Excitation (SE) blocks for channel attention
114
- - Designed for **edge / mobile inference** — ~2.5 M parameters
115
 
116
- **In this app:** Typically 3–5× faster than ResNet-18.
117
- `features[-1]` is hooked for activation maps.
118
- """)
119
-
120
- with tab_mobilevit:
121
- st.markdown("### 🤖 MobileViT-XXS")
122
- st.markdown("""
123
- **Source:** timm — `mobilevit_xxs.cvnets_in1k` (Apple Research, 2022)
124
- **Pre-training:** ImageNet-1k
125
- **Backbone output:** 320-dimensional embedding (`num_classes=0`)
126
  **Head:** LogisticRegression trained on your session data
127
 
128
- **Architecture highlights:**
129
- - **Hybrid CNN + Vision Transformer** — local convolutions for spatial features,
130
- global self-attention for long-range context
131
- - MobileNetV2 stem + MobileViT blocks (attention on non-overlapping patches)
132
- - Only ~1.3 M parameters — smallest of the three
133
-
134
- **In this app:** The final transformer stage `stages[-1]` is hooked.
135
- Slower than MobileNetV3 but captures global structure.
136
- """)
137
 
138
- st.divider()
139
 
140
- # ===================================================================
141
- # Depth Estimation
142
- # ===================================================================
143
- st.header("📐 Stereo Depth Estimation")
144
 
145
- col_d1, col_d2 = st.columns(2)
146
- with col_d1:
147
- st.markdown("""
148
  **Algorithm:** `cv2.StereoSGBM` (Semi-Global Block Matching)
149
 
150
  SGBM minimises a global energy function combining:
151
  - Data cost (pixel intensity difference)
152
  - Smoothness penalty (P1, P2 regularisation)
153
 
154
- It processes multiple horizontal and diagonal scan-line passes,
155
  making it significantly more accurate than basic block matching.
156
- """)
157
- with col_d2:
158
- st.markdown("""
159
- **Depth formula (Middlebury convention):**
160
- """)
161
- st.latex(r"Z = \frac{f \times B}{d + d_{\text{offs}}}")
162
- st.markdown("""
163
- - $f$ — focal length (pixels)
164
- - $B$ — baseline (mm, from calibration file)
165
- - $d$ — disparity (pixels)
166
  - $d_\\text{offs}$ — optical-center offset between cameras
167
- """)
 
 
 
168
 
169
- st.divider()
170
 
171
  # ===================================================================
172
- # Session Status
173
  # ===================================================================
174
- st.header("📋 Session Status")
175
-
176
- pipe = st.session_state.get("pipeline_data", {})
177
-
178
- checks = {
179
- "Data Lab locked": "left" in pipe,
180
- "Crop defined": "crop" in pipe,
181
- "Augmentation applied": "crop_aug" in pipe,
182
- "Active modules locked": "active_modules" in st.session_state,
183
- "RCE head trained": "rce_head" in st.session_state,
184
- "CNN head trained": any(f"cnn_head_{n}" in st.session_state
185
- for n in ["ResNet-18", "MobileNetV3", "MobileViT-XXS"]),
186
- "RCE detections ready": "rce_dets" in st.session_state,
187
- "CNN detections ready": "cnn_dets" in st.session_state,
188
- }
189
-
190
- cols = st.columns(4)
191
- for i, (label, done) in enumerate(checks.items()):
192
- cols[i % 4].markdown(
193
- f"{'✅' if done else '⬜'} {'~~' if not done else ''}{label}{'~~' if not done else ''}"
194
- )
195
-
196
- st.divider()
197
- st.caption("Navigate using the sidebar → Start with **🧪 Data Lab**")
 
3
  st.set_page_config(page_title="Perception Benchmark", layout="wide", page_icon="🦅")
4
 
5
  # ===================================================================
6
+ # Routing — Sidebar Navigation
7
  # ===================================================================
8
+ PIPELINES = {
9
+ "🏠 Home": None,
10
+ "📐 Stereo + Depth": {
11
+ "🧪 Data Lab": "tabs.stereo.data_lab",
12
+ "🔬 Feature Lab": "tabs.stereo.feature_lab",
13
+ "⚙️ Model Tuning": "tabs.stereo.model_tuning",
14
+ "🔍 Localization Lab": "tabs.stereo.localization",
15
+ "🎯 Real-Time Detection":"tabs.stereo.detection",
16
+ "📈 Evaluation": "tabs.stereo.evaluation",
17
+ "📐 Stereo Geometry": "tabs.stereo.stereo_depth",
18
+ },
19
+ "🌍 Generalisation": {
20
+ "🧪 Data Lab": "tabs.generalisation.data_lab",
21
+ "🔬 Feature Lab": "tabs.generalisation.feature_lab",
22
+ "⚙️ Model Tuning": "tabs.generalisation.model_tuning",
23
+ "🔍 Localization Lab": "tabs.generalisation.localization",
24
+ "🎯 Real-Time Detection":"tabs.generalisation.detection",
25
+ "📈 Evaluation": "tabs.generalisation.evaluation",
26
+ },
27
+ }
28
 
29
+ st.sidebar.title("🦅 Recognition BenchMark")
30
+ pipeline_choice = st.sidebar.radio("Pipeline", list(PIPELINES.keys()), key="nav_pipeline")
31
+
32
+ stage_module = None
33
+ if PIPELINES[pipeline_choice] is not None:
34
+ stages_map = PIPELINES[pipeline_choice]
35
+ stage_choice = st.sidebar.radio("Stage", list(stages_map.keys()), key="nav_stage")
36
+ module_path = stages_map[stage_choice]
37
+ # dynamic import
38
+ import importlib
39
+ stage_module = importlib.import_module(module_path)
40
+
41
+ # Session status widget (always visible in sidebar)
42
+ st.sidebar.divider()
43
+ st.sidebar.subheader("📋 Session Status")
44
+
45
+ for pipe_label, pipe_key in [("Stereo", "stereo_pipeline"), ("General", "gen_pipeline")]:
46
+ pipe = st.session_state.get(pipe_key, {})
47
+ checks = {
48
+ "Data locked": "train_image" in pipe,
49
+ "Crop defined": "crop" in pipe,
50
+ "Modules set": "active_modules" in pipe,
51
+ "RCE trained": "rce_head" in pipe,
52
+ "CNN trained": any(f"cnn_head_{n}" in pipe
53
+ for n in ["ResNet-18", "MobileNetV3", "MobileViT-XXS"]),
54
+ "Dets ready": "rce_dets" in pipe or "cnn_dets" in pipe,
55
+ }
56
+ with st.sidebar.expander(f"**{pipe_label}** — {sum(checks.values())}/{len(checks)}"):
57
+ for label, done in checks.items():
58
+ st.markdown(f"{'✅' if done else '⬜'} {label}")
59
 
60
  # ===================================================================
61
+ # Home Page
62
  # ===================================================================
63
+ def render_home():
64
+ st.title("🦅 Recognition BenchMark")
65
+ st.subheader("A stereo-vision pipeline for object recognition & depth estimation")
66
+ st.caption("Compare classical feature engineering (RCE) against modern deep learning backbones — end-to-end, in your browser.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
+ st.divider()
 
 
 
69
 
70
+ # -------------------------------------------------------------------
71
+ # Two Pipelines
72
+ # -------------------------------------------------------------------
73
+ st.header("🗺️ Two Pipelines")
 
74
  st.markdown("""
75
+ Choose a pipeline from the **sidebar**:
76
+
77
+ - **📐 Stereo + Depth** — 7 stages. Uses a stereo image pair (LEFT=train, RIGHT=test)
78
+ with calibration data and ground-truth disparities. Ends with metric depth estimation.
79
+ - **🌍 Generalisation** — 6 stages. Uses different scene *variants* from the Middlebury dataset
80
+ (train on one variant, test on another). Tests how well models generalise across viewpoints.
81
+ """)
82
+
83
+ col_s, col_g = st.columns(2)
84
+ with col_s:
85
+ st.markdown("### 📐 Stereo + Depth (7 stages)")
86
+ stereo_stages = [
87
+ ("🧪", "Data Lab", "Load stereo pair, calib, GT depth. Define ROIs."),
88
+ ("🔬", "Feature Lab", "Toggle RCE modules, compare CNN activations."),
89
+ ("⚙️", "Model Tuning", "Train RCE / CNN / ORB heads."),
90
+ ("🔍", "Localization Lab", "Compare 5 localization strategies."),
91
+ ("🎯", "Real-Time Detection","Sliding window on the TEST image."),
92
+ ("📈", "Evaluation", "Confusion matrices, PR curves, F1."),
93
+ ("📐", "Stereo Geometry", "StereoSGBM disparity → metric depth."),
94
+ ]
95
+ for icon, title, desc in stereo_stages:
96
+ st.markdown(f"{icon} **{title}** — {desc}")
97
+
98
+ with col_g:
99
+ st.markdown("### 🌍 Generalisation (6 stages)")
100
+ gen_stages = [
101
+ ("🧪", "Data Lab", "Pick scene group & variants (train ≠ test)."),
102
+ ("🔬", "Feature Lab", "Toggle RCE modules, compare CNN activations."),
103
+ ("⚙️", "Model Tuning", "Train RCE / CNN / ORB heads."),
104
+ ("🔍", "Localization Lab", "Compare 5 localization strategies."),
105
+ ("🎯", "Real-Time Detection","Sliding window on a different variant."),
106
+ ("📈", "Evaluation", "Confusion matrices, PR curves, F1."),
107
+ ]
108
+ for icon, title, desc in gen_stages:
109
+ st.markdown(f"{icon} **{title}** — {desc}")
110
+
111
+ st.divider()
112
+
113
+ # -------------------------------------------------------------------
114
+ # Models
115
+ # -------------------------------------------------------------------
116
+ st.header("🧠 Models Used")
117
+
118
+ tab_rce, tab_resnet, tab_mobilenet, tab_mobilevit = st.tabs(
119
+ ["RCE Engine", "ResNet-18", "MobileNetV3-Small", "MobileViT-XXS"])
120
+
121
+ with tab_rce:
122
+ st.markdown("### 🧬 RCE — Relative Contextual Encoding")
123
+ st.markdown("""
124
+ **Type:** Modular hand-crafted feature extractor
125
+ **Architecture:** Seven physics-inspired modules, each producing a 10-bin histogram:
126
 
127
  | Module | Input | Operation |
128
  |--------|-------|-----------|
129
  | **Intensity** | Grayscale | Pixel-value histogram (global appearance) |
130
  | **Sobel** | Gradient magnitude | Edge strength distribution (texture) |
131
  | **Spectral** | FFT log-magnitude | Frequency content (pattern / structure) |
132
+ | **Laplacian** | Laplacian response | Second-derivative focus / sharpness |
133
+ | **Gradient Orientation** | Sobel angles | Edge direction histogram |
134
+ | **Gabor** | Multi-kernel response | Texture at multiple orientations / scales |
135
+ | **LBP** | Local Binary Patterns | Micro-texture descriptor |
136
+
137
+ Max feature vector = **70D** (7 modules × 10 bins).
138
+ """)
139
+
140
+ with tab_resnet:
141
+ st.markdown("### 🏗️ ResNet-18")
142
+ st.markdown("""
143
+ **Source:** PyTorch Hub (`torchvision.models.ResNet18_Weights.DEFAULT`)
144
+ **Pre-training:** ImageNet-1k (1.28 M images, 1 000 classes)
145
+ **Backbone output:** 512-dimensional embedding (after `avgpool`)
 
 
146
  **Head:** LogisticRegression trained on your session data
147
 
148
+ **In this app:** The entire backbone is **frozen** (`requires_grad=False`).
 
 
 
 
 
149
  Only the lightweight head adapts to your specific object.
150
+ """)
151
+
152
+ with tab_mobilenet:
153
+ st.markdown("### 📱 MobileNetV3-Small")
154
+ st.markdown("""
155
+ **Source:** PyTorch Hub (`torchvision.models.MobileNet_V3_Small_Weights.DEFAULT`)
156
+ **Pre-training:** ImageNet-1k
157
+ **Backbone output:** 576-dimensional embedding
158
  **Head:** LogisticRegression trained on your session data
159
 
160
+ **In this app:** Typically 3–5× faster than ResNet-18.
161
+ """)
 
 
 
162
 
163
+ with tab_mobilevit:
164
+ st.markdown("### 🤖 MobileViT-XXS")
165
+ st.markdown("""
166
+ **Source:** timm — `mobilevit_xxs.cvnets_in1k` (Apple Research, 2022)
167
+ **Pre-training:** ImageNet-1k
168
+ **Backbone output:** 320-dimensional embedding (`num_classes=0`)
 
 
 
 
169
  **Head:** LogisticRegression trained on your session data
170
 
171
+ **In this app:** Hybrid CNN + Vision Transformer. Only ~1.3 M parameters.
172
+ """)
 
 
 
 
 
 
 
173
 
174
+ st.divider()
175
 
176
+ # -------------------------------------------------------------------
177
+ # Depth Estimation
178
+ # -------------------------------------------------------------------
179
+ st.header("📐 Stereo Depth Estimation")
180
 
181
+ col_d1, col_d2 = st.columns(2)
182
+ with col_d1:
183
+ st.markdown("""
184
  **Algorithm:** `cv2.StereoSGBM` (Semi-Global Block Matching)
185
 
186
  SGBM minimises a global energy function combining:
187
  - Data cost (pixel intensity difference)
188
  - Smoothness penalty (P1, P2 regularisation)
189
 
190
+ It processes multiple horizontal and diagonal scan-line passes,
191
  making it significantly more accurate than basic block matching.
192
+ """)
193
+ with col_d2:
194
+ st.markdown("**Depth formula (Middlebury convention):**")
195
+ st.latex(r"Z = \frac{f \times B}{d + d_{\text{offs}}}")
196
+ st.markdown("""
197
+ - $f$ focal length (pixels)
198
+ - $B$ — baseline (mm, from calibration file)
199
+ - $d$ — disparity (pixels)
 
 
200
  - $d_\\text{offs}$ — optical-center offset between cameras
201
+ """)
202
+
203
+ st.divider()
204
+ st.caption("Select a pipeline from the **sidebar** to begin.")
205
 
 
206
 
207
  # ===================================================================
208
+ # Dispatch
209
  # ===================================================================
210
+ if stage_module is not None:
211
+ stage_module.render()
212
+ else:
213
+ render_home()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dataOLD/README.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ two data views:
2
+
3
+ - Classification (ResNet, MobileNet, RCE)
4
+ - Detection (YOLO)
5
+
dataOLD/artroom/bird/yolo/bird_data.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ path: /Users/dariusgiannoli/Desktop/Recognition-BenchMark/data/artroom/bird/yolo
2
+ train: train/images
3
+ val: train/images
4
+
5
+ # Classes
6
+ nc: 1
7
+ names: ['bird']
dataOLD/artroom/bird/yolo/train/images/bird_01_original.png ADDED

Git LFS Details

  • SHA256: d025453acb490274beb548b07097fd044e98088b1369be16a1eff3061858cfcc
  • Pointer size: 129 Bytes
  • Size of remote file: 7.9 kB
dataOLD/artroom/bird/yolo/train/images/bird_02_rot_pos5.png ADDED

Git LFS Details

  • SHA256: a53253ca317d4690ef2874bf38ed5c7ae13c61d1cf326fe1b86c760e51dce22e
  • Pointer size: 129 Bytes
  • Size of remote file: 7.64 kB
dataOLD/artroom/bird/yolo/train/images/bird_03_rot_neg5.png ADDED

Git LFS Details

  • SHA256: 8f18ee9f0fd5140f72119db42765d0211d1ad1ba12a00dd0a2a5f25043b8f2f2
  • Pointer size: 129 Bytes
  • Size of remote file: 7.58 kB
dataOLD/artroom/bird/yolo/train/images/bird_04_bright.png ADDED

Git LFS Details

  • SHA256: 1870cd53de369fd6eaef0437f057b88f9c6e634ba156d1ab9e17c82015f3396c
  • Pointer size: 129 Bytes
  • Size of remote file: 6.69 kB
dataOLD/artroom/bird/yolo/train/images/bird_05_dark.png ADDED

Git LFS Details

  • SHA256: 94202124356fe3086efdac56ad6a5cb23578bcaf45b00818752162e14584d358
  • Pointer size: 129 Bytes
  • Size of remote file: 7.42 kB
dataOLD/artroom/bird/yolo/train/images/bird_06_noisy.png ADDED

Git LFS Details

  • SHA256: 14321f70346a4b8ba50223f367205e26f7312812ffb42e1c5f5e36284047120a
  • Pointer size: 130 Bytes
  • Size of remote file: 10.5 kB
dataOLD/artroom/bird/yolo/train/images/bird_07_flip.png ADDED

Git LFS Details

  • SHA256: 505c987465f02d9393ddc4a869027830f8cfc6c488426e9fd5710b68f6dc5d57
  • Pointer size: 129 Bytes
  • Size of remote file: 7.89 kB
dataOLD/artroom/bird/yolo/train/images/bird_08_blur.png ADDED

Git LFS Details

  • SHA256: 6937295d87d1601008337ca97924b149fb911561d226ff3f594bfe0387958a8c
  • Pointer size: 129 Bytes
  • Size of remote file: 7.22 kB
dataOLD/artroom/bird/yolo/train/images/bird_09_shift_x.png ADDED

Git LFS Details

  • SHA256: a522a5a057292e81d27e6ed18de2104c825e9a8d8ca78feb15fcb34ccce8f233
  • Pointer size: 129 Bytes
  • Size of remote file: 7.82 kB
dataOLD/artroom/bird/yolo/train/images/bird_10_shift_y.png ADDED

Git LFS Details

  • SHA256: 696de70c12b9f73c5dcf6de5aea76100fabed63c6667c513eb7fd435a10ce6d9
  • Pointer size: 129 Bytes
  • Size of remote file: 7.86 kB
dataOLD/artroom/bird/yolo/train/images/room_1.png ADDED

Git LFS Details

  • SHA256: e5ad4b77f43935da0ebd55a6e4218c861307788b1a00e4b4d688018b3edd7083
  • Pointer size: 131 Bytes
  • Size of remote file: 371 kB
dataOLD/artroom/bird/yolo/train/images/room_2.png ADDED

Git LFS Details

  • SHA256: ab100a7c64a1969c91dd66280486a39586e28413271e9a0735a20646b51301ec
  • Pointer size: 131 Bytes
  • Size of remote file: 611 kB
dataOLD/artroom/bird/yolo/train/images/room_3.png ADDED

Git LFS Details

  • SHA256: f14c31db76bff3c1425d1361cbac2564566c799735fa30576c71318dcfc467b9
  • Pointer size: 131 Bytes
  • Size of remote file: 411 kB
dataOLD/artroom/bird/yolo/train/images/room_4.png ADDED

Git LFS Details

  • SHA256: d5bab08dfc1f57c39dfcce979280fbcdc1217f14e21eabd68f2d2e3a7a420e10
  • Pointer size: 131 Bytes
  • Size of remote file: 313 kB
dataOLD/artroom/bird/yolo/train/images/room_5.png ADDED

Git LFS Details

  • SHA256: c5ed3722ca9ad4c5a57fbf28516c3f8dfe133749d9fa21d843a8bce288e33ecc
  • Pointer size: 131 Bytes
  • Size of remote file: 206 kB
dataOLD/artroom/bird/yolo/train/labels.cache ADDED
Binary file (3.76 kB). View file
 
dataOLD/artroom/bird/yolo/train/labels/bird_01_original.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 0 0.5 0.5 0.8 0.8
dataOLD/artroom/bird/yolo/train/labels/bird_02_rot_pos5.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 0 0.5 0.5 0.8 0.8
dataOLD/artroom/bird/yolo/train/labels/bird_03_rot_neg5.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 0 0.5 0.5 0.8 0.8
dataOLD/artroom/bird/yolo/train/labels/bird_04_bright.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 0 0.5 0.5 0.8 0.8
dataOLD/artroom/bird/yolo/train/labels/bird_05_dark.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 0 0.5 0.5 0.8 0.8
dataOLD/artroom/bird/yolo/train/labels/bird_06_noisy.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 0 0.5 0.5 0.8 0.8
dataOLD/artroom/bird/yolo/train/labels/bird_07_flip.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 0 0.5 0.5 0.8 0.8
dataOLD/artroom/bird/yolo/train/labels/bird_08_blur.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 0 0.5 0.5 0.8 0.8
dataOLD/artroom/bird/yolo/train/labels/bird_09_shift_x.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 0 0.5 0.5 0.8 0.8
dataOLD/artroom/bird/yolo/train/labels/bird_10_shift_y.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 0 0.5 0.5 0.8 0.8
dataOLD/artroom/bird/yolo/train/labels/room_1.txt ADDED
File without changes
dataOLD/artroom/bird/yolo/train/labels/room_2.txt ADDED
File without changes
dataOLD/artroom/bird/yolo/train/labels/room_3.txt ADDED
File without changes
dataOLD/artroom/bird/yolo/train/labels/room_4.txt ADDED
File without changes
dataOLD/artroom/bird/yolo/train/labels/room_5.txt ADDED
File without changes
dataOLD/artroom/im0.png ADDED

Git LFS Details

  • SHA256: 280be6eac4b525eee6d49f0afd32c11ef0b83d2cad3e77e946fe525fda16a355
  • Pointer size: 132 Bytes
  • Size of remote file: 2.76 MB
pages/2_Data_Lab.py DELETED
@@ -1,321 +0,0 @@
1
- import streamlit as st
2
- import cv2
3
- import numpy as np
4
- import io
5
-
6
- st.set_page_config(page_title="Data Lab", layout="wide")
7
-
8
- st.title("🧪 Data Lab: Stereo Asset Loader")
9
- st.write("Upload your stereo images, camera configuration, and ground truth depth maps.")
10
-
11
-
12
- # ---------------------------------------------------------------------------
13
- # Helpers
14
- # ---------------------------------------------------------------------------
15
- def read_pfm(file_bytes: bytes) -> np.ndarray:
16
- """Parse a PFM (Portable Float Map) and return a float32 ndarray."""
17
- buf = io.BytesIO(file_bytes)
18
- header = buf.readline().decode("ascii").strip()
19
- if header not in ("Pf", "PF"):
20
- raise ValueError(f"Not a valid PFM file (header: {header!r})")
21
- color = header == "PF"
22
- line = buf.readline().decode("ascii").strip()
23
- while line.startswith("#"):
24
- line = buf.readline().decode("ascii").strip()
25
- w, h = map(int, line.split())
26
- scale = float(buf.readline().decode("ascii").strip())
27
- endian = "<" if scale < 0 else ">"
28
- channels = 3 if color else 1
29
- data = np.frombuffer(buf.read(), dtype=np.dtype(endian + "f4"))
30
- data = data.reshape((h, w, channels) if color else (h, w))
31
- return np.flipud(data)
32
-
33
-
34
- def vis_depth(depth: np.ndarray) -> np.ndarray:
35
- """Normalise depth to [0,1] for display, ignoring non-finite values."""
36
- finite = depth[np.isfinite(depth)]
37
- d = np.nan_to_num(depth, nan=0.0, posinf=float(finite.max()))
38
- return (d / d.max()).astype(np.float32) if d.max() > 0 else d.astype(np.float32)
39
-
40
-
41
- def augment(img: np.ndarray, brightness: float, contrast: float,
42
- rotation: float, flip_h: bool, flip_v: bool,
43
- noise: float, blur: int, shift_x: int, shift_y: int) -> np.ndarray:
44
- """Apply a chain of augmentations to a BGR crop."""
45
- out = img.astype(np.float32)
46
-
47
- # Brightness / Contrast: out = contrast * out + brightness_offset
48
- out = np.clip(contrast * out + brightness, 0, 255)
49
-
50
- # Gaussian noise
51
- if noise > 0:
52
- out = np.clip(out + np.random.normal(0, noise, out.shape), 0, 255)
53
-
54
- out = out.astype(np.uint8)
55
-
56
- # Blur (kernel must be odd)
57
- k = blur * 2 + 1
58
- if k > 1:
59
- out = cv2.GaussianBlur(out, (k, k), 0)
60
-
61
- # Rotation
62
- if rotation != 0:
63
- h, w = out.shape[:2]
64
- M = cv2.getRotationMatrix2D((w / 2, h / 2), rotation, 1.0)
65
- out = cv2.warpAffine(out, M, (w, h), borderMode=cv2.BORDER_REFLECT)
66
-
67
- # Translation
68
- if shift_x != 0 or shift_y != 0:
69
- h, w = out.shape[:2]
70
- M = np.float32([[1, 0, shift_x], [0, 1, shift_y]])
71
- out = cv2.warpAffine(out, M, (w, h), borderMode=cv2.BORDER_REFLECT)
72
-
73
- # Flips
74
- if flip_h:
75
- out = cv2.flip(out, 1)
76
- if flip_v:
77
- out = cv2.flip(out, 0)
78
-
79
- return out
80
-
81
-
82
- # --- Session State Initialization ---
83
- MAX_UPLOAD_BYTES = 50 * 1024 * 1024 # 50 MB
84
-
85
- if "pipeline_data" not in st.session_state:
86
- st.session_state["pipeline_data"] = {}
87
-
88
- # ---------------------------------------------------------------------------
89
- # Step 1 — Upload Assets
90
- # ---------------------------------------------------------------------------
91
- st.subheader("Step 1: Upload Assets")
92
- col1, col2 = st.columns(2)
93
-
94
- with col1:
95
- up_l = st.file_uploader("Left Image (Reference)", type=["png", "jpg", "jpeg"])
96
- if up_l:
97
- if up_l.size > MAX_UPLOAD_BYTES:
98
- st.error(f"❌ Left image too large ({up_l.size / 1e6:.1f} MB). Max 50 MB.")
99
- up_l = None
100
- else:
101
- img_l_preview = cv2.imdecode(np.frombuffer(up_l.read(), np.uint8), cv2.IMREAD_COLOR)
102
- up_l.seek(0)
103
- st.image(cv2.cvtColor(img_l_preview, cv2.COLOR_BGR2RGB), caption="Left Image Preview", use_container_width=True)
104
-
105
- up_conf = st.file_uploader("Camera Config (.txt or .conf)", type=["txt", "conf"])
106
-
107
- up_gt_l = st.file_uploader("Left Ground Truth Depth (.pfm)", type=["pfm"])
108
- if up_gt_l:
109
- try:
110
- gt_l_prev = read_pfm(up_gt_l.read()); up_gt_l.seek(0)
111
- st.image(vis_depth(gt_l_prev), caption="Left GT Depth Preview", use_container_width=True)
112
- except (ValueError, Exception) as e:
113
- st.error(f"❌ Invalid PFM file (left): {e}")
114
- up_gt_l = None
115
-
116
- with col2:
117
- up_r = st.file_uploader("Right Image (Stereo Match)", type=["png", "jpg", "jpeg"])
118
- if up_r:
119
- if up_r.size > MAX_UPLOAD_BYTES:
120
- st.error(f"❌ Right image too large ({up_r.size / 1e6:.1f} MB). Max 50 MB.")
121
- up_r = None
122
- else:
123
- img_r_preview = cv2.imdecode(np.frombuffer(up_r.read(), np.uint8), cv2.IMREAD_COLOR)
124
- up_r.seek(0)
125
- st.image(cv2.cvtColor(img_r_preview, cv2.COLOR_BGR2RGB), caption="Right Image Preview", use_container_width=True)
126
-
127
- up_gt_r = st.file_uploader("Right Ground Truth Depth (.pfm)", type=["pfm"])
128
- if up_gt_r:
129
- try:
130
- gt_r_prev = read_pfm(up_gt_r.read()); up_gt_r.seek(0)
131
- st.image(vis_depth(gt_r_prev), caption="Right GT Depth Preview", use_container_width=True)
132
- except (ValueError, Exception) as e:
133
- st.error(f"❌ Invalid PFM file (right): {e}")
134
- up_gt_r = None
135
-
136
- # ---------------------------------------------------------------------------
137
- # Step 2 — Full pipeline (requires all 5 files)
138
- # ---------------------------------------------------------------------------
139
- if up_l and up_r and up_conf and up_gt_l and up_gt_r:
140
- img_l = cv2.imdecode(np.frombuffer(up_l.read(), np.uint8), cv2.IMREAD_COLOR)
141
- img_r = cv2.imdecode(np.frombuffer(up_r.read(), np.uint8), cv2.IMREAD_COLOR)
142
- conf_content = up_conf.read().decode("utf-8")
143
- gt_depth_l = read_pfm(up_gt_l.read())
144
- gt_depth_r = read_pfm(up_gt_r.read())
145
-
146
- st.success("✅ All assets loaded successfully!")
147
-
148
- # --- Stereo pair ---
149
- st.divider()
150
- st.subheader("Step 2: Asset Visualization")
151
- st.write("### 📸 Stereo Pair")
152
- v1, v2 = st.columns(2)
153
- v1.image(cv2.cvtColor(img_l, cv2.COLOR_BGR2RGB), caption="Left View", use_container_width=True)
154
- v2.image(cv2.cvtColor(img_r, cv2.COLOR_BGR2RGB), caption="Right View", use_container_width=True)
155
-
156
- # --- Ground truth maps ---
157
- st.write("### 📊 Ground Truth Depth Maps")
158
- d1, d2 = st.columns(2)
159
- d1.image(vis_depth(gt_depth_l), caption="Left GT Depth (Normalized)", use_container_width=True)
160
- d2.image(vis_depth(gt_depth_r), caption="Right GT Depth (Normalized)", use_container_width=True)
161
-
162
- # --- Config ---
163
- with st.expander("📄 Camera Configuration"):
164
- st.text_area("Raw Config", conf_content, height=200)
165
-
166
- # -----------------------------------------------------------------------
167
- # Step 3 — Crop ROI(s) from Left Image (Multi-Object)
168
- # -----------------------------------------------------------------------
169
- st.divider()
170
- st.subheader("Step 3: Crop Region(s) of Interest")
171
- st.write("Define one or more bounding boxes — each becomes a separate class for recognition.")
172
-
173
- H, W = img_l.shape[:2]
174
-
175
- # Manage list of ROIs in session state
176
- if "rois" not in st.session_state:
177
- st.session_state["rois"] = [{"label": "object", "x0": 0, "y0": 0,
178
- "x1": min(W, 100), "y1": min(H, 100)}]
179
-
180
- def _add_roi():
181
- if len(st.session_state["rois"]) >= 20:
182
- return
183
- st.session_state["rois"].append(
184
- {"label": f"object_{len(st.session_state['rois'])+1}",
185
- "x0": 0, "y0": 0,
186
- "x1": min(W, 100), "y1": min(H, 100)})
187
-
188
- def _remove_roi(idx):
189
- if len(st.session_state["rois"]) > 1:
190
- st.session_state["rois"].pop(idx)
191
-
192
- ROI_COLORS = [(0,255,0), (255,0,0), (0,0,255), (255,255,0),
193
- (255,0,255), (0,255,255), (128,255,0), (255,128,0)]
194
-
195
- for i, roi in enumerate(st.session_state["rois"]):
196
- color = ROI_COLORS[i % len(ROI_COLORS)]
197
- color_hex = "#{:02x}{:02x}{:02x}".format(*color)
198
- with st.container(border=True):
199
- hc1, hc2, hc3 = st.columns([3, 6, 1])
200
- hc1.markdown(f"**ROI {i+1}** <span style='color:{color_hex}'>■</span>",
201
- unsafe_allow_html=True)
202
- roi["label"] = hc2.text_input("Class Label", roi["label"],
203
- key=f"roi_lbl_{i}")
204
- if len(st.session_state["rois"]) > 1:
205
- hc3.button("✕", key=f"roi_del_{i}",
206
- on_click=_remove_roi, args=(i,))
207
-
208
- cr1, cr2, cr3, cr4 = st.columns(4)
209
- roi["x0"] = int(cr1.number_input("X start", 0, W-2, int(roi["x0"]),
210
- step=1, key=f"roi_x0_{i}"))
211
- roi["y0"] = int(cr2.number_input("Y start", 0, H-2, int(roi["y0"]),
212
- step=1, key=f"roi_y0_{i}"))
213
- roi["x1"] = int(cr3.number_input("X end", roi["x0"]+1, W,
214
- min(W, int(roi["x1"])),
215
- step=1, key=f"roi_x1_{i}"))
216
- roi["y1"] = int(cr4.number_input("Y end", roi["y0"]+1, H,
217
- min(H, int(roi["y1"])),
218
- step=1, key=f"roi_y1_{i}"))
219
-
220
- st.button("➕ Add Another ROI", on_click=_add_roi,
221
- disabled=len(st.session_state["rois"]) >= 20)
222
-
223
- # Draw all ROIs on the image
224
- overlay = img_l.copy()
225
- crops = []
226
- for i, roi in enumerate(st.session_state["rois"]):
227
- color = ROI_COLORS[i % len(ROI_COLORS)]
228
- x0, y0, x1, y1 = roi["x0"], roi["y0"], roi["x1"], roi["y1"]
229
- cv2.rectangle(overlay, (x0, y0), (x1, y1), color, 2)
230
- cv2.putText(overlay, roi["label"], (x0, y0 - 6),
231
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
232
- crops.append(img_l[y0:y1, x0:x1].copy())
233
-
234
- ov1, ov2 = st.columns([3, 2])
235
- ov1.image(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB),
236
- caption="Left Image — ROIs highlighted", use_container_width=True)
237
- with ov2:
238
- for i, (c, roi) in enumerate(zip(crops, st.session_state["rois"])):
239
- st.image(cv2.cvtColor(c, cv2.COLOR_BGR2RGB),
240
- caption=f"{roi['label']} ({c.shape[1]}×{c.shape[0]})",
241
- width=160)
242
-
243
- # For backward compatibility: first ROI is the "primary"
244
- crop_bgr = crops[0]
245
- x0, y0, x1, y1 = (st.session_state["rois"][0]["x0"],
246
- st.session_state["rois"][0]["y0"],
247
- st.session_state["rois"][0]["x1"],
248
- st.session_state["rois"][0]["y1"])
249
-
250
- # -----------------------------------------------------------------------
251
- # Step 4 — Data Augmentation
252
- # -----------------------------------------------------------------------
253
- st.divider()
254
- st.subheader("Step 4: Data Augmentation")
255
- st.write("Tune the parameters below — the augmented crop updates live.")
256
-
257
- ac1, ac2 = st.columns(2)
258
- with ac1:
259
- brightness = st.slider("Brightness offset", -100, 100, 0, step=1)
260
- contrast = st.slider("Contrast scale", 0.5, 3.0, 1.0, step=0.05)
261
- rotation = st.slider("Rotation (°)", -180, 180, 0, step=1)
262
- noise = st.slider("Gaussian noise σ", 0, 50, 0, step=1)
263
- with ac2:
264
- blur = st.slider("Blur kernel (0 = off)", 0, 10, 0, step=1)
265
- shift_x = st.slider("Shift X (px)", -100, 100, 0, step=1)
266
- shift_y = st.slider("Shift Y (px)", -100, 100, 0, step=1)
267
- flip_h = st.checkbox("Flip Horizontal")
268
- flip_v = st.checkbox("Flip Vertical")
269
-
270
- aug = augment(crop_bgr, brightness, contrast, rotation,
271
- flip_h, flip_v, noise, blur, shift_x, shift_y)
272
-
273
- # Apply same augmentation to all crops
274
- all_augs = [augment(c, brightness, contrast, rotation,
275
- flip_h, flip_v, noise, blur, shift_x, shift_y)
276
- for c in crops]
277
-
278
- aug_col1, aug_col2 = st.columns(2)
279
- aug_col1.image(cv2.cvtColor(crop_bgr, cv2.COLOR_BGR2RGB),
280
- caption="Original Crop (ROI 1)", use_container_width=True)
281
- aug_col2.image(cv2.cvtColor(aug, cv2.COLOR_BGR2RGB),
282
- caption="Augmented Crop (ROI 1)", use_container_width=True)
283
-
284
- if len(crops) > 1:
285
- st.caption(f"Augmentation applied identically to all {len(crops)} ROIs.")
286
-
287
- # -----------------------------------------------------------------------
288
- # Step 5 — Lock & Store
289
- # -----------------------------------------------------------------------
290
- st.divider()
291
- if st.button("🚀 Lock Data & Proceed to Benchmark"):
292
- if not st.session_state.get("rois") or len(st.session_state["rois"]) == 0:
293
- st.error("❌ Define at least one ROI before locking!")
294
- st.stop()
295
- rois_data = []
296
- for i, roi in enumerate(st.session_state["rois"]):
297
- rois_data.append({
298
- "label": roi["label"],
299
- "bbox": (roi["x0"], roi["y0"], roi["x1"], roi["y1"]),
300
- "crop": crops[i],
301
- "crop_aug": all_augs[i],
302
- })
303
-
304
- st.session_state["pipeline_data"] = {
305
- "left": img_l,
306
- "right": img_r,
307
- "gt_left": gt_depth_l,
308
- "gt_right": gt_depth_r,
309
- "conf_raw": conf_content,
310
- # Backward compatibility: first ROI
311
- "crop": crop_bgr,
312
- "crop_aug": aug,
313
- "crop_bbox": (x0, y0, x1, y1),
314
- # Multi-object
315
- "rois": rois_data,
316
- }
317
- st.success(f"Data stored with **{len(rois_data)} ROI(s)**! "
318
- f"Move to Feature Lab.")
319
-
320
- else:
321
- st.info("Please upload all 5 files (left image, right image, config, left GT, right GT) to proceed.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pages/3_Feature_Lab.py DELETED
@@ -1,111 +0,0 @@
1
- import streamlit as st
2
- import cv2
3
- import numpy as np
4
- import plotly.graph_objects as go
5
- import sys, os
6
- sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
7
- from src.detectors.rce.features import REGISTRY
8
- from src.models import BACKBONES
9
-
10
- st.set_page_config(page_title="Feature Lab", layout="wide")
11
-
12
- if "pipeline_data" not in st.session_state:
13
- st.error("Please complete the Data Lab first!")
14
- st.stop()
15
-
16
- assets = st.session_state["pipeline_data"]
17
- # Use augmented crop if available, otherwise fall back to original crop
18
- obj = assets.get("crop_aug", assets.get("crop"))
19
- if obj is None:
20
- st.error("No crop found. Please go back to Data Lab and define a ROI.")
21
- st.stop()
22
- gray = cv2.cvtColor(obj, cv2.COLOR_BGR2GRAY)
23
-
24
- st.title("🔬 Feature Lab: Physical Module Selection")
25
-
26
- col_rce, col_cnn = st.columns([3, 2])
27
-
28
- # ---------------------------------------------------------------------------
29
- # LEFT — RCE Modular Engine (pure UI — all math lives in features.py)
30
- # ---------------------------------------------------------------------------
31
- with col_rce:
32
- st.header("🧬 RCE: Modular Physics Engine")
33
- st.subheader("Select Active Modules")
34
-
35
- # Dynamically build checkboxes from the registry (rows of 4)
36
- active = {}
37
- items = list(REGISTRY.items())
38
- for row_start in range(0, len(items), 4):
39
- row_items = items[row_start:row_start + 4]
40
- m_cols = st.columns(4)
41
- for col, (key, meta) in zip(m_cols, row_items):
42
- active[key] = col.checkbox(meta["label"], value=(key in ("intensity", "sobel", "spectral")))
43
-
44
- # Build vector + collect visualizations by calling registry functions
45
- final_vector = []
46
- viz_images = []
47
- for key, meta in REGISTRY.items():
48
- if active[key]:
49
- vec, viz = meta["fn"](gray)
50
- final_vector.extend(vec)
51
- viz_images.append((meta["viz_title"], viz))
52
-
53
- # Visualizations (rows of 3)
54
- st.divider()
55
- if viz_images:
56
- for row_start in range(0, len(viz_images), 3):
57
- row = viz_images[row_start:row_start + 3]
58
- v_cols = st.columns(3)
59
- for col, (title, img) in zip(v_cols, row):
60
- col.image(img, caption=title, use_container_width=True)
61
- else:
62
- st.warning("No modules selected — vector is empty.")
63
-
64
- # DNA vector bar chart
65
- st.write(f"### Current DNA Vector Size: **{len(final_vector)}**")
66
- fig_vec = go.Figure(data=[go.Bar(y=final_vector, marker_color="#00d4ff")])
67
- fig_vec.update_layout(title="Active Feature Vector (RCE Input)",
68
- template="plotly_dark", height=300)
69
- st.plotly_chart(fig_vec, use_container_width=True)
70
-
71
- # ---------------------------------------------------------------------------
72
- # RIGHT — CNN comparison panel
73
- # ---------------------------------------------------------------------------
74
- with col_cnn:
75
- st.header("🧠 CNN: Static Architecture")
76
- selected_cnn = st.selectbox("Compare against Model", list(BACKBONES.keys()))
77
- st.info("CNN features are fixed by pre-trained weights. You cannot toggle them like the RCE.")
78
-
79
- with st.spinner(f"Loading {selected_cnn} and extracting activations..."):
80
- try:
81
- bmeta = BACKBONES[selected_cnn]
82
- backbone = bmeta["loader"]() # cached frozen backbone
83
- layer_name = bmeta["hook_layer"]
84
-
85
- act_maps = backbone.get_activation_maps(obj, n_maps=6)
86
- st.caption(f"Hooked layer: `{layer_name}` — showing 6 of {len(act_maps)} channels")
87
- act_cols = st.columns(3)
88
- for i, amap in enumerate(act_maps):
89
- act_cols[i % 3].image(amap, caption=f"Channel {i}", use_container_width=True)
90
-
91
- except Exception as e:
92
- st.error(f"Could not load model: {e}")
93
-
94
- st.divider()
95
- st.markdown(f"""
96
- **Analysis:**
97
- - **Modularity:** RCE is **High** | CNN is **Zero**
98
- - **Explainability:** RCE is **High** | CNN is **Low**
99
- - **Compute Cost:** {len(final_vector)} floats | 512+ floats
100
- """)
101
-
102
- # ---------------------------------------------------------------------------
103
- # Lock configuration
104
- # ---------------------------------------------------------------------------
105
- if st.button("🚀 Lock Modular Configuration"):
106
- if not final_vector:
107
- st.error("Please select at least one module!")
108
- else:
109
- st.session_state["pipeline_data"]["final_vector"] = np.array(final_vector)
110
- st.session_state["active_modules"] = {k: v for k, v in active.items()}
111
- st.success("Modular DNA Locked! Ready for Model Tuning.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pages/4_Model_Tuning.py DELETED
@@ -1,475 +0,0 @@
1
- import streamlit as st
2
- import cv2
3
- import numpy as np
4
- import time
5
- import plotly.graph_objects as go
6
- import sys, os
7
- sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
8
-
9
- from src.detectors.rce.features import REGISTRY
10
- from src.models import BACKBONES, RecognitionHead
11
- from src.utils import build_rce_vector
12
-
13
- st.set_page_config(page_title="Model Tuning", layout="wide")
14
- st.title("⚙️ Model Tuning: Train & Compare")
15
-
16
- # ---------------------------------------------------------------------------
17
- # Guard: require Data Lab completion
18
- # ---------------------------------------------------------------------------
19
- if "pipeline_data" not in st.session_state or "crop" not in st.session_state.get("pipeline_data", {}):
20
- st.error("Please complete the **Data Lab** first (upload assets & define a crop).")
21
- st.stop()
22
-
23
- assets = st.session_state["pipeline_data"]
24
- crop = assets["crop"]
25
- crop_aug = assets.get("crop_aug", crop)
26
- left_img = assets["left"]
27
- bbox = assets.get("crop_bbox", (0, 0, crop.shape[1], crop.shape[0]))
28
- rois = assets.get("rois", [{"label": "object", "bbox": bbox,
29
- "crop": crop, "crop_aug": crop_aug}])
30
- active_modules = st.session_state.get("active_modules", {k: True for k in REGISTRY})
31
-
32
- is_multi = len(rois) > 1
33
-
34
- # ---------------------------------------------------------------------------
35
- # Build training set from session data (no disk reads)
36
- # ---------------------------------------------------------------------------
37
- def build_training_set():
38
- """
39
- Multi-class aware training set builder.
40
- Positive samples per class: original crop + augmented crop.
41
- Negative samples: random patches that don't overlap ANY ROI.
42
- """
43
- images = []
44
- labels = []
45
-
46
- for roi in rois:
47
- images.append(roi["crop"])
48
- labels.append(roi["label"])
49
- images.append(roi["crop_aug"])
50
- labels.append(roi["label"])
51
-
52
- all_bboxes = [roi["bbox"] for roi in rois]
53
- H, W = left_img.shape[:2]
54
- x0r, y0r, x1r, y1r = rois[0]["bbox"]
55
- ch, cw = y1r - y0r, x1r - x0r
56
- rng = np.random.default_rng(42)
57
-
58
- n_neg_target = len(images) * 2
59
- attempts = 0
60
- negatives = []
61
- while len(negatives) < n_neg_target and attempts < 300:
62
- rx = rng.integers(0, max(W - cw, 1))
63
- ry = rng.integers(0, max(H - ch, 1))
64
- overlaps = False
65
- for bx0, by0, bx1, by1 in all_bboxes:
66
- if rx < bx1 and rx + cw > bx0 and ry < by1 and ry + ch > by0:
67
- overlaps = True
68
- break
69
- if overlaps:
70
- attempts += 1
71
- continue
72
- patch = left_img[ry:ry+ch, rx:rx+cw]
73
- if patch.shape[0] > 0 and patch.shape[1] > 0:
74
- negatives.append(patch)
75
- attempts += 1
76
-
77
- images.extend(negatives)
78
- labels.extend(["background"] * len(negatives))
79
- return images, labels, len(negatives) < n_neg_target // 2
80
-
81
-
82
- # ===================================================================
83
- # Show training data
84
- # ===================================================================
85
- st.subheader("Training Data (from Data Lab)")
86
- if is_multi:
87
- st.caption(f"**{len(rois)} classes** defined — each ROI becomes a separate class.")
88
- roi_cols = st.columns(min(len(rois), 4))
89
- for i, roi in enumerate(rois):
90
- with roi_cols[i % len(roi_cols)]:
91
- st.image(cv2.cvtColor(roi["crop"], cv2.COLOR_BGR2RGB),
92
- caption=f"✅ {roi['label']}", width=140)
93
- else:
94
- st.caption("Positives = your crop + augmented crop | "
95
- "Negatives = random non-overlapping patches")
96
- td1, td2 = st.columns(2)
97
- td1.image(cv2.cvtColor(crop, cv2.COLOR_BGR2RGB),
98
- caption="Original Crop (positive)", width=180)
99
- td2.image(cv2.cvtColor(crop_aug, cv2.COLOR_BGR2RGB),
100
- caption="Augmented Crop (positive)", width=180)
101
-
102
- st.divider()
103
-
104
- # ===================================================================
105
- # LAYOUT: RCE | CNN | ORB
106
- # ===================================================================
107
- col_rce, col_cnn, col_orb = st.columns(3)
108
-
109
- # ---------------------------------------------------------------------------
110
- # LEFT — RCE Training
111
- # ---------------------------------------------------------------------------
112
- with col_rce:
113
- st.header("🧬 RCE Training")
114
-
115
- active_names = [REGISTRY[k]["label"] for k in active_modules if active_modules[k]]
116
- if not active_names:
117
- st.error("No RCE modules selected. Go back to Feature Lab.")
118
- else:
119
- st.write(f"**Active modules:** {', '.join(active_names)}")
120
-
121
- st.subheader("Training Parameters")
122
- rce_C = st.slider("Regularization (C)", 0.01, 10.0, 1.0, step=0.01,
123
- help="Higher = less regularization, may overfit")
124
- rce_max_iter = st.slider("Max Iterations", 100, 5000, 1000, step=100)
125
-
126
- if st.button("🚀 Train RCE Head"):
127
- images, labels, neg_short = build_training_set()
128
- if neg_short:
129
- st.warning(f"⚠️ Only {sum(1 for l in labels if l == 'background')} "
130
- f"negatives collected (target was {sum(1 for l in labels if l != 'background') * 2}). "
131
- f"Training data may be imbalanced.")
132
- from sklearn.metrics import accuracy_score
133
- from sklearn.model_selection import cross_val_score
134
-
135
- progress = st.progress(0, text="Extracting RCE features...")
136
- n = len(images)
137
- X = []
138
- for i, img in enumerate(images):
139
- X.append(build_rce_vector(img, active_modules))
140
- progress.progress((i + 1) / n, text=f"Feature extraction: {i+1}/{n}")
141
-
142
- X = np.array(X)
143
- progress.progress(1.0, text="Fitting Logistic Regression...")
144
-
145
- t0 = time.perf_counter()
146
- try:
147
- head = RecognitionHead(C=rce_C, max_iter=rce_max_iter).fit(X, labels)
148
- except ValueError as e:
149
- st.error(f"Training failed: {e}")
150
- st.stop()
151
- train_time = time.perf_counter() - t0
152
- progress.progress(1.0, text="✅ Training complete!")
153
-
154
- preds = head.model.predict(X)
155
- train_acc = accuracy_score(labels, preds)
156
-
157
- st.success(f"Trained in **{train_time:.2f}s**")
158
- m1, m2, m3, m4 = st.columns(4)
159
- m1.metric("Train Accuracy", f"{train_acc:.1%}")
160
- # Cross-validation (only if enough samples)
161
- if len(images) >= 6:
162
- n_splits = min(5, len(set(labels)))
163
- if n_splits >= 2:
164
- cv_scores = cross_val_score(head.model, X, labels,
165
- cv=min(3, len(images) // 2))
166
- m2.metric("CV Accuracy", f"{cv_scores.mean():.1%}",
167
- delta=f"±{cv_scores.std():.1%}")
168
- else:
169
- m2.metric("CV Accuracy", "N/A")
170
- else:
171
- m2.metric("CV Accuracy", "N/A")
172
- m3.metric("Vector Size", f"{X.shape[1]} floats")
173
- m4.metric("Samples", f"{len(images)}")
174
- if len(images) < 10:
175
- st.warning("⚠️ Training set is small (<10 samples). "
176
- "Reported accuracy may not reflect real performance.")
177
- if is_multi:
178
- st.caption(f"Classes: {', '.join(head.classes_)}")
179
-
180
- probs = head.predict_proba(X)
181
- fig = go.Figure()
182
- for ci, cls in enumerate(head.classes_):
183
- fig.add_trace(go.Histogram(x=probs[:, ci], name=cls,
184
- opacity=0.7, nbinsx=20))
185
- fig.update_layout(title="Confidence Distribution", barmode="overlay",
186
- template="plotly_dark", height=280,
187
- xaxis_title="Confidence", yaxis_title="Count")
188
- st.plotly_chart(fig, use_container_width=True)
189
-
190
- # ---- Feature Importance (RCE) ----
191
- st.subheader("🔍 Feature Importance")
192
- coefs = head.model.coef_
193
- feat_names = []
194
- for key, meta_r in REGISTRY.items():
195
- if active_modules.get(key, False):
196
- for b in range(10):
197
- feat_names.append(f"{meta_r['label']}[{b}]")
198
-
199
- if coefs.shape[0] == 1:
200
- importance = np.abs(coefs[0])
201
- fig_imp = go.Figure(go.Bar(
202
- x=feat_names, y=importance,
203
- marker_color=["#00d4ff" if "Intensity" in fn
204
- else "#ff6600" if "Sobel" in fn
205
- else "#aa00ff" for fn in feat_names]))
206
- fig_imp.update_layout(title="LogReg Coefficient Magnitude",
207
- template="plotly_dark", height=300,
208
- xaxis_title="Feature", yaxis_title="|Coefficient|")
209
- else:
210
- fig_imp = go.Figure()
211
- for ci, cls in enumerate(head.classes_):
212
- if cls == "background":
213
- continue
214
- fig_imp.add_trace(go.Bar(
215
- x=feat_names, y=np.abs(coefs[ci]),
216
- name=cls, opacity=0.8))
217
- fig_imp.update_layout(title="LogReg Coefficients per Class",
218
- template="plotly_dark", height=300,
219
- barmode="group",
220
- xaxis_title="Feature", yaxis_title="|Coefficient|")
221
- st.plotly_chart(fig_imp, use_container_width=True)
222
-
223
- # Module-level aggregation
224
- module_importance = {}
225
- idx = 0
226
- for key, meta_r in REGISTRY.items():
227
- if active_modules.get(key, False):
228
- module_importance[meta_r["label"]] = float(
229
- np.abs(coefs[:, idx:idx+10]).mean())
230
- idx += 10
231
-
232
- if module_importance:
233
- fig_mod = go.Figure(go.Pie(
234
- labels=list(module_importance.keys()),
235
- values=list(module_importance.values()),
236
- hole=0.4))
237
- fig_mod.update_layout(title="Module Contribution (avg |coef|)",
238
- template="plotly_dark", height=280)
239
- st.plotly_chart(fig_mod, use_container_width=True)
240
-
241
- st.session_state["rce_head"] = head
242
- st.session_state["rce_train_acc"] = train_acc
243
-
244
- if "rce_head" in st.session_state:
245
- st.divider()
246
- st.subheader("Quick Predict (Crop)")
247
- head = st.session_state["rce_head"]
248
- t0 = time.perf_counter()
249
- vec = build_rce_vector(crop_aug, active_modules)
250
- label, conf = head.predict(vec)
251
- dt = (time.perf_counter() - t0) * 1000
252
- st.write(f"**{label}** — {conf:.1%} confidence — {dt:.1f} ms")
253
-
254
-
255
- # ---------------------------------------------------------------------------
256
- # MIDDLE — CNN Fine-Tuning
257
- # ---------------------------------------------------------------------------
258
- with col_cnn:
259
- st.header("🧠 CNN Fine-Tuning")
260
-
261
- selected = st.selectbox("Select Model", list(BACKBONES.keys()))
262
- meta = BACKBONES[selected]
263
- st.caption(f"Backbone embedding: **{meta['dim']}D** → Logistic Regression head")
264
-
265
- st.subheader("Training Parameters")
266
- cnn_C = st.slider("Regularization (C) ", 0.01, 10.0, 1.0, step=0.01,
267
- key="cnn_c", help="Higher = less regularization")
268
- cnn_max_iter = st.slider("Max Iterations ", 100, 5000, 1000, step=100,
269
- key="cnn_iter")
270
-
271
- if st.button(f"🚀 Train {selected} Head"):
272
- images, labels, neg_short = build_training_set()
273
- if neg_short:
274
- st.warning(f"⚠️ Negative sample shortfall — training may be imbalanced.")
275
- backbone = meta["loader"]()
276
-
277
- from sklearn.metrics import accuracy_score
278
- from sklearn.model_selection import cross_val_score
279
-
280
- progress = st.progress(0, text=f"Extracting {selected} features...")
281
- n = len(images)
282
- X = []
283
- for i, img in enumerate(images):
284
- X.append(backbone.get_features(img))
285
- progress.progress((i + 1) / n, text=f"Feature extraction: {i+1}/{n}")
286
-
287
- X = np.array(X)
288
- progress.progress(1.0, text="Fitting Logistic Regression...")
289
-
290
- t0 = time.perf_counter()
291
- try:
292
- head = RecognitionHead(C=cnn_C, max_iter=cnn_max_iter).fit(X, labels)
293
- except ValueError as e:
294
- st.error(f"Training failed: {e}")
295
- st.stop()
296
- train_time = time.perf_counter() - t0
297
- progress.progress(1.0, text="✅ Training complete!")
298
-
299
- preds = head.model.predict(X)
300
- train_acc = accuracy_score(labels, preds)
301
-
302
- st.success(f"Trained in **{train_time:.2f}s**")
303
- m1, m2, m3, m4 = st.columns(4)
304
- m1.metric("Train Accuracy", f"{train_acc:.1%}")
305
- if len(images) >= 6:
306
- n_splits = min(5, len(set(labels)))
307
- if n_splits >= 2:
308
- cv_scores = cross_val_score(head.model, X, labels,
309
- cv=min(3, len(images) // 2))
310
- m2.metric("CV Accuracy", f"{cv_scores.mean():.1%}",
311
- delta=f"±{cv_scores.std():.1%}")
312
- else:
313
- m2.metric("CV Accuracy", "N/A")
314
- else:
315
- m2.metric("CV Accuracy", "N/A")
316
- m3.metric("Vector Size", f"{X.shape[1]}D")
317
- m4.metric("Samples", f"{len(images)}")
318
- if is_multi:
319
- st.caption(f"Classes: {', '.join(head.classes_)}")
320
-
321
- probs = head.predict_proba(X)
322
- fig = go.Figure()
323
- for ci, cls in enumerate(head.classes_):
324
- fig.add_trace(go.Histogram(x=probs[:, ci], name=cls,
325
- opacity=0.7, nbinsx=20))
326
- fig.update_layout(title="Confidence Distribution", barmode="overlay",
327
- template="plotly_dark", height=280,
328
- xaxis_title="Confidence", yaxis_title="Count")
329
- st.plotly_chart(fig, use_container_width=True)
330
-
331
- # ---- Activation Overlay (Grad-CAM style) ----
332
- st.subheader("🔍 Activation Overlay")
333
- st.caption("Highest-activation spatial regions from the hooked layer, "
334
- "overlaid on the crop as a Grad-CAM–style heatmap.")
335
- try:
336
- act_maps = backbone.get_activation_maps(crop_aug, n_maps=1)
337
- if act_maps:
338
- cam = act_maps[0]
339
- cam_resized = cv2.resize(cam, (crop_aug.shape[1], crop_aug.shape[0]))
340
- cam_color = cv2.applyColorMap(
341
- (cam_resized * 255).astype(np.uint8), cv2.COLORMAP_JET)
342
- overlay_img = cv2.addWeighted(crop_aug, 0.5, cam_color, 0.5, 0)
343
- gc1, gc2 = st.columns(2)
344
- gc1.image(cv2.cvtColor(crop_aug, cv2.COLOR_BGR2RGB),
345
- caption="Input Crop", use_container_width=True)
346
- gc2.image(cv2.cvtColor(overlay_img, cv2.COLOR_BGR2RGB),
347
- caption="Activation Overlay", use_container_width=True)
348
- except Exception:
349
- pass
350
-
351
- st.session_state[f"cnn_head_{selected}"] = head
352
- st.session_state[f"cnn_acc_{selected}"] = train_acc
353
-
354
- if f"cnn_head_{selected}" in st.session_state:
355
- st.divider()
356
- st.subheader("Quick Predict (Crop)")
357
- backbone = meta["loader"]()
358
- head = st.session_state[f"cnn_head_{selected}"]
359
- t0 = time.perf_counter()
360
- feats = backbone.get_features(crop_aug)
361
- label, conf = head.predict(feats)
362
- dt = (time.perf_counter() - t0) * 1000
363
- st.write(f"**{label}** — {conf:.1%} confidence — {dt:.1f} ms")
364
-
365
-
366
- # ---------------------------------------------------------------------------
367
- # RIGHT — ORB Training
368
- # ---------------------------------------------------------------------------
369
- with col_orb:
370
- st.header("🏛️ ORB Matching")
371
- st.caption("Keypoint-based matching — a fundamentally different paradigm. "
372
- "Extracts ORB descriptors from each ROI crop and matches them "
373
- "against image patches using brute-force Hamming distance.")
374
-
375
- from src.detectors.orb import ORBDetector
376
-
377
- orb_dist_thresh = st.slider("Match Distance Threshold", 10, 100, 70,
378
- key="orb_dist")
379
- orb_min_matches = st.slider("Min Good Matches", 1, 20, 5, key="orb_min")
380
-
381
- if st.button("🚀 Train ORB Reference"):
382
- orb = ORBDetector()
383
- progress = st.progress(0, text="Extracting ORB descriptors...")
384
-
385
- orb_refs = {}
386
- for i, roi in enumerate(rois):
387
- gray = cv2.cvtColor(roi["crop_aug"], cv2.COLOR_BGR2GRAY)
388
- clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
389
- gray = clahe.apply(gray)
390
- kp, des = orb.orb.detectAndCompute(gray, None)
391
- n_feat = 0 if des is None else len(des)
392
- orb_refs[roi["label"]] = {
393
- "descriptors": des,
394
- "n_features": n_feat,
395
- "keypoints": kp,
396
- "crop": roi["crop_aug"],
397
- }
398
- progress.progress((i + 1) / len(rois),
399
- text=f"ROI {i+1}/{len(rois)}: {n_feat} features")
400
-
401
- progress.progress(1.0, text="✅ ORB references extracted!")
402
-
403
- for lbl, ref in orb_refs.items():
404
- if ref["keypoints"]:
405
- vis = cv2.drawKeypoints(ref["crop"], ref["keypoints"],
406
- None, color=(0, 255, 0))
407
- st.image(cv2.cvtColor(vis, cv2.COLOR_BGR2RGB),
408
- caption=f"{lbl}: {ref['n_features']} keypoints",
409
- use_container_width=True)
410
- else:
411
- st.warning(f"{lbl}: No keypoints detected")
412
-
413
- st.session_state["orb_detector"] = orb
414
- st.session_state["orb_refs"] = orb_refs
415
- st.session_state["orb_dist_thresh"] = orb_dist_thresh
416
- st.session_state["orb_min_matches"] = orb_min_matches
417
- st.success("ORB references stored in session!")
418
-
419
- if "orb_refs" in st.session_state:
420
- st.divider()
421
- st.subheader("Quick Predict (Crop)")
422
- orb = st.session_state["orb_detector"]
423
- refs = st.session_state["orb_refs"]
424
- dt_thresh = st.session_state["orb_dist_thresh"]
425
- min_m = st.session_state["orb_min_matches"]
426
-
427
- gray = cv2.cvtColor(crop_aug, cv2.COLOR_BGR2GRAY)
428
- clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
429
- gray = clahe.apply(gray)
430
- kp, des = orb.orb.detectAndCompute(gray, None)
431
-
432
- if des is not None:
433
- for lbl, ref in refs.items():
434
- if ref["descriptors"] is None:
435
- st.write(f"**{lbl}:** no reference features")
436
- continue
437
- matches = orb.bf.match(ref["descriptors"], des)
438
- good = [m for m in matches if m.distance < dt_thresh]
439
- conf = min(len(good) / max(min_m, 1), 1.0)
440
- verdict = lbl if len(good) >= min_m else "background"
441
- st.write(f"**{verdict}** — {len(good)} matches — "
442
- f"{conf:.0%} confidence")
443
- else:
444
- st.write("No keypoints in test image.")
445
-
446
-
447
- # ===========================================================================
448
- # Bottom — Side-by-side comparison table
449
- # ===========================================================================
450
- st.divider()
451
- st.subheader("📊 Training Comparison")
452
-
453
- rows = []
454
- rce_acc = st.session_state.get("rce_train_acc")
455
- if rce_acc is not None:
456
- rows.append({"Model": "RCE", "Type": "Feature Engineering",
457
- "Train Accuracy": f"{rce_acc:.1%}",
458
- "Vector Size": str(sum(10 for k in active_modules if active_modules[k]))})
459
- for name in BACKBONES:
460
- acc = st.session_state.get(f"cnn_acc_{name}")
461
- if acc is not None:
462
- rows.append({"Model": name, "Type": "CNN Backbone",
463
- "Train Accuracy": f"{acc:.1%}",
464
- "Vector Size": f"{BACKBONES[name]['dim']}D"})
465
- if "orb_refs" in st.session_state:
466
- total_kp = sum(r["n_features"] for r in st.session_state["orb_refs"].values())
467
- rows.append({"Model": "ORB", "Type": "Keypoint Matching",
468
- "Train Accuracy": "N/A (matching)",
469
- "Vector Size": f"{total_kp} descriptors"})
470
-
471
- if rows:
472
- import pandas as pd
473
- st.dataframe(pd.DataFrame(rows), use_container_width=True, hide_index=True)
474
- else:
475
- st.info("Train at least one model to see the comparison.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pages/5_Localization_Lab.py DELETED
@@ -1,348 +0,0 @@
1
- import streamlit as st
2
- import cv2
3
- import numpy as np
4
- import pandas as pd
5
- import plotly.graph_objects as go
6
- import sys, os
7
- sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
8
-
9
- from src.detectors.rce.features import REGISTRY
10
- from src.models import BACKBONES, RecognitionHead
11
- from src.utils import build_rce_vector
12
- from src.localization import (
13
- exhaustive_sliding_window,
14
- image_pyramid,
15
- coarse_to_fine,
16
- contour_proposals,
17
- template_matching,
18
- STRATEGIES,
19
- )
20
-
21
- st.set_page_config(page_title="Localization Lab", layout="wide")
22
- st.title("🔍 Localization Lab")
23
- st.markdown(
24
- "Compare **localization strategies** — algorithms that decide *where* "
25
- "to look in the image. The recognition head stays the same; only the "
26
- "search method changes."
27
- )
28
-
29
- # ===================================================================
30
- # Guard
31
- # ===================================================================
32
- if "pipeline_data" not in st.session_state or \
33
- "crop" not in st.session_state.get("pipeline_data", {}):
34
- st.error("Complete **Data Lab** first (upload assets & define a crop).")
35
- st.stop()
36
-
37
- assets = st.session_state["pipeline_data"]
38
- right_img = assets["right"]
39
- crop = assets["crop"]
40
- crop_aug = assets.get("crop_aug", crop)
41
- bbox = assets.get("crop_bbox", (0, 0, crop.shape[1], crop.shape[0]))
42
- active_mods = st.session_state.get("active_modules",
43
- {k: True for k in REGISTRY})
44
-
45
- x0, y0, x1, y1 = bbox
46
- win_h, win_w = y1 - y0, x1 - x0
47
-
48
- if win_h <= 0 or win_w <= 0:
49
- st.error("Invalid window size from crop bbox. "
50
- "Go back to **Data Lab** and redefine the ROI.")
51
- st.stop()
52
-
53
- rce_head = st.session_state.get("rce_head")
54
- has_any_cnn = any(f"cnn_head_{n}" in st.session_state for n in BACKBONES)
55
-
56
- if rce_head is None and not has_any_cnn:
57
- st.warning("No trained heads found. Go to **Model Tuning** first.")
58
- st.stop()
59
-
60
-
61
- # ===================================================================
62
- # RCE feature function
63
- # ===================================================================
64
- def rce_feature_fn(patch_bgr):
65
- return build_rce_vector(patch_bgr, active_mods)
66
-
67
-
68
- # ===================================================================
69
- # Algorithm Reference (collapsible)
70
- # ===================================================================
71
- st.divider()
72
- with st.expander("📚 **Algorithm Reference** — click to expand", expanded=False):
73
- tabs = st.tabs([f"{v['icon']} {k}" for k, v in STRATEGIES.items()])
74
- for tab, (name, meta) in zip(tabs, STRATEGIES.items()):
75
- with tab:
76
- st.markdown(f"### {meta['icon']} {name}")
77
- st.caption(meta["short"])
78
- st.markdown(meta["detail"])
79
-
80
-
81
- # ===================================================================
82
- # Configuration
83
- # ===================================================================
84
- st.divider()
85
- st.header("⚙️ Configuration")
86
-
87
- # --- Head selection ---
88
- col_head, col_info = st.columns([2, 3])
89
- with col_head:
90
- head_options = []
91
- if rce_head is not None:
92
- head_options.append("RCE")
93
- trained_cnns = [n for n in BACKBONES if f"cnn_head_{n}" in st.session_state]
94
- head_options.extend(trained_cnns)
95
- selected_head = st.selectbox("Recognition Head", head_options,
96
- key="loc_head")
97
-
98
- if selected_head == "RCE":
99
- feature_fn = rce_feature_fn
100
- head = rce_head
101
- else:
102
- bmeta = BACKBONES[selected_head]
103
- backbone = bmeta["loader"]()
104
- feature_fn = backbone.get_features
105
- head = st.session_state[f"cnn_head_{selected_head}"]
106
-
107
- with col_info:
108
- if selected_head == "RCE":
109
- mods = [REGISTRY[k]["label"] for k in active_mods if active_mods[k]]
110
- st.info(f"**RCE** — Modules: {', '.join(mods)}")
111
- else:
112
- st.info(f"**{selected_head}** — "
113
- f"{BACKBONES[selected_head]['dim']}D feature vector")
114
-
115
- # --- Algorithm checkboxes ---
116
- st.subheader("Select Algorithms to Compare")
117
- algo_cols = st.columns(5)
118
- algo_names = list(STRATEGIES.keys())
119
- algo_checks = {}
120
- for col, name in zip(algo_cols, algo_names):
121
- algo_checks[name] = col.checkbox(
122
- f"{STRATEGIES[name]['icon']} {name}",
123
- value=(name != "Template Matching"), # default all on except TM
124
- key=f"chk_{name}")
125
-
126
- any_selected = any(algo_checks.values())
127
-
128
- # --- Shared parameters ---
129
- st.subheader("Parameters")
130
- sp1, sp2, sp3 = st.columns(3)
131
- stride = sp1.slider("Base Stride (px)", 4, max(win_w, win_h),
132
- max(win_w // 4, 4), step=2, key="loc_stride")
133
- conf_thresh = sp2.slider("Confidence Threshold", 0.5, 1.0, 0.7, 0.05,
134
- key="loc_conf")
135
- nms_iou = sp3.slider("NMS IoU Threshold", 0.1, 0.9, 0.3, 0.05,
136
- key="loc_nms")
137
-
138
- # --- Per-algorithm settings ---
139
- with st.expander("🔧 Per-Algorithm Settings"):
140
- pa1, pa2, pa3 = st.columns(3)
141
- with pa1:
142
- st.markdown("**Image Pyramid**")
143
- pyr_min = st.slider("Min Scale", 0.3, 1.0, 0.5, 0.05, key="pyr_min")
144
- pyr_max = st.slider("Max Scale", 1.0, 2.0, 1.5, 0.1, key="pyr_max")
145
- pyr_n = st.slider("Number of Scales", 3, 7, 5, key="pyr_n")
146
- with pa2:
147
- st.markdown("**Coarse-to-Fine**")
148
- c2f_factor = st.slider("Coarse Factor", 2, 8, 4, key="c2f_factor")
149
- c2f_radius = st.slider("Refine Radius (strides)", 1, 5, 2,
150
- key="c2f_radius")
151
- with pa3:
152
- st.markdown("**Contour Proposals**")
153
- cnt_low = st.slider("Canny Low", 10, 100, 50, key="cnt_low")
154
- cnt_high = st.slider("Canny High", 50, 300, 150, key="cnt_high")
155
- cnt_tol = st.slider("Area Tolerance", 1.5, 10.0, 3.0, 0.5,
156
- key="cnt_tol")
157
-
158
- st.caption(
159
- f"Window: **{win_w}×{win_h} px** · "
160
- f"Image: **{right_img.shape[1]}×{right_img.shape[0]} px** · "
161
- f"Stride: **{stride} px**"
162
- )
163
-
164
-
165
- # ===================================================================
166
- # Run
167
- # ===================================================================
168
- st.divider()
169
- run_btn = st.button("▶ Run Comparison", type="primary",
170
- disabled=not any_selected, use_container_width=True)
171
-
172
- if run_btn:
173
- selected_algos = [n for n in algo_names if algo_checks[n]]
174
- progress = st.progress(0, text="Starting…")
175
- results = {}
176
- edge_maps = {} # for contour visualisation
177
-
178
- for i, name in enumerate(selected_algos):
179
- progress.progress(i / len(selected_algos), text=f"Running **{name}**…")
180
-
181
- if name == "Exhaustive Sliding Window":
182
- dets, n, ms, hmap = exhaustive_sliding_window(
183
- right_img, win_h, win_w, feature_fn, head,
184
- stride, conf_thresh, nms_iou)
185
-
186
- elif name == "Image Pyramid":
187
- scales = np.linspace(pyr_min, pyr_max, pyr_n).tolist()
188
- dets, n, ms, hmap = image_pyramid(
189
- right_img, win_h, win_w, feature_fn, head,
190
- stride, conf_thresh, nms_iou, scales=scales)
191
-
192
- elif name == "Coarse-to-Fine":
193
- dets, n, ms, hmap = coarse_to_fine(
194
- right_img, win_h, win_w, feature_fn, head,
195
- stride, conf_thresh, nms_iou,
196
- coarse_factor=c2f_factor, refine_radius=c2f_radius)
197
-
198
- elif name == "Contour Proposals":
199
- dets, n, ms, hmap, edges = contour_proposals(
200
- right_img, win_h, win_w, feature_fn, head,
201
- conf_thresh, nms_iou,
202
- canny_low=cnt_low, canny_high=cnt_high,
203
- area_tolerance=cnt_tol)
204
- edge_maps[name] = edges
205
-
206
- elif name == "Template Matching":
207
- dets, n, ms, hmap = template_matching(
208
- right_img, crop_aug, conf_thresh, nms_iou)
209
-
210
- results[name] = {
211
- "dets": dets, "n_proposals": n,
212
- "time_ms": ms, "heatmap": hmap,
213
- }
214
-
215
- progress.progress(1.0, text="Done!")
216
-
217
- # ===============================================================
218
- # Summary Table
219
- # ===============================================================
220
- st.header("📊 Results")
221
-
222
- baseline_ms = results.get("Exhaustive Sliding Window", {}).get("time_ms")
223
- rows = []
224
- for name, r in results.items():
225
- speedup = (baseline_ms / r["time_ms"]
226
- if baseline_ms and r["time_ms"] > 0 else None)
227
- rows.append({
228
- "Algorithm": name,
229
- "Proposals": r["n_proposals"],
230
- "Time (ms)": round(r["time_ms"], 1),
231
- "Detections": len(r["dets"]),
232
- "ms / Proposal": round(r["time_ms"] / max(r["n_proposals"], 1), 4),
233
- "Speedup": f"{speedup:.1f}×" if speedup else "—",
234
- })
235
-
236
- st.dataframe(pd.DataFrame(rows), use_container_width=True, hide_index=True)
237
-
238
- # ===============================================================
239
- # Detection Images & Heatmaps (one tab per algorithm)
240
- # ===============================================================
241
- st.subheader("Detection Results")
242
- COLORS = {
243
- "Exhaustive Sliding Window": (0, 255, 0),
244
- "Image Pyramid": (255, 128, 0),
245
- "Coarse-to-Fine": (0, 128, 255),
246
- "Contour Proposals": (255, 0, 255),
247
- "Template Matching": (0, 255, 255),
248
- }
249
-
250
- result_tabs = st.tabs(
251
- [f"{STRATEGIES[n]['icon']} {n}" for n in results])
252
-
253
- for tab, (name, r) in zip(result_tabs, results.items()):
254
- with tab:
255
- c1, c2 = st.columns(2)
256
- color = COLORS.get(name, (0, 255, 0))
257
-
258
- # --- Detection overlay ---
259
- vis = right_img.copy()
260
- for x1d, y1d, x2d, y2d, _, cf in r["dets"]:
261
- cv2.rectangle(vis, (x1d, y1d), (x2d, y2d), color, 2)
262
- cv2.putText(vis, f"{cf:.0%}", (x1d, y1d - 6),
263
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
264
- c1.image(cv2.cvtColor(vis, cv2.COLOR_BGR2RGB),
265
- caption=f"{name} — {len(r['dets'])} detections",
266
- use_container_width=True)
267
-
268
- # --- Heatmap ---
269
- hmap = r["heatmap"]
270
- if hmap.max() > 0:
271
- hmap_color = cv2.applyColorMap(
272
- (hmap / hmap.max() * 255).astype(np.uint8),
273
- cv2.COLORMAP_JET)
274
- blend = cv2.addWeighted(right_img, 0.5, hmap_color, 0.5, 0)
275
- c2.image(cv2.cvtColor(blend, cv2.COLOR_BGR2RGB),
276
- caption=f"{name} — Confidence Heatmap",
277
- use_container_width=True)
278
- else:
279
- c2.info("No positive responses above threshold.")
280
-
281
- # --- Contour edge map (extra) ---
282
- if name in edge_maps:
283
- st.image(edge_maps[name],
284
- caption="Canny Edge Map (proposals derived from these contours)",
285
- use_container_width=True, clamp=True)
286
-
287
- # --- Per-algorithm metrics ---
288
- m1, m2, m3, m4 = st.columns(4)
289
- m1.metric("Proposals", r["n_proposals"])
290
- m2.metric("Time", f"{r['time_ms']:.0f} ms")
291
- m3.metric("Detections", len(r["dets"]))
292
- m4.metric("ms / Proposal",
293
- f"{r['time_ms'] / max(r['n_proposals'], 1):.3f}")
294
-
295
- # --- Detection table ---
296
- if r["dets"]:
297
- df = pd.DataFrame(r["dets"],
298
- columns=["x1","y1","x2","y2","label","conf"])
299
- st.dataframe(df, use_container_width=True, hide_index=True)
300
-
301
- # ===============================================================
302
- # Performance Charts
303
- # ===============================================================
304
- st.subheader("📈 Performance Comparison")
305
- ch1, ch2 = st.columns(2)
306
-
307
- names = list(results.keys())
308
- times = [results[n]["time_ms"] for n in names]
309
- props = [results[n]["n_proposals"] for n in names]
310
- n_dets = [len(results[n]["dets"]) for n in names]
311
- colors_hex = ["#00cc66", "#ff8800", "#0088ff", "#ff00ff", "#00cccc"]
312
-
313
- with ch1:
314
- fig = go.Figure(go.Bar(
315
- x=names, y=times,
316
- text=[f"{t:.0f}" for t in times], textposition="auto",
317
- marker_color=colors_hex[:len(names)]))
318
- fig.update_layout(title="Total Time (ms)",
319
- yaxis_title="ms", height=400)
320
- st.plotly_chart(fig, use_container_width=True)
321
-
322
- with ch2:
323
- fig = go.Figure(go.Bar(
324
- x=names, y=props,
325
- text=[str(p) for p in props], textposition="auto",
326
- marker_color=colors_hex[:len(names)]))
327
- fig.update_layout(title="Proposals Evaluated",
328
- yaxis_title="Count", height=400)
329
- st.plotly_chart(fig, use_container_width=True)
330
-
331
- # --- Scatter: proposals vs time (marker = detections) ---
332
- fig = go.Figure()
333
- for i, name in enumerate(names):
334
- fig.add_trace(go.Scatter(
335
- x=[props[i]], y=[times[i]],
336
- mode="markers+text",
337
- marker=dict(size=max(n_dets[i] * 12, 18),
338
- color=colors_hex[i % len(colors_hex)]),
339
- text=[name], textposition="top center",
340
- name=name,
341
- ))
342
- fig.update_layout(
343
- title="Proposals vs Time (marker size ∝ detections)",
344
- xaxis_title="Proposals Evaluated",
345
- yaxis_title="Time (ms)",
346
- height=500,
347
- )
348
- st.plotly_chart(fig, use_container_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pages/6_RealTime_Detection.py DELETED
@@ -1,435 +0,0 @@
1
- import streamlit as st
2
- import cv2
3
- import numpy as np
4
- import time
5
- import plotly.graph_objects as go
6
- import sys, os
7
- sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
8
-
9
- from src.detectors.rce.features import REGISTRY
10
- from src.models import BACKBONES, RecognitionHead
11
- from src.utils import build_rce_vector
12
- from src.localization import nms as _nms, _iou
13
-
14
- st.set_page_config(page_title="Real-Time Detection", layout="wide")
15
- st.title("🎯 Real-Time Detection")
16
-
17
- # ---------------------------------------------------------------------------
18
- # Guard
19
- # ---------------------------------------------------------------------------
20
- if "pipeline_data" not in st.session_state or "crop" not in st.session_state.get("pipeline_data", {}):
21
- st.error("Complete **Data Lab** first (upload assets & define a crop).")
22
- st.stop()
23
-
24
- assets = st.session_state["pipeline_data"]
25
- right_img = assets["right"]
26
- crop = assets["crop"]
27
- crop_aug = assets.get("crop_aug", crop)
28
- bbox = assets.get("crop_bbox", (0, 0, crop.shape[1], crop.shape[0]))
29
- rois = assets.get("rois", [{"label": "object", "bbox": bbox,
30
- "crop": crop, "crop_aug": crop_aug}])
31
- active_mods = st.session_state.get("active_modules", {k: True for k in REGISTRY})
32
-
33
- x0, y0, x1, y1 = bbox
34
- win_h, win_w = y1 - y0, x1 - x0 # window = same size as crop
35
-
36
- if win_h <= 0 or win_w <= 0:
37
- st.error("Invalid window size from crop bbox. "
38
- "Go back to **Data Lab** and redefine the ROI.")
39
- st.stop()
40
-
41
- # Color palette for multi-class drawing
42
- CLASS_COLORS = [(0,255,0),(0,0,255),(255,165,0),(255,0,255),(0,255,255),
43
- (128,255,0),(255,128,0),(0,128,255)]
44
-
45
- rce_head = st.session_state.get("rce_head")
46
- has_any_cnn = any(f"cnn_head_{n}" in st.session_state for n in BACKBONES)
47
- has_orb = "orb_refs" in st.session_state
48
-
49
- if rce_head is None and not has_any_cnn and not has_orb:
50
- st.warning("No trained heads found. Go to **Model Tuning** and train at least one head.")
51
- st.stop()
52
-
53
-
54
- # ===================================================================
55
- # Sliding Window Engine (shared by both sides)
56
- # ===================================================================
57
- def sliding_window_detect(
58
- image: np.ndarray,
59
- feature_fn, # callable(patch_bgr) -> 1-D np.ndarray
60
- head: RecognitionHead,
61
- stride: int,
62
- conf_thresh: float,
63
- nms_iou: float,
64
- progress_placeholder=None,
65
- live_image_placeholder=None,
66
- ):
67
- """
68
- Slide a window of size (win_h, win_w) across *image* with *stride*.
69
- At each position call *feature_fn* → *head.predict*.
70
- Returns (detections, heatmap, total_time_ms, n_windows).
71
-
72
- Each detection is (x, y, x+win_w, y+win_h, label, confidence).
73
- heatmap is a float32 array same size as image (object confidence).
74
- """
75
- H, W = image.shape[:2]
76
- heatmap = np.zeros((H, W), dtype=np.float32)
77
- detections = []
78
- t0 = time.perf_counter()
79
-
80
- positions = []
81
- for y in range(0, H - win_h + 1, stride):
82
- for x in range(0, W - win_w + 1, stride):
83
- positions.append((x, y))
84
-
85
- n_total = len(positions)
86
- if n_total == 0:
87
- return [], heatmap, 0.0, 0
88
-
89
- for idx, (x, y) in enumerate(positions):
90
- patch = image[y:y+win_h, x:x+win_w]
91
- feats = feature_fn(patch)
92
- label, conf = head.predict(feats)
93
-
94
- # Fill heatmap with non-background confidence
95
- if label != "background":
96
- heatmap[y:y+win_h, x:x+win_w] = np.maximum(
97
- heatmap[y:y+win_h, x:x+win_w], conf)
98
- if conf >= conf_thresh:
99
- detections.append((x, y, x+win_w, y+win_h, label, conf))
100
-
101
- # Live updates (every 5th window or last)
102
- if live_image_placeholder is not None and (idx % 5 == 0 or idx == n_total - 1):
103
- vis = image.copy()
104
- # Draw current scan position
105
- cv2.rectangle(vis, (x, y), (x+win_w, y+win_h), (255, 255, 0), 1)
106
- # Draw current detections
107
- for dx, dy, dx2, dy2, dl, dc in detections:
108
- cv2.rectangle(vis, (dx, dy), (dx2, dy2), (0, 255, 0), 2)
109
- cv2.putText(vis, f"{dc:.0%}", (dx, dy - 4),
110
- cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 255, 0), 1)
111
- live_image_placeholder.image(
112
- cv2.cvtColor(vis, cv2.COLOR_BGR2RGB),
113
- caption=f"Scanning… {idx+1}/{n_total}",
114
- use_container_width=True)
115
-
116
- if progress_placeholder is not None:
117
- progress_placeholder.progress(
118
- (idx + 1) / n_total,
119
- text=f"Window {idx+1}/{n_total}")
120
-
121
- total_ms = (time.perf_counter() - t0) * 1000
122
-
123
- # --- Non-Maximum Suppression ---
124
- if detections:
125
- detections = _nms(detections, nms_iou)
126
-
127
- return detections, heatmap, total_ms, n_total
128
-
129
-
130
- # ===================================================================
131
- # RCE feature function
132
- # ===================================================================
133
- def rce_feature_fn(patch_bgr):
134
- return build_rce_vector(patch_bgr, active_mods)
135
-
136
-
137
- # ===================================================================
138
- # Controls
139
- # ===================================================================
140
- st.subheader("Sliding Window Parameters")
141
- p1, p2, p3 = st.columns(3)
142
- stride = p1.slider("Stride (px)", 4, max(win_w // 2, 4),
143
- max(win_w // 4, 4), step=2,
144
- help="Lower = more windows = slower but finer")
145
- conf_thresh = p2.slider("Confidence Threshold", 0.5, 1.0, 0.7, 0.05)
146
- nms_iou = p3.slider("NMS IoU Threshold", 0.1, 0.9, 0.3, 0.05)
147
-
148
- st.caption(f"Window size: **{win_w}×{win_h} px** | "
149
- f"Right image: **{right_img.shape[1]}×{right_img.shape[0]} px** | "
150
- f"≈ {((right_img.shape[0]-win_h)//stride + 1) * ((right_img.shape[1]-win_w)//stride + 1)} windows")
151
-
152
- st.divider()
153
-
154
- # ===================================================================
155
- # Side-by-side layout
156
- # ===================================================================
157
- col_rce, col_cnn, col_orb = st.columns(3)
158
-
159
- # -------------------------------------------------------------------
160
- # LEFT — RCE Detection
161
- # -------------------------------------------------------------------
162
- with col_rce:
163
- st.header("🧬 RCE Detection")
164
- if rce_head is None:
165
- st.info("No RCE head trained. Train one in **Model Tuning**.")
166
- else:
167
- st.caption(f"Modules: {', '.join(REGISTRY[k]['label'] for k in active_mods if active_mods[k])}")
168
- rce_run = st.button("▶ Run RCE Scan", key="rce_run")
169
-
170
- rce_progress = st.empty()
171
- rce_live = st.empty()
172
- rce_results = st.container()
173
-
174
- if rce_run:
175
- dets, hmap, ms, nw = sliding_window_detect(
176
- right_img, rce_feature_fn, rce_head,
177
- stride, conf_thresh, nms_iou,
178
- progress_placeholder=rce_progress,
179
- live_image_placeholder=rce_live,
180
- )
181
-
182
- # Final image with boxes
183
- final = right_img.copy()
184
- class_labels = sorted(set(d[4] for d in dets)) if dets else []
185
- for x1d, y1d, x2d, y2d, lbl, cf in dets:
186
- ci = class_labels.index(lbl) if lbl in class_labels else 0
187
- clr = CLASS_COLORS[ci % len(CLASS_COLORS)]
188
- cv2.rectangle(final, (x1d, y1d), (x2d, y2d), clr, 2)
189
- cv2.putText(final, f"{lbl} {cf:.0%}", (x1d, y1d - 6),
190
- cv2.FONT_HERSHEY_SIMPLEX, 0.4, clr, 1)
191
- rce_live.image(cv2.cvtColor(final, cv2.COLOR_BGR2RGB),
192
- caption="RCE — Final Detections",
193
- use_container_width=True)
194
- rce_progress.empty()
195
-
196
- with rce_results:
197
- # Metrics
198
- rm1, rm2, rm3, rm4 = st.columns(4)
199
- rm1.metric("Detections", len(dets))
200
- rm2.metric("Windows", nw)
201
- rm3.metric("Total Time", f"{ms:.0f} ms")
202
- rm4.metric("Per Window", f"{ms/max(nw,1):.2f} ms")
203
-
204
- # Confidence heatmap
205
- if hmap.max() > 0:
206
- hmap_color = cv2.applyColorMap(
207
- (hmap / hmap.max() * 255).astype(np.uint8),
208
- cv2.COLORMAP_JET)
209
- blend = cv2.addWeighted(right_img, 0.5, hmap_color, 0.5, 0)
210
- st.image(cv2.cvtColor(blend, cv2.COLOR_BGR2RGB),
211
- caption="RCE — Confidence Heatmap",
212
- use_container_width=True)
213
-
214
- # Detection table
215
- if dets:
216
- import pandas as pd
217
- df = pd.DataFrame(dets, columns=["x1","y1","x2","y2","label","conf"])
218
- st.dataframe(df, use_container_width=True, hide_index=True)
219
-
220
- st.session_state["rce_dets"] = dets
221
- st.session_state["rce_det_ms"] = ms
222
-
223
- # -------------------------------------------------------------------
224
- # RIGHT — CNN Detection
225
- # -------------------------------------------------------------------
226
- with col_cnn:
227
- st.header("🧠 CNN Detection")
228
-
229
- # Find which CNN heads are trained
230
- trained_cnns = [n for n in BACKBONES if f"cnn_head_{n}" in st.session_state]
231
- if not trained_cnns:
232
- st.info("No CNN head trained. Train one in **Model Tuning**.")
233
- else:
234
- selected = st.selectbox("Select Model", trained_cnns, key="det_cnn_sel")
235
- bmeta = BACKBONES[selected]
236
- backbone = bmeta["loader"]()
237
- head = st.session_state[f"cnn_head_{selected}"]
238
-
239
- st.caption(f"Backbone: **{selected}** ({bmeta['dim']}D) — Head in session state")
240
- cnn_run = st.button(f"▶ Run {selected} Scan", key="cnn_run")
241
-
242
- cnn_progress = st.empty()
243
- cnn_live = st.empty()
244
- cnn_results = st.container()
245
-
246
- if cnn_run:
247
- dets, hmap, ms, nw = sliding_window_detect(
248
- right_img, backbone.get_features, head,
249
- stride, conf_thresh, nms_iou,
250
- progress_placeholder=cnn_progress,
251
- live_image_placeholder=cnn_live,
252
- )
253
-
254
- # Final image
255
- final = right_img.copy()
256
- class_labels = sorted(set(d[4] for d in dets)) if dets else []
257
- for x1d, y1d, x2d, y2d, lbl, cf in dets:
258
- ci = class_labels.index(lbl) if lbl in class_labels else 0
259
- clr = CLASS_COLORS[ci % len(CLASS_COLORS)]
260
- cv2.rectangle(final, (x1d, y1d), (x2d, y2d), clr, 2)
261
- cv2.putText(final, f"{lbl} {cf:.0%}", (x1d, y1d - 6),
262
- cv2.FONT_HERSHEY_SIMPLEX, 0.4, clr, 1)
263
- cnn_live.image(cv2.cvtColor(final, cv2.COLOR_BGR2RGB),
264
- caption=f"{selected} — Final Detections",
265
- use_container_width=True)
266
- cnn_progress.empty()
267
-
268
- with cnn_results:
269
- cm1, cm2, cm3, cm4 = st.columns(4)
270
- cm1.metric("Detections", len(dets))
271
- cm2.metric("Windows", nw)
272
- cm3.metric("Total Time", f"{ms:.0f} ms")
273
- cm4.metric("Per Window", f"{ms/max(nw,1):.2f} ms")
274
-
275
- if hmap.max() > 0:
276
- hmap_color = cv2.applyColorMap(
277
- (hmap / hmap.max() * 255).astype(np.uint8),
278
- cv2.COLORMAP_JET)
279
- blend = cv2.addWeighted(right_img, 0.5, hmap_color, 0.5, 0)
280
- st.image(cv2.cvtColor(blend, cv2.COLOR_BGR2RGB),
281
- caption=f"{selected} — Confidence Heatmap",
282
- use_container_width=True)
283
-
284
- if dets:
285
- import pandas as pd
286
- df = pd.DataFrame(dets, columns=["x1","y1","x2","y2","label","conf"])
287
- st.dataframe(df, use_container_width=True, hide_index=True)
288
-
289
- st.session_state["cnn_dets"] = dets
290
- st.session_state["cnn_det_ms"] = ms
291
-
292
-
293
- # -------------------------------------------------------------------
294
- # RIGHT — ORB Detection
295
- # -------------------------------------------------------------------
296
- with col_orb:
297
- st.header("🏛️ ORB Detection")
298
- if not has_orb:
299
- st.info("No ORB reference trained. Train one in **Model Tuning**.")
300
- else:
301
- orb_det = st.session_state["orb_detector"]
302
- orb_refs = st.session_state["orb_refs"]
303
- dt_thresh = st.session_state.get("orb_dist_thresh", 70)
304
- min_m = st.session_state.get("orb_min_matches", 5)
305
- st.caption(f"References: {', '.join(orb_refs.keys())} | "
306
- f"dist<{dt_thresh}, min {min_m} matches")
307
- orb_run = st.button("▶ Run ORB Scan", key="orb_run")
308
-
309
- orb_progress = st.empty()
310
- orb_live = st.empty()
311
- orb_results = st.container()
312
-
313
- if orb_run:
314
- H, W = right_img.shape[:2]
315
- positions = [(x, y)
316
- for y in range(0, H - win_h + 1, stride)
317
- for x in range(0, W - win_w + 1, stride)]
318
- n_total = len(positions)
319
- heatmap = np.zeros((H, W), dtype=np.float32)
320
- detections = []
321
- t0 = time.perf_counter()
322
-
323
- clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
324
-
325
- for idx, (px, py) in enumerate(positions):
326
- patch = right_img[py:py+win_h, px:px+win_w]
327
- gray = cv2.cvtColor(patch, cv2.COLOR_BGR2GRAY)
328
- gray = clahe.apply(gray)
329
- kp, des = orb_det.orb.detectAndCompute(gray, None)
330
-
331
- if des is not None:
332
- best_label, best_conf = "background", 0.0
333
- for lbl, ref in orb_refs.items():
334
- if ref["descriptors"] is None:
335
- continue
336
- matches = orb_det.bf.match(ref["descriptors"], des)
337
- good = [m for m in matches if m.distance < dt_thresh]
338
- conf = min(len(good) / max(min_m, 1), 1.0)
339
- if len(good) >= min_m and conf > best_conf:
340
- best_label, best_conf = lbl, conf
341
-
342
- if best_label != "background":
343
- heatmap[py:py+win_h, px:px+win_w] = np.maximum(
344
- heatmap[py:py+win_h, px:px+win_w], best_conf)
345
- if best_conf >= conf_thresh:
346
- detections.append(
347
- (px, py, px+win_w, py+win_h, best_label, best_conf))
348
-
349
- if idx % 5 == 0 or idx == n_total - 1:
350
- orb_progress.progress((idx+1)/n_total,
351
- text=f"Window {idx+1}/{n_total}")
352
-
353
- total_ms = (time.perf_counter() - t0) * 1000
354
- if detections:
355
- detections = _nms(detections, nms_iou)
356
-
357
- final = right_img.copy()
358
- cls_labels = sorted(set(d[4] for d in detections)) if detections else []
359
- for x1d, y1d, x2d, y2d, lbl, cf in detections:
360
- ci = cls_labels.index(lbl) if lbl in cls_labels else 0
361
- clr = CLASS_COLORS[ci % len(CLASS_COLORS)]
362
- cv2.rectangle(final, (x1d, y1d), (x2d, y2d), clr, 2)
363
- cv2.putText(final, f"{lbl} {cf:.0%}", (x1d, y1d - 6),
364
- cv2.FONT_HERSHEY_SIMPLEX, 0.4, clr, 1)
365
- orb_live.image(cv2.cvtColor(final, cv2.COLOR_BGR2RGB),
366
- caption="ORB — Final Detections",
367
- use_container_width=True)
368
- orb_progress.empty()
369
-
370
- with orb_results:
371
- om1, om2, om3, om4 = st.columns(4)
372
- om1.metric("Detections", len(detections))
373
- om2.metric("Windows", n_total)
374
- om3.metric("Total Time", f"{total_ms:.0f} ms")
375
- om4.metric("Per Window", f"{total_ms/max(n_total,1):.2f} ms")
376
-
377
- if heatmap.max() > 0:
378
- hmap_color = cv2.applyColorMap(
379
- (heatmap / heatmap.max() * 255).astype(np.uint8),
380
- cv2.COLORMAP_JET)
381
- blend = cv2.addWeighted(right_img, 0.5, hmap_color, 0.5, 0)
382
- st.image(cv2.cvtColor(blend, cv2.COLOR_BGR2RGB),
383
- caption="ORB — Confidence Heatmap",
384
- use_container_width=True)
385
-
386
- if detections:
387
- import pandas as pd
388
- df = pd.DataFrame(detections,
389
- columns=["x1","y1","x2","y2","label","conf"])
390
- st.dataframe(df, use_container_width=True, hide_index=True)
391
-
392
- st.session_state["orb_dets"] = detections
393
- st.session_state["orb_det_ms"] = total_ms
394
-
395
-
396
- # ===================================================================
397
- # Bottom — Comparison (if any two have run)
398
- # ===================================================================
399
- rce_dets = st.session_state.get("rce_dets")
400
- cnn_dets = st.session_state.get("cnn_dets")
401
- orb_dets = st.session_state.get("orb_dets")
402
-
403
- methods = {}
404
- if rce_dets is not None:
405
- methods["RCE"] = (rce_dets, st.session_state.get("rce_det_ms", 0), (0,255,0))
406
- if cnn_dets is not None:
407
- methods["CNN"] = (cnn_dets, st.session_state.get("cnn_det_ms", 0), (0,0,255))
408
- if orb_dets is not None:
409
- methods["ORB"] = (orb_dets, st.session_state.get("orb_det_ms", 0), (255,165,0))
410
-
411
- if len(methods) >= 2:
412
- st.divider()
413
- st.subheader("📊 Side-by-Side Comparison")
414
-
415
- import pandas as pd
416
- comp = {"Metric": ["Detections", "Best Confidence", "Total Time (ms)"]}
417
- for name, (dets, ms, _) in methods.items():
418
- comp[name] = [
419
- len(dets),
420
- f"{max((d[5] for d in dets), default=0):.1%}",
421
- f"{ms:.0f}",
422
- ]
423
- st.dataframe(pd.DataFrame(comp), use_container_width=True, hide_index=True)
424
-
425
- # Overlay all methods on one image
426
- overlay = right_img.copy()
427
- for name, (dets, _, clr) in methods.items():
428
- for x1d, y1d, x2d, y2d, lbl, cf in dets:
429
- cv2.rectangle(overlay, (x1d, y1d), (x2d, y2d), clr, 2)
430
- cv2.putText(overlay, f"{name}:{lbl} {cf:.0%}", (x1d, y1d - 6),
431
- cv2.FONT_HERSHEY_SIMPLEX, 0.35, clr, 1)
432
- legend = " | ".join(f"{n}={'green' if c==(0,255,0) else 'blue' if c==(0,0,255) else 'orange'}"
433
- for n, (_, _, c) in methods.items())
434
- st.image(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB),
435
- caption=legend, use_container_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pages/7_Evaluation.py DELETED
@@ -1,295 +0,0 @@
1
- import streamlit as st
2
- import cv2
3
- import numpy as np
4
- import plotly.graph_objects as go
5
- import plotly.figure_factory as ff
6
- import sys, os
7
- sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
8
-
9
- from src.detectors.rce.features import REGISTRY
10
- from src.models import BACKBONES
11
-
12
- st.set_page_config(page_title="Evaluation", layout="wide")
13
- st.title("📈 Evaluation: Confusion Matrix & PR Curves")
14
-
15
- # ---------------------------------------------------------------------------
16
- # Guard
17
- # ---------------------------------------------------------------------------
18
- if "pipeline_data" not in st.session_state:
19
- st.error("Complete the **Data Lab** first.")
20
- st.stop()
21
-
22
- assets = st.session_state["pipeline_data"]
23
- crop = assets["crop"]
24
- crop_aug = assets.get("crop_aug", crop)
25
- bbox = assets.get("crop_bbox", (0, 0, crop.shape[1], crop.shape[0]))
26
- rois = assets.get("rois", [{"label": "object", "bbox": bbox,
27
- "crop": crop, "crop_aug": crop_aug}])
28
-
29
- rce_dets = st.session_state.get("rce_dets")
30
- cnn_dets = st.session_state.get("cnn_dets")
31
- orb_dets = st.session_state.get("orb_dets")
32
-
33
- if rce_dets is None and cnn_dets is None and orb_dets is None:
34
- st.warning("Run detection on at least one method in **Real-Time Detection** first.")
35
- st.stop()
36
-
37
-
38
- # ---------------------------------------------------------------------------
39
- # Ground truth from ROIs
40
- # ---------------------------------------------------------------------------
41
- gt_boxes = [(roi["bbox"], roi["label"]) for roi in rois]
42
-
43
- st.sidebar.subheader("Evaluation Settings")
44
- iou_thresh = st.sidebar.slider("IoU Threshold", 0.1, 0.9, 0.5, 0.05,
45
- help="Minimum IoU to count a detection as TP")
46
-
47
- st.subheader("Ground Truth (from Data Lab ROIs)")
48
- st.caption(f"{len(gt_boxes)} ground-truth ROIs defined")
49
- gt_vis = assets["right"].copy()
50
- for (bx0, by0, bx1, by1), lbl in gt_boxes:
51
- cv2.rectangle(gt_vis, (bx0, by0), (bx1, by1), (0, 255, 255), 2)
52
- cv2.putText(gt_vis, lbl, (bx0, by0 - 6),
53
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 255), 1)
54
- st.image(cv2.cvtColor(gt_vis, cv2.COLOR_BGR2RGB),
55
- caption="Ground Truth Annotations", use_container_width=True)
56
-
57
- st.divider()
58
-
59
-
60
- # ---------------------------------------------------------------------------
61
- # Matching helpers
62
- # ---------------------------------------------------------------------------
63
- def _iou(a, b):
64
- xi1 = max(a[0], b[0]); yi1 = max(a[1], b[1])
65
- xi2 = min(a[2], b[2]); yi2 = min(a[3], b[3])
66
- inter = max(0, xi2 - xi1) * max(0, yi2 - yi1)
67
- aa = (a[2] - a[0]) * (a[3] - a[1])
68
- ab = (b[2] - b[0]) * (b[3] - b[1])
69
- return inter / (aa + ab - inter + 1e-6)
70
-
71
-
72
- def match_detections(dets, gt_list, iou_thr):
73
- """
74
- Match detections to GT boxes.
75
- Returns (results, n_missed, matched_gt_indices).
76
- results = list of (det, matched_gt_label_or_None, iou) sorted by confidence.
77
- matched_gt_indices = set of GT indices that were matched.
78
- """
79
- dets_sorted = sorted(dets, key=lambda d: d[5], reverse=True)
80
- matched_gt = set()
81
- results = []
82
-
83
- for det in dets_sorted:
84
- det_box = det[:4]
85
- det_label = det[4]
86
- best_iou = 0.0
87
- best_gt_idx = -1
88
- best_gt_label = None
89
-
90
- for gi, (gt_box, gt_label) in enumerate(gt_list):
91
- if gi in matched_gt:
92
- continue
93
- iou_val = _iou(det_box, gt_box)
94
- if iou_val > best_iou:
95
- best_iou = iou_val
96
- best_gt_idx = gi
97
- best_gt_label = gt_label
98
-
99
- if best_iou >= iou_thr and best_gt_idx >= 0:
100
- matched_gt.add(best_gt_idx)
101
- results.append((det, best_gt_label, best_iou))
102
- else:
103
- results.append((det, None, best_iou))
104
-
105
- return results, len(gt_list) - len(matched_gt), matched_gt
106
-
107
-
108
- def compute_pr_curve(dets, gt_list, iou_thr, steps=50):
109
- """
110
- Sweep confidence thresholds and compute precision/recall.
111
- Returns (thresholds, precisions, recalls, f1s).
112
- """
113
- if not dets:
114
- return [], [], [], []
115
-
116
- thresholds = np.linspace(0.0, 1.0, steps)
117
- precisions, recalls, f1s = [], [], []
118
-
119
- for thr in thresholds:
120
- filtered = [d for d in dets if d[5] >= thr]
121
- if not filtered:
122
- precisions.append(1.0)
123
- recalls.append(0.0)
124
- f1s.append(0.0)
125
- continue
126
-
127
- matched, n_missed, _ = match_detections(filtered, gt_list, iou_thr)
128
- tp = sum(1 for _, gt_lbl, _ in matched if gt_lbl is not None)
129
- fp = sum(1 for _, gt_lbl, _ in matched if gt_lbl is None)
130
- fn = n_missed
131
-
132
- prec = tp / (tp + fp) if (tp + fp) > 0 else 1.0
133
- rec = tp / (tp + fn) if (tp + fn) > 0 else 0.0
134
- f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0.0
135
- precisions.append(prec)
136
- recalls.append(rec)
137
- f1s.append(f1)
138
-
139
- return thresholds.tolist(), precisions, recalls, f1s
140
-
141
-
142
- def build_confusion_matrix(dets, gt_list, iou_thr):
143
- """
144
- Build a confusion matrix: rows = predicted, cols = actual.
145
- Classes = all GT labels + 'background'.
146
- """
147
- gt_labels = sorted(set(lbl for _, lbl in gt_list))
148
- all_labels = gt_labels + ["background"]
149
-
150
- n = len(all_labels)
151
- matrix = np.zeros((n, n), dtype=int)
152
- label_to_idx = {lbl: i for i, lbl in enumerate(all_labels)}
153
-
154
- matched, n_missed, matched_gt_indices = match_detections(dets, gt_list, iou_thr)
155
-
156
- for det, gt_lbl, _ in matched:
157
- pred_lbl = det[4]
158
- if gt_lbl is not None:
159
- # TP or mislabel
160
- pi = label_to_idx.get(pred_lbl, label_to_idx["background"])
161
- gi = label_to_idx[gt_lbl]
162
- matrix[pi][gi] += 1
163
- else:
164
- # FP
165
- pi = label_to_idx.get(pred_lbl, label_to_idx["background"])
166
- matrix[pi][label_to_idx["background"]] += 1
167
-
168
- # FN: unmatched GT
169
- for gi, (_, gt_lbl) in enumerate(gt_list):
170
- if gi not in matched_gt_indices:
171
- matrix[label_to_idx["background"]][label_to_idx[gt_lbl]] += 1
172
-
173
- return matrix, all_labels
174
-
175
-
176
- # ---------------------------------------------------------------------------
177
- # Collect all methods with detections
178
- # ---------------------------------------------------------------------------
179
- methods = {}
180
- if rce_dets is not None:
181
- methods["RCE"] = rce_dets
182
- if cnn_dets is not None:
183
- methods["CNN"] = cnn_dets
184
- if orb_dets is not None:
185
- methods["ORB"] = orb_dets
186
-
187
-
188
- # ===================================================================
189
- # 1. Confusion Matrices
190
- # ===================================================================
191
- st.subheader("🔲 Confusion Matrices")
192
- cm_cols = st.columns(len(methods))
193
-
194
- for col, (name, dets) in zip(cm_cols, methods.items()):
195
- with col:
196
- st.markdown(f"**{name}**")
197
- matrix, labels = build_confusion_matrix(dets, gt_boxes, iou_thresh)
198
-
199
- fig_cm = ff.create_annotated_heatmap(
200
- z=matrix.tolist(),
201
- x=labels, y=labels,
202
- colorscale="Blues",
203
- showscale=True)
204
- fig_cm.update_layout(
205
- title=f"{name} Confusion Matrix",
206
- xaxis_title="Actual",
207
- yaxis_title="Predicted",
208
- template="plotly_dark",
209
- height=350)
210
- fig_cm.update_yaxes(autorange="reversed")
211
- st.plotly_chart(fig_cm, use_container_width=True)
212
-
213
- # Summary metrics at this default threshold
214
- matched, n_missed, _ = match_detections(dets, gt_boxes, iou_thresh)
215
- tp = sum(1 for _, g, _ in matched if g is not None)
216
- fp = sum(1 for _, g, _ in matched if g is None)
217
- fn = n_missed
218
- prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0
219
- rec = tp / (tp + fn) if (tp + fn) > 0 else 0.0
220
- f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0.0
221
-
222
- m1, m2, m3 = st.columns(3)
223
- m1.metric("Precision", f"{prec:.1%}")
224
- m2.metric("Recall", f"{rec:.1%}")
225
- m3.metric("F1 Score", f"{f1:.1%}")
226
-
227
-
228
- # ===================================================================
229
- # 2. Precision-Recall Curves
230
- # ===================================================================
231
- st.divider()
232
- st.subheader("📉 Precision-Recall Curves")
233
-
234
- method_colors = {"RCE": "#00ff88", "CNN": "#4488ff", "ORB": "#ff8800"}
235
- fig_pr = go.Figure()
236
- fig_f1 = go.Figure()
237
-
238
- summary_rows = []
239
-
240
- for name, dets in methods.items():
241
- thrs, precs, recs, f1s = compute_pr_curve(dets, gt_boxes, iou_thresh)
242
- clr = method_colors.get(name, "#ffffff")
243
-
244
- fig_pr.add_trace(go.Scatter(
245
- x=recs, y=precs, mode="lines+markers",
246
- name=name, line=dict(color=clr, width=2),
247
- marker=dict(size=4)))
248
-
249
- fig_f1.add_trace(go.Scatter(
250
- x=thrs, y=f1s, mode="lines",
251
- name=name, line=dict(color=clr, width=2)))
252
-
253
- # AP (area under PR curve)
254
- if recs and precs:
255
- ap = float(np.trapz(precs, recs))
256
- else:
257
- ap = 0.0
258
-
259
- best_f1_idx = int(np.argmax(f1s)) if f1s else 0
260
- summary_rows.append({
261
- "Method": name,
262
- "AP": f"{abs(ap):.3f}",
263
- "Best F1": f"{f1s[best_f1_idx]:.3f}" if f1s else "N/A",
264
- "@ Threshold": f"{thrs[best_f1_idx]:.2f}" if thrs else "N/A",
265
- "Detections": len(dets),
266
- })
267
-
268
- fig_pr.update_layout(
269
- title="Precision vs Recall",
270
- xaxis_title="Recall", yaxis_title="Precision",
271
- template="plotly_dark", height=400,
272
- xaxis=dict(range=[0, 1.05]), yaxis=dict(range=[0, 1.05]))
273
-
274
- fig_f1.update_layout(
275
- title="F1 Score vs Confidence Threshold",
276
- xaxis_title="Confidence Threshold", yaxis_title="F1 Score",
277
- template="plotly_dark", height=400,
278
- xaxis=dict(range=[0, 1.05]), yaxis=dict(range=[0, 1.05]))
279
-
280
- pc1, pc2 = st.columns(2)
281
- pc1.plotly_chart(fig_pr, use_container_width=True)
282
- pc2.plotly_chart(fig_f1, use_container_width=True)
283
-
284
-
285
- # ===================================================================
286
- # 3. Summary Table
287
- # ===================================================================
288
- st.divider()
289
- st.subheader("📊 Summary")
290
-
291
- import pandas as pd
292
- st.dataframe(pd.DataFrame(summary_rows), use_container_width=True, hide_index=True)
293
-
294
- st.caption(f"All metrics computed at IoU threshold = **{iou_thresh:.2f}**. "
295
- "Adjust in the sidebar to explore sensitivity.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pages/8_Stereo_Geometry.py DELETED
@@ -1,353 +0,0 @@
1
- import streamlit as st
2
- import cv2
3
- import numpy as np
4
- import re
5
- import pandas as pd
6
- import plotly.graph_objects as go
7
- import sys, os
8
- sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
9
-
10
- st.set_page_config(page_title="Stereo Geometry", layout="wide")
11
- st.title("📐 Stereo Geometry: Distance Estimation")
12
-
13
- # ---------------------------------------------------------------------------
14
- # Guard
15
- # ---------------------------------------------------------------------------
16
- if "pipeline_data" not in st.session_state or "left" not in st.session_state.get("pipeline_data", {}):
17
- st.error("Complete **Data Lab** first.")
18
- st.stop()
19
-
20
- assets = st.session_state["pipeline_data"]
21
- img_l = assets["left"]
22
- img_r = assets["right"]
23
- gt_left = assets.get("gt_left") # float32 disparity map from PFM
24
- gt_right = assets.get("gt_right")
25
- conf_raw = assets.get("conf_raw", "")
26
- crop_bbox = assets.get("crop_bbox") # (x0, y0, x1, y1) on LEFT image
27
-
28
- rce_dets = st.session_state.get("rce_dets", [])
29
- cnn_dets = st.session_state.get("cnn_dets", [])
30
-
31
-
32
- # ===================================================================
33
- # Parse Middlebury-style camera config
34
- # ===================================================================
35
- def parse_config(text: str) -> dict:
36
- """
37
- Parse a Middlebury .txt / .conf calibration file.
38
- Expected keys: cam0, cam1, doffs, baseline, width, height, ndisp, vmin, vmax
39
- cam0 / cam1 are 3×3 matrices in bracket notation: [f 0 cx; 0 f cy; 0 0 1]
40
- """
41
- params = {}
42
- if not text or not text.strip():
43
- return params
44
- for line in text.strip().splitlines():
45
- line = line.strip()
46
- if "=" not in line:
47
- continue
48
- key, val = line.split("=", 1)
49
- key = key.strip()
50
- val = val.strip()
51
- if "[" in val:
52
- nums = list(map(float, re.findall(r"[-+]?\d*\.?\d+(?:[eE][-+]?\d+)?", val)))
53
- params[key] = np.array(nums).reshape(3, 3) if len(nums) == 9 else nums
54
- else:
55
- try:
56
- params[key] = float(val)
57
- except ValueError:
58
- params[key] = val
59
- return params
60
-
61
-
62
- calib = parse_config(conf_raw)
63
-
64
- # Extract intrinsics
65
- cam0 = calib.get("cam0")
66
- focal = float(cam0[0, 0]) if isinstance(cam0, np.ndarray) and cam0.shape == (3, 3) else 0.0
67
- doffs = float(calib.get("doffs", 0.0))
68
- baseline = float(calib.get("baseline", 1.0))
69
- ndisp = int(calib.get("ndisp", 128))
70
-
71
- if focal <= 0:
72
- st.error("❌ Focal length is **0** — the camera config is missing or malformed. "
73
- "Depth estimation cannot proceed. Return to **Data Lab** and upload a valid "
74
- "Middlebury camera config.")
75
- st.stop()
76
-
77
- if focal > 10000:
78
- st.error(f"❌ Focal length ({focal:.0f} px) is suspiciously large. "
79
- "Check your camera config file.")
80
- st.stop()
81
-
82
- if baseline <= 0 or baseline > 1000:
83
- st.error(f"❌ Invalid baseline ({baseline:.1f}). Expected 1–1000 mm.")
84
- st.stop()
85
-
86
- st.subheader("Camera Calibration")
87
- cc1, cc2, cc3, cc4 = st.columns(4)
88
- cc1.metric("Focal Length (px)", f"{focal:.1f}")
89
- cc2.metric("Baseline (mm)", f"{baseline:.1f}")
90
- cc3.metric("Doffs (px)", f"{doffs:.2f}")
91
- cc4.metric("ndisp", str(ndisp))
92
-
93
- with st.expander("Full Calibration"):
94
- st.json({k: v.tolist() if isinstance(v, np.ndarray) else v for k, v in calib.items()})
95
-
96
- st.divider()
97
-
98
-
99
- # ===================================================================
100
- # Image-size validation
101
- # ===================================================================
102
- if img_l.shape[:2] != img_r.shape[:2]:
103
- st.error(f"Left ({img_l.shape[1]}×{img_l.shape[0]}) and right "
104
- f"({img_r.shape[1]}×{img_r.shape[0]}) images must be the same size.")
105
- st.stop()
106
-
107
-
108
- # ===================================================================
109
- # Step 1 — Compute Disparity Map
110
- # ===================================================================
111
- st.subheader("Step 1: Disparity Map (StereoSGBM)")
112
-
113
- sc1, sc2, sc3 = st.columns(3)
114
- block_size = sc1.slider("Block Size", 3, 21, 5, step=2)
115
- p1_mult = sc2.slider("P1 multiplier", 1, 32, 8)
116
- p2_mult = sc3.slider("P2 multiplier", 1, 128, 32)
117
-
118
-
119
- @st.cache_data
120
- def compute_disparity(_left, _right, ndisp, block_size, p1m, p2m):
121
- """StereoSGBM disparity. _left/_right are un-hashed (numpy arrays)."""
122
- gray_l = cv2.cvtColor(_left, cv2.COLOR_BGR2GRAY)
123
- gray_r = cv2.cvtColor(_right, cv2.COLOR_BGR2GRAY)
124
-
125
- nd = max(16, (ndisp // 16) * 16)
126
- sgbm = cv2.StereoSGBM_create(
127
- minDisparity=0,
128
- numDisparities=nd,
129
- blockSize=block_size,
130
- P1=p1m * 1 * block_size ** 2,
131
- P2=p2m * 1 * block_size ** 2,
132
- disp12MaxDiff=1,
133
- uniquenessRatio=10,
134
- speckleWindowSize=100,
135
- speckleRange=32,
136
- mode=cv2.STEREO_SGBM_MODE_SGBM_3WAY,
137
- )
138
- return sgbm.compute(gray_l, gray_r).astype(np.float32) / 16.0
139
-
140
-
141
- with st.spinner("Computing disparity…"):
142
- try:
143
- disp = compute_disparity(img_l, img_r, ndisp, block_size, p1_mult, p2_mult)
144
- except cv2.error as e:
145
- st.error(f"StereoSGBM failed: {e}")
146
- st.stop()
147
-
148
- # Visualize disparity
149
- disp_vis = np.clip(disp, 0, None)
150
- disp_max = disp_vis.max() if disp_vis.max() > 0 else 1.0
151
- disp_norm = (disp_vis / disp_max * 255).astype(np.uint8)
152
- disp_color = cv2.applyColorMap(disp_norm, cv2.COLORMAP_INFERNO)
153
-
154
- dc1, dc2 = st.columns(2)
155
- dc1.image(cv2.cvtColor(img_l, cv2.COLOR_BGR2RGB), caption="Left Image", use_container_width=True)
156
- dc2.image(cv2.cvtColor(disp_color, cv2.COLOR_BGR2RGB), caption="Disparity Map (SGBM)", use_container_width=True)
157
-
158
-
159
- # ===================================================================
160
- # Step 2 — Depth Map from Disparity
161
- # ===================================================================
162
- st.divider()
163
- st.subheader("Step 2: Depth Map from Disparity")
164
-
165
- st.latex(r"Z = \frac{f \times B}{d + d_{\text{offs}}}")
166
- st.caption("Z = depth (mm), f = focal length (px), B = baseline (mm), d = disparity (px), d_offs = optical center offset (px)")
167
-
168
- # Compute depth from disparity
169
- valid = (disp + doffs) > 0
170
- depth_map = np.zeros_like(disp)
171
- if focal > 0:
172
- depth_map[valid] = (focal * baseline) / (disp[valid] + doffs)
173
-
174
- # Visualize
175
- depth_vis = depth_map.copy()
176
- finite = depth_vis[depth_vis > 0]
177
- if len(finite) > 0:
178
- clip_max = np.percentile(finite, 98)
179
- depth_vis = np.clip(depth_vis, 0, clip_max)
180
- depth_norm = (depth_vis / clip_max * 255).astype(np.uint8)
181
- else:
182
- depth_norm = np.zeros_like(depth_map, dtype=np.uint8)
183
-
184
- depth_color = cv2.applyColorMap(depth_norm, cv2.COLORMAP_TURBO)
185
-
186
- zc1, zc2 = st.columns(2)
187
- zc1.image(cv2.cvtColor(depth_color, cv2.COLOR_BGR2RGB),
188
- caption="Estimated Depth (SGBM)", use_container_width=True)
189
-
190
- # Ground truth comparison
191
- if gt_left is not None:
192
- gt_vis = gt_left.copy()
193
- gt_finite = gt_vis[np.isfinite(gt_vis) & (gt_vis > 0)]
194
- if len(gt_finite) > 0:
195
- gt_clip = np.percentile(gt_finite, 98)
196
- gt_vis = np.clip(np.nan_to_num(gt_vis, nan=0), 0, gt_clip)
197
- gt_norm = (gt_vis / gt_clip * 255).astype(np.uint8)
198
- else:
199
- gt_norm = np.zeros_like(gt_vis, dtype=np.uint8)
200
- gt_color = cv2.applyColorMap(gt_norm, cv2.COLORMAP_TURBO)
201
- zc2.image(cv2.cvtColor(gt_color, cv2.COLOR_BGR2RGB),
202
- caption="Ground Truth Disparity (from PFM)", use_container_width=True)
203
-
204
-
205
- # ===================================================================
206
- # Step 3 — Error Map (SGBM vs Ground Truth)
207
- # ===================================================================
208
- if gt_left is not None:
209
- st.divider()
210
- st.subheader("Step 3: Error Analysis (SGBM vs Ground Truth)")
211
-
212
- gt_disp = gt_left # Middlebury standard: PFM = disparity map
213
-
214
- # Ensure GT and SGBM disparity have the same shape
215
- if gt_disp.shape[:2] != disp.shape[:2]:
216
- st.warning(
217
- f"Ground truth shape ({gt_disp.shape[1]}×{gt_disp.shape[0]}) differs from "
218
- f"disparity shape ({disp.shape[1]}×{disp.shape[0]}). Resizing GT to match."
219
- )
220
- gt_disp = cv2.resize(gt_disp, (disp.shape[1], disp.shape[0]),
221
- interpolation=cv2.INTER_NEAREST)
222
-
223
- gt_valid = np.isfinite(gt_disp) & (gt_disp > 0)
224
- both_valid = valid & gt_valid
225
-
226
- if both_valid.any():
227
- # Disparity error
228
- disp_err = np.abs(disp - gt_disp)
229
- disp_err[~both_valid] = 0
230
-
231
- # Stats
232
- err_vals = disp_err[both_valid]
233
- mae = float(np.mean(err_vals))
234
- rmse = float(np.sqrt(np.mean(err_vals ** 2)))
235
- bad_2 = float(np.mean(err_vals > 2.0)) * 100
236
-
237
- em1, em2, em3 = st.columns(3)
238
- em1.metric("MAE (px)", f"{mae:.2f}")
239
- em2.metric("RMSE (px)", f"{rmse:.2f}")
240
- em3.metric("Bad-2.0 (%)", f"{bad_2:.1f}%")
241
-
242
- # Error heatmap
243
- err_clip = np.clip(disp_err, 0, 10)
244
- err_norm = (err_clip / 10 * 255).astype(np.uint8)
245
- err_color = cv2.applyColorMap(err_norm, cv2.COLORMAP_HOT)
246
- st.image(cv2.cvtColor(err_color, cv2.COLOR_BGR2RGB),
247
- caption="Disparity Error Map (red = high error, clipped at 10 px)",
248
- use_container_width=True)
249
-
250
- # Histogram
251
- fig = go.Figure(data=[go.Histogram(x=err_vals, nbinsx=50,
252
- marker_color="#ff6361")])
253
- fig.update_layout(title="Disparity Error Distribution",
254
- xaxis_title="Absolute Error (px)",
255
- yaxis_title="Pixel Count",
256
- template="plotly_dark", height=300)
257
- st.plotly_chart(fig, use_container_width=True)
258
- else:
259
- st.warning("No overlapping valid pixels between SGBM disparity and ground truth.")
260
-
261
-
262
- # ===================================================================
263
- # Step 4 — Object Distance from Detections
264
- # ===================================================================
265
- st.divider()
266
- st.subheader("Step 4: Object Distance Estimation")
267
-
268
- all_dets = []
269
- all_dets.extend(("RCE", *d) for d in rce_dets)
270
- all_dets.extend(("CNN", *d) for d in cnn_dets)
271
-
272
- if not all_dets and crop_bbox is not None:
273
- st.info("No detections from the Real-Time Detection page. Using the **crop bounding box** as a fallback.")
274
- x0, y0, x1, y1 = crop_bbox
275
- all_dets.append(("Crop", x0, y0, x1, y1, "object", 1.0))
276
- elif not all_dets:
277
- st.warning("No detections found. Run **Real-Time Detection** first, or define a crop in **Data Lab**.")
278
- st.stop()
279
-
280
- if focal <= 0:
281
- st.warning("Focal length is 0 — cannot compute depth. Upload a valid config in **Data Lab**.")
282
- st.stop()
283
-
284
- rows = []
285
- det_overlay = img_l.copy()
286
-
287
- for source, dx1, dy1, dx2, dy2, lbl, conf in all_dets:
288
- dx1, dy1, dx2, dy2 = int(dx1), int(dy1), int(dx2), int(dy2)
289
-
290
- # Clamp to image bounds
291
- H, W = depth_map.shape[:2]
292
- dx1c = max(0, min(dx1, W - 1))
293
- dy1c = max(0, min(dy1, H - 1))
294
- dx2c = max(0, min(dx2, W))
295
- dy2c = max(0, min(dy2, H))
296
-
297
- roi_depth = depth_map[dy1c:dy2c, dx1c:dx2c]
298
- roi_disp = disp[dy1c:dy2c, dx1c:dx2c]
299
- roi_valid = roi_depth[roi_depth > 0]
300
-
301
- if len(roi_valid) > 0:
302
- med_depth = float(np.median(roi_valid))
303
- mean_depth = float(np.mean(roi_valid))
304
- med_disp = float(np.median(roi_disp[roi_disp > 0])) if (roi_disp > 0).any() else 0
305
- else:
306
- med_depth = mean_depth = med_disp = 0.0
307
-
308
- # Ground truth depth at this region
309
- gt_depth_val = 0.0
310
- if gt_left is not None:
311
- gt_roi = gt_left[dy1c:dy2c, dx1c:dx2c]
312
- gt_roi_valid = gt_roi[np.isfinite(gt_roi) & (gt_roi > 0)]
313
- if len(gt_roi_valid) > 0:
314
- gt_med_disp = float(np.median(gt_roi_valid))
315
- gt_depth_val = (focal * baseline) / (gt_med_disp + doffs) if (gt_med_disp + doffs) > 0 else 0
316
-
317
- error_mm = abs(med_depth - gt_depth_val) if gt_depth_val > 0 else float("nan")
318
-
319
- rows.append({
320
- "Source": source,
321
- "Box": f"({dx1},{dy1})→({dx2},{dy2})",
322
- "Confidence": f"{conf:.1%}" if isinstance(conf, float) else str(conf),
323
- "Med Disparity": f"{med_disp:.1f} px",
324
- "Med Depth": f"{med_depth:.0f} mm",
325
- "Mean Depth": f"{mean_depth:.0f} mm",
326
- "GT Depth": f"{gt_depth_val:.0f} mm" if gt_depth_val > 0 else "N/A",
327
- "Error": f"{error_mm:.0f} mm" if not np.isnan(error_mm) else "N/A",
328
- })
329
-
330
- # Draw on overlay
331
- color = (0, 255, 0) if "RCE" in source else (0, 0, 255) if "CNN" in source else (255, 255, 0)
332
- cv2.rectangle(det_overlay, (dx1c, dy1c), (dx2c, dy2c), color, 2)
333
- depth_str = f"{med_depth / 1000:.2f}m" if med_depth > 0 else "?"
334
- cv2.putText(det_overlay, f"{source} {depth_str}",
335
- (dx1c, max(dy1c - 6, 12)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
336
-
337
- # Show overlay
338
- st.image(cv2.cvtColor(det_overlay, cv2.COLOR_BGR2RGB),
339
- caption="Detections with Estimated Distance",
340
- use_container_width=True)
341
-
342
- # Table
343
- st.dataframe(pd.DataFrame(rows), use_container_width=True, hide_index=True)
344
-
345
- # Primary detection summary
346
- if rows:
347
- best = rows[0]
348
- st.divider()
349
- st.subheader("🎯 Primary Detection — Distance")
350
- bc1, bc2, bc3 = st.columns(3)
351
- bc1.metric("Estimated Depth", best["Med Depth"])
352
- bc2.metric("Ground Truth", best["GT Depth"])
353
- bc3.metric("Absolute Error", best["Error"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tabs/__init__.py ADDED
File without changes
tabs/generalisation/__init__.py ADDED
File without changes
tabs/generalisation/data_lab.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Generalisation Data Lab — Stage 1 of the Generalisation pipeline."""
2
+
3
+ import streamlit as st
4
+ import cv2
5
+ import numpy as np
6
+ import os
7
+
8
+ from utils.middlebury_loader import (
9
+ DEFAULT_MIDDLEBURY_ROOT, get_scene_groups, load_single_view,
10
+ read_pfm_bytes,
11
+ )
12
+
13
+
14
+ # ------------------------------------------------------------------
15
+ # Helpers (shared with stereo data lab)
16
+ # ------------------------------------------------------------------
17
+
18
+ def _augment(img, brightness, contrast, rotation,
19
+ flip_h, flip_v, noise, blur, shift_x, shift_y):
20
+ out = img.astype(np.float32)
21
+ out = np.clip(contrast * out + brightness, 0, 255)
22
+ if noise > 0:
23
+ out = np.clip(out + np.random.normal(0, noise, out.shape), 0, 255)
24
+ out = out.astype(np.uint8)
25
+ k = blur * 2 + 1
26
+ if k > 1:
27
+ out = cv2.GaussianBlur(out, (k, k), 0)
28
+ if rotation != 0:
29
+ h, w = out.shape[:2]
30
+ M = cv2.getRotationMatrix2D((w / 2, h / 2), rotation, 1.0)
31
+ out = cv2.warpAffine(out, M, (w, h), borderMode=cv2.BORDER_REFLECT)
32
+ if shift_x != 0 or shift_y != 0:
33
+ h, w = out.shape[:2]
34
+ M = np.float32([[1, 0, shift_x], [0, 1, shift_y]])
35
+ out = cv2.warpAffine(out, M, (w, h), borderMode=cv2.BORDER_REFLECT)
36
+ if flip_h:
37
+ out = cv2.flip(out, 1)
38
+ if flip_v:
39
+ out = cv2.flip(out, 0)
40
+ return out
41
+
42
+
43
+ ROI_COLORS = [(0,255,0),(255,0,0),(0,0,255),(255,255,0),
44
+ (255,0,255),(0,255,255),(128,255,0),(255,128,0)]
45
+ MAX_UPLOAD_BYTES = 50 * 1024 * 1024
46
+
47
+
48
+ def render():
49
+ st.header("🧪 Data Lab — Generalisation")
50
+ st.info("**How this works:** Train on one image, test on a completely "
51
+ "different image of the same object. No stereo geometry — "
52
+ "pure recognition generalisation.")
53
+
54
+ source = st.radio("Data source",
55
+ ["📦 Middlebury Multi-View", "📁 Upload your own files"],
56
+ horizontal=True, key="gen_source")
57
+
58
+ # ===================================================================
59
+ # Middlebury multi-view
60
+ # ===================================================================
61
+ if source == "📦 Middlebury Multi-View":
62
+ groups = get_scene_groups()
63
+ if not groups:
64
+ st.error("No valid Middlebury scenes found in ./data/middlebury/")
65
+ return
66
+
67
+ group_name = st.selectbox("Scene group", list(groups.keys()), key="gen_group")
68
+ variants = groups[group_name]
69
+
70
+ gc1, gc2 = st.columns(2)
71
+ train_scene = gc1.selectbox("Training scene", variants, key="gen_train_scene")
72
+ available_test = [v for v in variants if v != train_scene]
73
+ if not available_test:
74
+ st.error("Need at least 2 variants in a group.")
75
+ return
76
+ test_scene = gc2.selectbox("Test scene", available_test, key="gen_test_scene")
77
+
78
+ train_path = os.path.join(DEFAULT_MIDDLEBURY_ROOT, train_scene)
79
+ test_path = os.path.join(DEFAULT_MIDDLEBURY_ROOT, test_scene)
80
+
81
+ img_train = load_single_view(train_path)
82
+ img_test = load_single_view(test_path)
83
+
84
+ st.markdown("*Both images show the same scene type captured under different "
85
+ "conditions. The model trains on one variant and must recognise "
86
+ "the same object class in the other — testing genuine appearance "
87
+ "generalisation.*")
88
+
89
+ c1, c2 = st.columns(2)
90
+ c1.image(cv2.cvtColor(img_train, cv2.COLOR_BGR2RGB),
91
+ caption=f"🟦 TRAIN IMAGE ({train_scene})", use_container_width=True)
92
+ c2.image(cv2.cvtColor(img_test, cv2.COLOR_BGR2RGB),
93
+ caption=f"🟥 TEST IMAGE ({test_scene})", use_container_width=True)
94
+
95
+ scene_group = group_name
96
+
97
+ # ===================================================================
98
+ # Custom upload
99
+ # ===================================================================
100
+ else:
101
+ uc1, uc2 = st.columns(2)
102
+ with uc1:
103
+ up_train = st.file_uploader("Train Image", type=["png","jpg","jpeg"],
104
+ key="gen_up_train")
105
+ with uc2:
106
+ up_test = st.file_uploader("Test Image", type=["png","jpg","jpeg"],
107
+ key="gen_up_test")
108
+
109
+ if not (up_train and up_test):
110
+ st.info("Upload a train and test image to proceed.")
111
+ return
112
+
113
+ if up_train.size > MAX_UPLOAD_BYTES or up_test.size > MAX_UPLOAD_BYTES:
114
+ st.error("Image too large (max 50 MB).")
115
+ return
116
+
117
+ img_train = cv2.imdecode(np.frombuffer(up_train.read(), np.uint8), cv2.IMREAD_COLOR); up_train.seek(0)
118
+ img_test = cv2.imdecode(np.frombuffer(up_test.read(), np.uint8), cv2.IMREAD_COLOR); up_test.seek(0)
119
+
120
+ c1, c2 = st.columns(2)
121
+ c1.image(cv2.cvtColor(img_train, cv2.COLOR_BGR2RGB),
122
+ caption="🟦 TRAIN IMAGE", use_container_width=True)
123
+ c2.image(cv2.cvtColor(img_test, cv2.COLOR_BGR2RGB),
124
+ caption="🟥 TEST IMAGE", use_container_width=True)
125
+
126
+ train_scene = "custom_train"
127
+ test_scene = "custom_test"
128
+ scene_group = "custom"
129
+
130
+ # ===================================================================
131
+ # ROI Definition (on TRAIN image)
132
+ # ===================================================================
133
+ st.divider()
134
+ st.subheader("Step 2: Crop Region(s) of Interest")
135
+ st.write("Define bounding boxes on the **TRAIN image**.")
136
+
137
+ H, W = img_train.shape[:2]
138
+
139
+ if "gen_rois" not in st.session_state:
140
+ st.session_state["gen_rois"] = [
141
+ {"label": "object", "x0": 0, "y0": 0,
142
+ "x1": min(W, 100), "y1": min(H, 100)}
143
+ ]
144
+
145
+ def _add_roi():
146
+ if len(st.session_state["gen_rois"]) >= 20:
147
+ return
148
+ st.session_state["gen_rois"].append(
149
+ {"label": f"object_{len(st.session_state['gen_rois'])+1}",
150
+ "x0": 0, "y0": 0,
151
+ "x1": min(W, 100), "y1": min(H, 100)})
152
+
153
+ def _remove_roi(idx):
154
+ if len(st.session_state["gen_rois"]) > 1:
155
+ st.session_state["gen_rois"].pop(idx)
156
+
157
+ for i, roi in enumerate(st.session_state["gen_rois"]):
158
+ color = ROI_COLORS[i % len(ROI_COLORS)]
159
+ color_hex = "#{:02x}{:02x}{:02x}".format(*color)
160
+ with st.container(border=True):
161
+ hc1, hc2, hc3 = st.columns([3, 6, 1])
162
+ hc1.markdown(f"**ROI {i+1}** <span style='color:{color_hex}'>■</span>",
163
+ unsafe_allow_html=True)
164
+ roi["label"] = hc2.text_input("Class Label", roi["label"],
165
+ key=f"gen_roi_lbl_{i}")
166
+ if len(st.session_state["gen_rois"]) > 1:
167
+ hc3.button("✕", key=f"gen_roi_del_{i}",
168
+ on_click=_remove_roi, args=(i,))
169
+
170
+ cr1, cr2, cr3, cr4 = st.columns(4)
171
+ roi["x0"] = int(cr1.number_input("X start", 0, W-2, int(roi["x0"]),
172
+ step=1, key=f"gen_roi_x0_{i}"))
173
+ roi["y0"] = int(cr2.number_input("Y start", 0, H-2, int(roi["y0"]),
174
+ step=1, key=f"gen_roi_y0_{i}"))
175
+ roi["x1"] = int(cr3.number_input("X end", roi["x0"]+1, W,
176
+ min(W, int(roi["x1"])),
177
+ step=1, key=f"gen_roi_x1_{i}"))
178
+ roi["y1"] = int(cr4.number_input("Y end", roi["y0"]+1, H,
179
+ min(H, int(roi["y1"])),
180
+ step=1, key=f"gen_roi_y1_{i}"))
181
+
182
+ st.button("➕ Add Another ROI", on_click=_add_roi,
183
+ disabled=len(st.session_state["gen_rois"]) >= 20,
184
+ key="gen_add_roi")
185
+
186
+ overlay = img_train.copy()
187
+ crops = []
188
+ for i, roi in enumerate(st.session_state["gen_rois"]):
189
+ color = ROI_COLORS[i % len(ROI_COLORS)]
190
+ x0, y0, x1, y1 = roi["x0"], roi["y0"], roi["x1"], roi["y1"]
191
+ cv2.rectangle(overlay, (x0, y0), (x1, y1), color, 2)
192
+ cv2.putText(overlay, roi["label"], (x0, y0 - 6),
193
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
194
+ crops.append(img_train[y0:y1, x0:x1].copy())
195
+
196
+ ov1, ov2 = st.columns([3, 2])
197
+ ov1.image(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB),
198
+ caption="TRAIN image — ROIs highlighted", use_container_width=True)
199
+ with ov2:
200
+ for i, (c, roi) in enumerate(zip(crops, st.session_state["gen_rois"])):
201
+ st.image(cv2.cvtColor(c, cv2.COLOR_BGR2RGB),
202
+ caption=f"{roi['label']} ({c.shape[1]}×{c.shape[0]})", width=160)
203
+
204
+ crop_bgr = crops[0]
205
+ x0 = st.session_state["gen_rois"][0]["x0"]
206
+ y0 = st.session_state["gen_rois"][0]["y0"]
207
+ x1 = st.session_state["gen_rois"][0]["x1"]
208
+ y1 = st.session_state["gen_rois"][0]["y1"]
209
+
210
+ # ===================================================================
211
+ # Augmentation
212
+ # ===================================================================
213
+ st.divider()
214
+ st.subheader("Step 3: Data Augmentation")
215
+ ac1, ac2 = st.columns(2)
216
+ with ac1:
217
+ brightness = st.slider("Brightness offset", -100, 100, 0, key="gen_bright")
218
+ contrast = st.slider("Contrast scale", 0.5, 3.0, 1.0, 0.05, key="gen_contrast")
219
+ rotation = st.slider("Rotation (°)", -180, 180, 0, key="gen_rot")
220
+ noise = st.slider("Gaussian noise σ", 0, 50, 0, key="gen_noise")
221
+ with ac2:
222
+ blur = st.slider("Blur kernel (0=off)", 0, 10, 0, key="gen_blur")
223
+ shift_x = st.slider("Shift X (px)", -100, 100, 0, key="gen_sx")
224
+ shift_y = st.slider("Shift Y (px)", -100, 100, 0, key="gen_sy")
225
+ flip_h = st.checkbox("Flip Horizontal", key="gen_fh")
226
+ flip_v = st.checkbox("Flip Vertical", key="gen_fv")
227
+
228
+ aug = _augment(crop_bgr, brightness, contrast, rotation,
229
+ flip_h, flip_v, noise, blur, shift_x, shift_y)
230
+ all_augs = [_augment(c, brightness, contrast, rotation,
231
+ flip_h, flip_v, noise, blur, shift_x, shift_y)
232
+ for c in crops]
233
+
234
+ ag1, ag2 = st.columns(2)
235
+ ag1.image(cv2.cvtColor(crop_bgr, cv2.COLOR_BGR2RGB),
236
+ caption="Original Crop (ROI 1)", use_container_width=True)
237
+ ag2.image(cv2.cvtColor(aug, cv2.COLOR_BGR2RGB),
238
+ caption="Augmented Crop (ROI 1)", use_container_width=True)
239
+
240
+ # ===================================================================
241
+ # Lock & Store
242
+ # ===================================================================
243
+ st.divider()
244
+ if st.button("🚀 Lock Data & Proceed", key="gen_lock"):
245
+ rois_data = []
246
+ for i, roi in enumerate(st.session_state["gen_rois"]):
247
+ rois_data.append({
248
+ "label": roi["label"],
249
+ "bbox": (roi["x0"], roi["y0"], roi["x1"], roi["y1"]),
250
+ "crop": crops[i],
251
+ "crop_aug": all_augs[i],
252
+ })
253
+
254
+ st.session_state["gen_pipeline"] = {
255
+ "train_image": img_train,
256
+ "test_image": img_test,
257
+ "roi": {"x": x0, "y": y0, "w": x1 - x0, "h": y1 - y0,
258
+ "label": st.session_state["gen_rois"][0]["label"]},
259
+ "crop": crop_bgr,
260
+ "crop_aug": aug,
261
+ "crop_bbox": (x0, y0, x1, y1),
262
+ "rois": rois_data,
263
+ "source": "middlebury" if source == "📦 Middlebury Multi-View" else "custom",
264
+ "scene_group": scene_group if "scene_group" in dir() else "",
265
+ "train_scene": train_scene if "train_scene" in dir() else "",
266
+ "test_scene": test_scene if "test_scene" in dir() else "",
267
+ }
268
+ st.success(f"✅ Data locked with **{len(rois_data)} ROI(s)**! "
269
+ f"Proceed to Feature Lab.")
tabs/generalisation/detection.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Generalisation Detection — Stage 5 of the Generalisation pipeline.
2
+
3
+ CRITICAL: Detection runs on the TEST image (different scene variant).
4
+ Training was done on the TRAIN image.
5
+ This enforces the data-leakage fix.
6
+ """
7
+
8
+ import streamlit as st
9
+ import cv2
10
+ import numpy as np
11
+ import time
12
+ import plotly.graph_objects as go
13
+
14
+ from src.detectors.rce.features import REGISTRY
15
+ from src.models import BACKBONES, RecognitionHead
16
+ from src.utils import build_rce_vector
17
+ from src.localization import nms as _nms
18
+
19
+
20
+ CLASS_COLORS = [(0,255,0),(0,0,255),(255,165,0),(255,0,255),(0,255,255),
21
+ (128,255,0),(255,128,0),(0,128,255)]
22
+
23
+
24
+ def sliding_window_detect(image, feature_fn, head, win_h, win_w,
25
+ stride, conf_thresh, nms_iou,
26
+ progress_placeholder=None,
27
+ live_image_placeholder=None):
28
+ H, W = image.shape[:2]
29
+ heatmap = np.zeros((H, W), dtype=np.float32)
30
+ detections = []
31
+ t0 = time.perf_counter()
32
+
33
+ positions = [(x, y)
34
+ for y in range(0, H - win_h + 1, stride)
35
+ for x in range(0, W - win_w + 1, stride)]
36
+ n_total = len(positions)
37
+ if n_total == 0:
38
+ return [], heatmap, 0.0, 0
39
+
40
+ for idx, (x, y) in enumerate(positions):
41
+ patch = image[y:y+win_h, x:x+win_w]
42
+ feats = feature_fn(patch)
43
+ label, conf = head.predict(feats)
44
+
45
+ if label != "background":
46
+ heatmap[y:y+win_h, x:x+win_w] = np.maximum(
47
+ heatmap[y:y+win_h, x:x+win_w], conf)
48
+ if conf >= conf_thresh:
49
+ detections.append((x, y, x+win_w, y+win_h, label, conf))
50
+
51
+ if live_image_placeholder is not None and (idx % 5 == 0 or idx == n_total - 1):
52
+ vis = image.copy()
53
+ cv2.rectangle(vis, (x, y), (x+win_w, y+win_h), (255, 255, 0), 1)
54
+ for dx, dy, dx2, dy2, dl, dc in detections:
55
+ cv2.rectangle(vis, (dx, dy), (dx2, dy2), (0, 255, 0), 2)
56
+ cv2.putText(vis, f"{dc:.0%}", (dx, dy - 4),
57
+ cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 255, 0), 1)
58
+ live_image_placeholder.image(
59
+ cv2.cvtColor(vis, cv2.COLOR_BGR2RGB),
60
+ caption=f"Scanning… {idx+1}/{n_total}",
61
+ use_container_width=True)
62
+
63
+ if progress_placeholder is not None:
64
+ progress_placeholder.progress(
65
+ (idx + 1) / n_total, text=f"Window {idx+1}/{n_total}")
66
+
67
+ total_ms = (time.perf_counter() - t0) * 1000
68
+ if detections:
69
+ detections = _nms(detections, nms_iou)
70
+ return detections, heatmap, total_ms, n_total
71
+
72
+
73
+ def render():
74
+ st.title("🎯 Real-Time Detection")
75
+
76
+ pipe = st.session_state.get("gen_pipeline")
77
+ if not pipe or "crop" not in pipe:
78
+ st.error("Complete **Data Lab** first (upload assets & define a crop).")
79
+ st.stop()
80
+
81
+ # CRITICAL: detect on TEST image, not TRAIN image
82
+ test_img = pipe["test_image"]
83
+ crop = pipe["crop"]
84
+ crop_aug = pipe.get("crop_aug", crop)
85
+ bbox = pipe.get("crop_bbox", (0, 0, crop.shape[1], crop.shape[0]))
86
+ rois = pipe.get("rois", [{"label": "object", "bbox": bbox,
87
+ "crop": crop, "crop_aug": crop_aug}])
88
+ active_mods = pipe.get("active_modules", {k: True for k in REGISTRY})
89
+
90
+ x0, y0, x1, y1 = bbox
91
+ win_h, win_w = y1 - y0, x1 - x0
92
+
93
+ if win_h <= 0 or win_w <= 0:
94
+ st.error("Invalid window size from crop bbox.")
95
+ st.stop()
96
+
97
+ rce_head = pipe.get("rce_head")
98
+ has_any_cnn = any(f"cnn_head_{n}" in pipe for n in BACKBONES)
99
+ has_orb = pipe.get("orb_refs") is not None
100
+
101
+ if rce_head is None and not has_any_cnn and not has_orb:
102
+ st.warning("No trained heads found. Go to **Model Tuning** first.")
103
+ st.stop()
104
+
105
+ def rce_feature_fn(patch_bgr):
106
+ return build_rce_vector(patch_bgr, active_mods)
107
+
108
+ # Controls
109
+ st.subheader("Sliding Window Parameters")
110
+ p1, p2, p3 = st.columns(3)
111
+ stride = p1.slider("Stride (px)", 4, max(win_w // 2, 4),
112
+ max(win_w // 4, 4), step=2, key="gen_det_stride")
113
+ conf_thresh = p2.slider("Confidence Threshold", 0.5, 1.0, 0.7, 0.05,
114
+ key="gen_det_conf")
115
+ nms_iou = p3.slider("NMS IoU Threshold", 0.1, 0.9, 0.3, 0.05,
116
+ key="gen_det_nms")
117
+
118
+ st.caption(f"Window size: **{win_w}×{win_h} px** | "
119
+ f"Test image: **{test_img.shape[1]}×{test_img.shape[0]} px** | "
120
+ f"≈ {((test_img.shape[0]-win_h)//stride + 1) * ((test_img.shape[1]-win_w)//stride + 1)} windows")
121
+ st.divider()
122
+
123
+ col_rce, col_cnn, col_orb = st.columns(3)
124
+
125
+ # -------------------------------------------------------------------
126
+ # RCE Detection
127
+ # -------------------------------------------------------------------
128
+ with col_rce:
129
+ st.header("🧬 RCE Detection")
130
+ if rce_head is None:
131
+ st.info("No RCE head trained.")
132
+ else:
133
+ st.caption(f"Modules: {', '.join(REGISTRY[k]['label'] for k in active_mods if active_mods[k])}")
134
+ rce_run = st.button("▶ Run RCE Scan", key="gen_rce_run")
135
+ rce_progress = st.empty()
136
+ rce_live = st.empty()
137
+ rce_results = st.container()
138
+
139
+ if rce_run:
140
+ dets, hmap, ms, nw = sliding_window_detect(
141
+ test_img, rce_feature_fn, rce_head, win_h, win_w,
142
+ stride, conf_thresh, nms_iou,
143
+ progress_placeholder=rce_progress,
144
+ live_image_placeholder=rce_live)
145
+
146
+ final = test_img.copy()
147
+ class_labels = sorted(set(d[4] for d in dets)) if dets else []
148
+ for x1d, y1d, x2d, y2d, lbl, cf in dets:
149
+ ci = class_labels.index(lbl) if lbl in class_labels else 0
150
+ clr = CLASS_COLORS[ci % len(CLASS_COLORS)]
151
+ cv2.rectangle(final, (x1d, y1d), (x2d, y2d), clr, 2)
152
+ cv2.putText(final, f"{lbl} {cf:.0%}", (x1d, y1d - 6),
153
+ cv2.FONT_HERSHEY_SIMPLEX, 0.4, clr, 1)
154
+ rce_live.image(cv2.cvtColor(final, cv2.COLOR_BGR2RGB),
155
+ caption="RCE — Final Detections",
156
+ use_container_width=True)
157
+ rce_progress.empty()
158
+
159
+ with rce_results:
160
+ rm1, rm2, rm3, rm4 = st.columns(4)
161
+ rm1.metric("Detections", len(dets))
162
+ rm2.metric("Windows", nw)
163
+ rm3.metric("Total Time", f"{ms:.0f} ms")
164
+ rm4.metric("Per Window", f"{ms/max(nw,1):.2f} ms")
165
+
166
+ if hmap.max() > 0:
167
+ hmap_color = cv2.applyColorMap(
168
+ (hmap / hmap.max() * 255).astype(np.uint8),
169
+ cv2.COLORMAP_JET)
170
+ blend = cv2.addWeighted(test_img, 0.5, hmap_color, 0.5, 0)
171
+ st.image(cv2.cvtColor(blend, cv2.COLOR_BGR2RGB),
172
+ caption="RCE — Confidence Heatmap",
173
+ use_container_width=True)
174
+
175
+ if dets:
176
+ import pandas as pd
177
+ df = pd.DataFrame(dets, columns=["x1","y1","x2","y2","label","conf"])
178
+ st.dataframe(df, use_container_width=True, hide_index=True)
179
+
180
+ pipe["rce_dets"] = dets
181
+ pipe["rce_det_ms"] = ms
182
+ st.session_state["gen_pipeline"] = pipe
183
+
184
+ # -------------------------------------------------------------------
185
+ # CNN Detection
186
+ # -------------------------------------------------------------------
187
+ with col_cnn:
188
+ st.header("🧠 CNN Detection")
189
+ trained_cnns = [n for n in BACKBONES if f"cnn_head_{n}" in pipe]
190
+ if not trained_cnns:
191
+ st.info("No CNN head trained.")
192
+ else:
193
+ selected = st.selectbox("Select Model", trained_cnns,
194
+ key="gen_det_cnn_sel")
195
+ bmeta = BACKBONES[selected]
196
+ backbone = bmeta["loader"]()
197
+ head = pipe[f"cnn_head_{selected}"]
198
+
199
+ st.caption(f"Backbone: **{selected}** ({bmeta['dim']}D)")
200
+ cnn_run = st.button(f"▶ Run {selected} Scan", key="gen_cnn_run")
201
+ cnn_progress = st.empty()
202
+ cnn_live = st.empty()
203
+ cnn_results = st.container()
204
+
205
+ if cnn_run:
206
+ dets, hmap, ms, nw = sliding_window_detect(
207
+ test_img, backbone.get_features, head, win_h, win_w,
208
+ stride, conf_thresh, nms_iou,
209
+ progress_placeholder=cnn_progress,
210
+ live_image_placeholder=cnn_live)
211
+
212
+ final = test_img.copy()
213
+ class_labels = sorted(set(d[4] for d in dets)) if dets else []
214
+ for x1d, y1d, x2d, y2d, lbl, cf in dets:
215
+ ci = class_labels.index(lbl) if lbl in class_labels else 0
216
+ clr = CLASS_COLORS[ci % len(CLASS_COLORS)]
217
+ cv2.rectangle(final, (x1d, y1d), (x2d, y2d), clr, 2)
218
+ cv2.putText(final, f"{lbl} {cf:.0%}", (x1d, y1d - 6),
219
+ cv2.FONT_HERSHEY_SIMPLEX, 0.4, clr, 1)
220
+ cnn_live.image(cv2.cvtColor(final, cv2.COLOR_BGR2RGB),
221
+ caption=f"{selected} — Final Detections",
222
+ use_container_width=True)
223
+ cnn_progress.empty()
224
+
225
+ with cnn_results:
226
+ cm1, cm2, cm3, cm4 = st.columns(4)
227
+ cm1.metric("Detections", len(dets))
228
+ cm2.metric("Windows", nw)
229
+ cm3.metric("Total Time", f"{ms:.0f} ms")
230
+ cm4.metric("Per Window", f"{ms/max(nw,1):.2f} ms")
231
+
232
+ if hmap.max() > 0:
233
+ hmap_color = cv2.applyColorMap(
234
+ (hmap / hmap.max() * 255).astype(np.uint8),
235
+ cv2.COLORMAP_JET)
236
+ blend = cv2.addWeighted(test_img, 0.5, hmap_color, 0.5, 0)
237
+ st.image(cv2.cvtColor(blend, cv2.COLOR_BGR2RGB),
238
+ caption=f"{selected} — Confidence Heatmap",
239
+ use_container_width=True)
240
+
241
+ if dets:
242
+ import pandas as pd
243
+ df = pd.DataFrame(dets, columns=["x1","y1","x2","y2","label","conf"])
244
+ st.dataframe(df, use_container_width=True, hide_index=True)
245
+
246
+ pipe["cnn_dets"] = dets
247
+ pipe["cnn_det_ms"] = ms
248
+ st.session_state["gen_pipeline"] = pipe
249
+
250
+ # -------------------------------------------------------------------
251
+ # ORB Detection
252
+ # -------------------------------------------------------------------
253
+ with col_orb:
254
+ st.header("🏛️ ORB Detection")
255
+ if not has_orb:
256
+ st.info("No ORB reference trained.")
257
+ else:
258
+ orb_det = pipe["orb_detector"]
259
+ orb_refs = pipe["orb_refs"]
260
+ dt_thresh = pipe.get("orb_dist_thresh", 70)
261
+ min_m = pipe.get("orb_min_matches", 5)
262
+ st.caption(f"References: {', '.join(orb_refs.keys())} | "
263
+ f"dist<{dt_thresh}, min {min_m} matches")
264
+ orb_run = st.button("▶ Run ORB Scan", key="gen_orb_run")
265
+ orb_progress = st.empty()
266
+ orb_live = st.empty()
267
+ orb_results = st.container()
268
+
269
+ if orb_run:
270
+ H, W = test_img.shape[:2]
271
+ positions = [(x, y)
272
+ for y in range(0, H - win_h + 1, stride)
273
+ for x in range(0, W - win_w + 1, stride)]
274
+ n_total = len(positions)
275
+ heatmap = np.zeros((H, W), dtype=np.float32)
276
+ detections = []
277
+ t0 = time.perf_counter()
278
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
279
+
280
+ for idx, (px, py) in enumerate(positions):
281
+ patch = test_img[py:py+win_h, px:px+win_w]
282
+ gray = cv2.cvtColor(patch, cv2.COLOR_BGR2GRAY)
283
+ gray = clahe.apply(gray)
284
+ kp, des = orb_det.orb.detectAndCompute(gray, None)
285
+
286
+ if des is not None:
287
+ best_label, best_conf = "background", 0.0
288
+ for lbl, ref in orb_refs.items():
289
+ if ref["descriptors"] is None:
290
+ continue
291
+ matches = orb_det.bf.match(ref["descriptors"], des)
292
+ good = [m for m in matches if m.distance < dt_thresh]
293
+ conf = min(len(good) / max(min_m, 1), 1.0)
294
+ if len(good) >= min_m and conf > best_conf:
295
+ best_label, best_conf = lbl, conf
296
+
297
+ if best_label != "background":
298
+ heatmap[py:py+win_h, px:px+win_w] = np.maximum(
299
+ heatmap[py:py+win_h, px:px+win_w], best_conf)
300
+ if best_conf >= conf_thresh:
301
+ detections.append(
302
+ (px, py, px+win_w, py+win_h, best_label, best_conf))
303
+
304
+ if idx % 5 == 0 or idx == n_total - 1:
305
+ orb_progress.progress((idx+1)/n_total,
306
+ text=f"Window {idx+1}/{n_total}")
307
+
308
+ total_ms = (time.perf_counter() - t0) * 1000
309
+ if detections:
310
+ detections = _nms(detections, nms_iou)
311
+
312
+ final = test_img.copy()
313
+ cls_labels = sorted(set(d[4] for d in detections)) if detections else []
314
+ for x1d, y1d, x2d, y2d, lbl, cf in detections:
315
+ ci = cls_labels.index(lbl) if lbl in cls_labels else 0
316
+ clr = CLASS_COLORS[ci % len(CLASS_COLORS)]
317
+ cv2.rectangle(final, (x1d, y1d), (x2d, y2d), clr, 2)
318
+ cv2.putText(final, f"{lbl} {cf:.0%}", (x1d, y1d - 6),
319
+ cv2.FONT_HERSHEY_SIMPLEX, 0.4, clr, 1)
320
+ orb_live.image(cv2.cvtColor(final, cv2.COLOR_BGR2RGB),
321
+ caption="ORB — Final Detections",
322
+ use_container_width=True)
323
+ orb_progress.empty()
324
+
325
+ with orb_results:
326
+ om1, om2, om3, om4 = st.columns(4)
327
+ om1.metric("Detections", len(detections))
328
+ om2.metric("Windows", n_total)
329
+ om3.metric("Total Time", f"{total_ms:.0f} ms")
330
+ om4.metric("Per Window", f"{total_ms/max(n_total,1):.2f} ms")
331
+
332
+ if heatmap.max() > 0:
333
+ hmap_color = cv2.applyColorMap(
334
+ (heatmap / heatmap.max() * 255).astype(np.uint8),
335
+ cv2.COLORMAP_JET)
336
+ blend = cv2.addWeighted(test_img, 0.5, hmap_color, 0.5, 0)
337
+ st.image(cv2.cvtColor(blend, cv2.COLOR_BGR2RGB),
338
+ caption="ORB — Confidence Heatmap",
339
+ use_container_width=True)
340
+
341
+ if detections:
342
+ import pandas as pd
343
+ df = pd.DataFrame(detections,
344
+ columns=["x1","y1","x2","y2","label","conf"])
345
+ st.dataframe(df, use_container_width=True, hide_index=True)
346
+
347
+ pipe["orb_dets"] = detections
348
+ pipe["orb_det_ms"] = total_ms
349
+ st.session_state["gen_pipeline"] = pipe
350
+
351
+ # ===================================================================
352
+ # Bottom — Comparison
353
+ # ===================================================================
354
+ rce_dets = pipe.get("rce_dets")
355
+ cnn_dets = pipe.get("cnn_dets")
356
+ orb_dets = pipe.get("orb_dets")
357
+
358
+ methods = {}
359
+ if rce_dets is not None:
360
+ methods["RCE"] = (rce_dets, pipe.get("rce_det_ms", 0), (0,255,0))
361
+ if cnn_dets is not None:
362
+ methods["CNN"] = (cnn_dets, pipe.get("cnn_det_ms", 0), (0,0,255))
363
+ if orb_dets is not None:
364
+ methods["ORB"] = (orb_dets, pipe.get("orb_det_ms", 0), (255,165,0))
365
+
366
+ if len(methods) >= 2:
367
+ st.divider()
368
+ st.subheader("📊 Side-by-Side Comparison")
369
+ import pandas as pd
370
+ comp = {"Metric": ["Detections", "Best Confidence", "Total Time (ms)"]}
371
+ for name, (dets, ms, _) in methods.items():
372
+ comp[name] = [
373
+ len(dets),
374
+ f"{max((d[5] for d in dets), default=0):.1%}",
375
+ f"{ms:.0f}",
376
+ ]
377
+ st.dataframe(pd.DataFrame(comp), use_container_width=True, hide_index=True)
378
+
379
+ overlay = test_img.copy()
380
+ for name, (dets, _, clr) in methods.items():
381
+ for x1d, y1d, x2d, y2d, lbl, cf in dets:
382
+ cv2.rectangle(overlay, (x1d, y1d), (x2d, y2d), clr, 2)
383
+ cv2.putText(overlay, f"{name}:{lbl} {cf:.0%}", (x1d, y1d - 6),
384
+ cv2.FONT_HERSHEY_SIMPLEX, 0.35, clr, 1)
385
+ legend = " | ".join(f"{n}={'green' if c==(0,255,0) else 'blue' if c==(0,0,255) else 'orange'}"
386
+ for n, (_, _, c) in methods.items())
387
+ st.image(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB),
388
+ caption=legend, use_container_width=True)
tabs/generalisation/evaluation.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Generalisation Evaluation — Stage 6 of the Generalisation pipeline."""
2
+
3
+ import streamlit as st
4
+ import cv2
5
+ import numpy as np
6
+ import plotly.graph_objects as go
7
+ import plotly.figure_factory as ff
8
+
9
+ from src.models import BACKBONES
10
+
11
+
12
+ def _iou(a, b):
13
+ xi1 = max(a[0], b[0]); yi1 = max(a[1], b[1])
14
+ xi2 = min(a[2], b[2]); yi2 = min(a[3], b[3])
15
+ inter = max(0, xi2 - xi1) * max(0, yi2 - yi1)
16
+ aa = (a[2] - a[0]) * (a[3] - a[1])
17
+ ab = (b[2] - b[0]) * (b[3] - b[1])
18
+ return inter / (aa + ab - inter + 1e-6)
19
+
20
+
21
+ def match_detections(dets, gt_list, iou_thr):
22
+ dets_sorted = sorted(dets, key=lambda d: d[5], reverse=True)
23
+ matched_gt = set()
24
+ results = []
25
+ for det in dets_sorted:
26
+ det_box = det[:4]
27
+ best_iou, best_gt_idx, best_gt_label = 0.0, -1, None
28
+ for gi, (gt_box, gt_label) in enumerate(gt_list):
29
+ if gi in matched_gt:
30
+ continue
31
+ iou_val = _iou(det_box, gt_box)
32
+ if iou_val > best_iou:
33
+ best_iou, best_gt_idx, best_gt_label = iou_val, gi, gt_label
34
+ if best_iou >= iou_thr and best_gt_idx >= 0:
35
+ matched_gt.add(best_gt_idx)
36
+ results.append((det, best_gt_label, best_iou))
37
+ else:
38
+ results.append((det, None, best_iou))
39
+ return results, len(gt_list) - len(matched_gt), matched_gt
40
+
41
+
42
+ def compute_pr_curve(dets, gt_list, iou_thr, steps=50):
43
+ if not dets:
44
+ return [], [], [], []
45
+ thresholds = np.linspace(0.0, 1.0, steps)
46
+ precisions, recalls, f1s = [], [], []
47
+ for thr in thresholds:
48
+ filtered = [d for d in dets if d[5] >= thr]
49
+ if not filtered:
50
+ precisions.append(1.0); recalls.append(0.0); f1s.append(0.0)
51
+ continue
52
+ matched, n_missed, _ = match_detections(filtered, gt_list, iou_thr)
53
+ tp = sum(1 for _, gt_lbl, _ in matched if gt_lbl is not None)
54
+ fp = sum(1 for _, gt_lbl, _ in matched if gt_lbl is None)
55
+ fn = n_missed
56
+ prec = tp / (tp + fp) if (tp + fp) > 0 else 1.0
57
+ rec = tp / (tp + fn) if (tp + fn) > 0 else 0.0
58
+ f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0.0
59
+ precisions.append(prec); recalls.append(rec); f1s.append(f1)
60
+ return thresholds.tolist(), precisions, recalls, f1s
61
+
62
+
63
+ def build_confusion_matrix(dets, gt_list, iou_thr):
64
+ gt_labels = sorted(set(lbl for _, lbl in gt_list))
65
+ all_labels = gt_labels + ["background"]
66
+ n = len(all_labels)
67
+ matrix = np.zeros((n, n), dtype=int)
68
+ label_to_idx = {lbl: i for i, lbl in enumerate(all_labels)}
69
+ matched, n_missed, matched_gt_indices = match_detections(dets, gt_list, iou_thr)
70
+ for det, gt_lbl, _ in matched:
71
+ pred_lbl = det[4]
72
+ if gt_lbl is not None:
73
+ pi = label_to_idx.get(pred_lbl, label_to_idx["background"])
74
+ gi = label_to_idx[gt_lbl]
75
+ matrix[pi][gi] += 1
76
+ else:
77
+ pi = label_to_idx.get(pred_lbl, label_to_idx["background"])
78
+ matrix[pi][label_to_idx["background"]] += 1
79
+ for gi, (_, gt_lbl) in enumerate(gt_list):
80
+ if gi not in matched_gt_indices:
81
+ matrix[label_to_idx["background"]][label_to_idx[gt_lbl]] += 1
82
+ return matrix, all_labels
83
+
84
+
85
+ def render():
86
+ st.title("📈 Evaluation: Confusion Matrix & PR Curves")
87
+
88
+ pipe = st.session_state.get("gen_pipeline")
89
+ if not pipe:
90
+ st.error("Complete the **Data Lab** first.")
91
+ st.stop()
92
+
93
+ crop = pipe.get("crop")
94
+ crop_aug = pipe.get("crop_aug", crop)
95
+ bbox = pipe.get("crop_bbox", (0, 0, crop.shape[1], crop.shape[0])) if crop is not None else None
96
+ rois = pipe.get("rois", [{"label": "object", "bbox": bbox,
97
+ "crop": crop, "crop_aug": crop_aug}])
98
+
99
+ rce_dets = pipe.get("rce_dets")
100
+ cnn_dets = pipe.get("cnn_dets")
101
+ orb_dets = pipe.get("orb_dets")
102
+
103
+ if rce_dets is None and cnn_dets is None and orb_dets is None:
104
+ st.warning("Run detection first in **Real-Time Detection**.")
105
+ st.stop()
106
+
107
+ gt_boxes = [(roi["bbox"], roi["label"]) for roi in rois]
108
+
109
+ st.sidebar.subheader("Evaluation Settings")
110
+ iou_thresh = st.sidebar.slider("IoU Threshold", 0.1, 0.9, 0.5, 0.05,
111
+ help="Minimum IoU to count as TP",
112
+ key="gen_eval_iou")
113
+
114
+ st.subheader("Ground Truth (from Data Lab ROIs)")
115
+ st.caption(f"{len(gt_boxes)} ground-truth ROIs defined")
116
+ gt_vis = pipe["test_image"].copy()
117
+ for (bx0, by0, bx1, by1), lbl in gt_boxes:
118
+ cv2.rectangle(gt_vis, (bx0, by0), (bx1, by1), (0, 255, 255), 2)
119
+ cv2.putText(gt_vis, lbl, (bx0, by0 - 6),
120
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 255), 1)
121
+ st.image(cv2.cvtColor(gt_vis, cv2.COLOR_BGR2RGB),
122
+ caption="Ground Truth Annotations", use_container_width=True)
123
+ st.divider()
124
+
125
+ methods = {}
126
+ if rce_dets is not None:
127
+ methods["RCE"] = rce_dets
128
+ if cnn_dets is not None:
129
+ methods["CNN"] = cnn_dets
130
+ if orb_dets is not None:
131
+ methods["ORB"] = orb_dets
132
+
133
+ # Confusion Matrices
134
+ st.subheader("🔲 Confusion Matrices")
135
+ cm_cols = st.columns(len(methods))
136
+ for col, (name, dets) in zip(cm_cols, methods.items()):
137
+ with col:
138
+ st.markdown(f"**{name}**")
139
+ matrix, labels = build_confusion_matrix(dets, gt_boxes, iou_thresh)
140
+ fig_cm = ff.create_annotated_heatmap(
141
+ z=matrix.tolist(), x=labels, y=labels,
142
+ colorscale="Blues", showscale=True)
143
+ fig_cm.update_layout(title=f"{name} Confusion Matrix",
144
+ xaxis_title="Actual", yaxis_title="Predicted",
145
+ template="plotly_dark", height=350)
146
+ fig_cm.update_yaxes(autorange="reversed")
147
+ st.plotly_chart(fig_cm, use_container_width=True)
148
+
149
+ matched, n_missed, _ = match_detections(dets, gt_boxes, iou_thresh)
150
+ tp = sum(1 for _, g, _ in matched if g is not None)
151
+ fp = sum(1 for _, g, _ in matched if g is None)
152
+ fn = n_missed
153
+ prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0
154
+ rec = tp / (tp + fn) if (tp + fn) > 0 else 0.0
155
+ f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0.0
156
+ m1, m2, m3 = st.columns(3)
157
+ m1.metric("Precision", f"{prec:.1%}")
158
+ m2.metric("Recall", f"{rec:.1%}")
159
+ m3.metric("F1 Score", f"{f1:.1%}")
160
+
161
+ # PR Curves
162
+ st.divider()
163
+ st.subheader("📉 Precision-Recall Curves")
164
+ method_colors = {"RCE": "#00ff88", "CNN": "#4488ff", "ORB": "#ff8800"}
165
+ fig_pr = go.Figure()
166
+ fig_f1 = go.Figure()
167
+ summary_rows = []
168
+
169
+ for name, dets in methods.items():
170
+ thrs, precs, recs, f1s = compute_pr_curve(dets, gt_boxes, iou_thresh)
171
+ clr = method_colors.get(name, "#ffffff")
172
+ fig_pr.add_trace(go.Scatter(
173
+ x=recs, y=precs, mode="lines+markers",
174
+ name=name, line=dict(color=clr, width=2), marker=dict(size=4)))
175
+ fig_f1.add_trace(go.Scatter(
176
+ x=thrs, y=f1s, mode="lines",
177
+ name=name, line=dict(color=clr, width=2)))
178
+ ap = float(np.trapz(precs, recs)) if recs and precs else 0.0
179
+ best_f1_idx = int(np.argmax(f1s)) if f1s else 0
180
+ summary_rows.append({
181
+ "Method": name,
182
+ "AP": f"{abs(ap):.3f}",
183
+ "Best F1": f"{f1s[best_f1_idx]:.3f}" if f1s else "N/A",
184
+ "@ Threshold": f"{thrs[best_f1_idx]:.2f}" if thrs else "N/A",
185
+ "Detections": len(dets),
186
+ })
187
+
188
+ fig_pr.update_layout(title="Precision vs Recall",
189
+ xaxis_title="Recall", yaxis_title="Precision",
190
+ template="plotly_dark", height=400,
191
+ xaxis=dict(range=[0, 1.05]), yaxis=dict(range=[0, 1.05]))
192
+ fig_f1.update_layout(title="F1 Score vs Confidence Threshold",
193
+ xaxis_title="Confidence Threshold", yaxis_title="F1 Score",
194
+ template="plotly_dark", height=400,
195
+ xaxis=dict(range=[0, 1.05]), yaxis=dict(range=[0, 1.05]))
196
+ pc1, pc2 = st.columns(2)
197
+ pc1.plotly_chart(fig_pr, use_container_width=True)
198
+ pc2.plotly_chart(fig_f1, use_container_width=True)
199
+
200
+ # Summary Table
201
+ st.divider()
202
+ st.subheader("📊 Summary")
203
+ import pandas as pd
204
+ st.dataframe(pd.DataFrame(summary_rows), use_container_width=True, hide_index=True)
205
+ st.caption(f"All metrics computed at IoU threshold = **{iou_thresh:.2f}**.")
tabs/generalisation/feature_lab.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Generalisation Feature Lab — Stage 2 of the Generalisation pipeline."""
2
+
3
+ import streamlit as st
4
+ import cv2
5
+ import numpy as np
6
+ import plotly.graph_objects as go
7
+
8
+ from src.detectors.rce.features import REGISTRY
9
+ from src.models import BACKBONES
10
+
11
+
12
+ def render():
13
+ pipe = st.session_state.get("gen_pipeline")
14
+ if not pipe or "crop" not in pipe:
15
+ st.error("Please complete the **Data Lab** first!")
16
+ st.stop()
17
+
18
+ obj = pipe.get("crop_aug", pipe.get("crop"))
19
+ if obj is None:
20
+ st.error("No crop found. Go back to Data Lab and define a ROI.")
21
+ st.stop()
22
+ gray = cv2.cvtColor(obj, cv2.COLOR_BGR2GRAY)
23
+
24
+ st.title("🔬 Feature Lab: Physical Module Selection")
25
+
26
+ col_rce, col_cnn = st.columns([3, 2])
27
+
28
+ with col_rce:
29
+ st.header("🧬 RCE: Modular Physics Engine")
30
+ st.subheader("Select Active Modules")
31
+
32
+ active = {}
33
+ items = list(REGISTRY.items())
34
+ for row_start in range(0, len(items), 4):
35
+ row_items = items[row_start:row_start + 4]
36
+ m_cols = st.columns(4)
37
+ for col, (key, meta) in zip(m_cols, row_items):
38
+ active[key] = col.checkbox(meta["label"],
39
+ value=(key in ("intensity", "sobel", "spectral")),
40
+ key=f"gen_fl_{key}")
41
+
42
+ final_vector = []
43
+ viz_images = []
44
+ for key, meta in REGISTRY.items():
45
+ if active[key]:
46
+ vec, viz = meta["fn"](gray)
47
+ final_vector.extend(vec)
48
+ viz_images.append((meta["viz_title"], viz))
49
+
50
+ st.divider()
51
+ if viz_images:
52
+ for row_start in range(0, len(viz_images), 3):
53
+ row = viz_images[row_start:row_start + 3]
54
+ v_cols = st.columns(3)
55
+ for col, (title, img) in zip(v_cols, row):
56
+ col.image(img, caption=title, use_container_width=True)
57
+ else:
58
+ st.warning("No modules selected — vector is empty.")
59
+
60
+ st.write(f"### Current DNA Vector Size: **{len(final_vector)}**")
61
+ fig_vec = go.Figure(data=[go.Bar(y=final_vector, marker_color="#00d4ff")])
62
+ fig_vec.update_layout(title="Active Feature Vector (RCE Input)",
63
+ template="plotly_dark", height=300)
64
+ st.plotly_chart(fig_vec, use_container_width=True)
65
+
66
+ with col_cnn:
67
+ st.header("🧠 CNN: Static Architecture")
68
+ selected_cnn = st.selectbox("Compare against Model", list(BACKBONES.keys()),
69
+ key="gen_fl_cnn")
70
+ st.info("CNN features are fixed by pre-trained weights.")
71
+
72
+ with st.spinner(f"Loading {selected_cnn} and extracting activations..."):
73
+ try:
74
+ bmeta = BACKBONES[selected_cnn]
75
+ backbone = bmeta["loader"]()
76
+ layer_name = bmeta["hook_layer"]
77
+ act_maps = backbone.get_activation_maps(obj, n_maps=6)
78
+ st.caption(f"Hooked layer: `{layer_name}` — showing 6 of "
79
+ f"{len(act_maps)} channels")
80
+ act_cols = st.columns(3)
81
+ for i, amap in enumerate(act_maps):
82
+ act_cols[i % 3].image(amap, caption=f"Channel {i}",
83
+ use_container_width=True)
84
+ except Exception as e:
85
+ st.error(f"Could not load model: {e}")
86
+
87
+ st.divider()
88
+ st.markdown(f"""
89
+ **Analysis:**
90
+ - **Modularity:** RCE is **High** | CNN is **Zero**
91
+ - **Explainability:** RCE is **High** | CNN is **Low**
92
+ - **Compute Cost:** {len(final_vector)} floats | 512+ floats
93
+ """)
94
+
95
+ if st.button("🚀 Lock Modular Configuration", key="gen_fl_lock"):
96
+ if not final_vector:
97
+ st.error("Please select at least one module!")
98
+ else:
99
+ pipe["final_vector"] = np.array(final_vector)
100
+ pipe["active_modules"] = {k: v for k, v in active.items()}
101
+ st.session_state["gen_pipeline"] = pipe
102
+ st.success("Modular DNA Locked! Ready for Model Tuning.")
tabs/generalisation/localization.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Generalisation Localization Lab — Stage 4 of the Generalisation pipeline."""
2
+
3
+ import streamlit as st
4
+ import cv2
5
+ import numpy as np
6
+ import pandas as pd
7
+ import plotly.graph_objects as go
8
+
9
+ from src.detectors.rce.features import REGISTRY
10
+ from src.models import BACKBONES
11
+ from src.utils import build_rce_vector
12
+ from src.localization import (
13
+ exhaustive_sliding_window,
14
+ image_pyramid,
15
+ coarse_to_fine,
16
+ contour_proposals,
17
+ template_matching,
18
+ STRATEGIES,
19
+ )
20
+
21
+
22
+ def render():
23
+ st.title("🔍 Localization Lab")
24
+ st.markdown(
25
+ "Compare **localization strategies** — algorithms that decide *where* "
26
+ "to look in the image. The recognition head stays the same; only the "
27
+ "search method changes."
28
+ )
29
+
30
+ pipe = st.session_state.get("gen_pipeline")
31
+ if not pipe or "crop" not in pipe:
32
+ st.error("Complete **Data Lab** first (upload assets & define a crop).")
33
+ st.stop()
34
+
35
+ test_img = pipe["test_image"]
36
+ crop = pipe["crop"]
37
+ crop_aug = pipe.get("crop_aug", crop)
38
+ bbox = pipe.get("crop_bbox", (0, 0, crop.shape[1], crop.shape[0]))
39
+ active_mods = pipe.get("active_modules", {k: True for k in REGISTRY})
40
+
41
+ x0, y0, x1, y1 = bbox
42
+ win_h, win_w = y1 - y0, x1 - x0
43
+
44
+ if win_h <= 0 or win_w <= 0:
45
+ st.error("Invalid window size from crop bbox. "
46
+ "Go back to **Data Lab** and redefine the ROI.")
47
+ st.stop()
48
+
49
+ rce_head = pipe.get("rce_head")
50
+ has_any_cnn = any(f"cnn_head_{n}" in pipe for n in BACKBONES)
51
+
52
+ if rce_head is None and not has_any_cnn:
53
+ st.warning("No trained heads found. Go to **Model Tuning** first.")
54
+ st.stop()
55
+
56
+ def rce_feature_fn(patch_bgr):
57
+ return build_rce_vector(patch_bgr, active_mods)
58
+
59
+ # Algorithm Reference
60
+ st.divider()
61
+ with st.expander("📚 **Algorithm Reference** — click to expand", expanded=False):
62
+ tabs = st.tabs([f"{v['icon']} {k}" for k, v in STRATEGIES.items()])
63
+ for tab, (name, meta) in zip(tabs, STRATEGIES.items()):
64
+ with tab:
65
+ st.markdown(f"### {meta['icon']} {name}")
66
+ st.caption(meta["short"])
67
+ st.markdown(meta["detail"])
68
+
69
+ # Configuration
70
+ st.divider()
71
+ st.header("⚙️ Configuration")
72
+
73
+ col_head, col_info = st.columns([2, 3])
74
+ with col_head:
75
+ head_options = []
76
+ if rce_head is not None:
77
+ head_options.append("RCE")
78
+ trained_cnns = [n for n in BACKBONES if f"cnn_head_{n}" in pipe]
79
+ head_options.extend(trained_cnns)
80
+ selected_head = st.selectbox("Recognition Head", head_options,
81
+ key="gen_loc_head")
82
+
83
+ if selected_head == "RCE":
84
+ feature_fn = rce_feature_fn
85
+ head = rce_head
86
+ else:
87
+ bmeta = BACKBONES[selected_head]
88
+ backbone = bmeta["loader"]()
89
+ feature_fn = backbone.get_features
90
+ head = pipe[f"cnn_head_{selected_head}"]
91
+
92
+ with col_info:
93
+ if selected_head == "RCE":
94
+ mods = [REGISTRY[k]["label"] for k in active_mods if active_mods[k]]
95
+ st.info(f"**RCE** — Modules: {', '.join(mods)}")
96
+ else:
97
+ st.info(f"**{selected_head}** — "
98
+ f"{BACKBONES[selected_head]['dim']}D feature vector")
99
+
100
+ # Algorithm checkboxes
101
+ st.subheader("Select Algorithms to Compare")
102
+ algo_cols = st.columns(5)
103
+ algo_names = list(STRATEGIES.keys())
104
+ algo_checks = {}
105
+ for col, name in zip(algo_cols, algo_names):
106
+ algo_checks[name] = col.checkbox(
107
+ f"{STRATEGIES[name]['icon']} {name}",
108
+ value=(name != "Template Matching"),
109
+ key=f"gen_chk_{name}")
110
+
111
+ any_selected = any(algo_checks.values())
112
+
113
+ # Parameters
114
+ st.subheader("Parameters")
115
+ sp1, sp2, sp3 = st.columns(3)
116
+ stride = sp1.slider("Base Stride (px)", 4, max(win_w, win_h),
117
+ max(win_w // 4, 4), step=2, key="gen_loc_stride")
118
+ conf_thresh = sp2.slider("Confidence Threshold", 0.5, 1.0, 0.7, 0.05,
119
+ key="gen_loc_conf")
120
+ nms_iou = sp3.slider("NMS IoU Threshold", 0.1, 0.9, 0.3, 0.05,
121
+ key="gen_loc_nms")
122
+
123
+ with st.expander("🔧 Per-Algorithm Settings"):
124
+ pa1, pa2, pa3 = st.columns(3)
125
+ with pa1:
126
+ st.markdown("**Image Pyramid**")
127
+ pyr_min = st.slider("Min Scale", 0.3, 1.0, 0.5, 0.05, key="gen_pyr_min")
128
+ pyr_max = st.slider("Max Scale", 1.0, 2.0, 1.5, 0.1, key="gen_pyr_max")
129
+ pyr_n = st.slider("Number of Scales", 3, 7, 5, key="gen_pyr_n")
130
+ with pa2:
131
+ st.markdown("**Coarse-to-Fine**")
132
+ c2f_factor = st.slider("Coarse Factor", 2, 8, 4, key="gen_c2f_factor")
133
+ c2f_radius = st.slider("Refine Radius (strides)", 1, 5, 2, key="gen_c2f_radius")
134
+ with pa3:
135
+ st.markdown("**Contour Proposals**")
136
+ cnt_low = st.slider("Canny Low", 10, 100, 50, key="gen_cnt_low")
137
+ cnt_high = st.slider("Canny High", 50, 300, 150, key="gen_cnt_high")
138
+ cnt_tol = st.slider("Area Tolerance", 1.5, 10.0, 3.0, 0.5, key="gen_cnt_tol")
139
+
140
+ st.caption(
141
+ f"Window: **{win_w}×{win_h} px** · "
142
+ f"Image: **{test_img.shape[1]}×{test_img.shape[0]} px** · "
143
+ f"Stride: **{stride} px**"
144
+ )
145
+
146
+ # Run
147
+ st.divider()
148
+ run_btn = st.button("▶ Run Comparison", type="primary",
149
+ disabled=not any_selected, use_container_width=True,
150
+ key="gen_loc_run")
151
+
152
+ if run_btn:
153
+ selected_algos = [n for n in algo_names if algo_checks[n]]
154
+ progress = st.progress(0, text="Starting…")
155
+ results = {}
156
+ edge_maps = {}
157
+
158
+ for i, name in enumerate(selected_algos):
159
+ progress.progress(i / len(selected_algos), text=f"Running **{name}**…")
160
+
161
+ if name == "Exhaustive Sliding Window":
162
+ dets, n, ms, hmap = exhaustive_sliding_window(
163
+ test_img, win_h, win_w, feature_fn, head,
164
+ stride, conf_thresh, nms_iou)
165
+ elif name == "Image Pyramid":
166
+ scales = np.linspace(pyr_min, pyr_max, pyr_n).tolist()
167
+ dets, n, ms, hmap = image_pyramid(
168
+ test_img, win_h, win_w, feature_fn, head,
169
+ stride, conf_thresh, nms_iou, scales=scales)
170
+ elif name == "Coarse-to-Fine":
171
+ dets, n, ms, hmap = coarse_to_fine(
172
+ test_img, win_h, win_w, feature_fn, head,
173
+ stride, conf_thresh, nms_iou,
174
+ coarse_factor=c2f_factor, refine_radius=c2f_radius)
175
+ elif name == "Contour Proposals":
176
+ dets, n, ms, hmap, edges = contour_proposals(
177
+ test_img, win_h, win_w, feature_fn, head,
178
+ conf_thresh, nms_iou,
179
+ canny_low=cnt_low, canny_high=cnt_high,
180
+ area_tolerance=cnt_tol)
181
+ edge_maps[name] = edges
182
+ elif name == "Template Matching":
183
+ dets, n, ms, hmap = template_matching(
184
+ test_img, crop_aug, conf_thresh, nms_iou)
185
+
186
+ results[name] = {"dets": dets, "n_proposals": n,
187
+ "time_ms": ms, "heatmap": hmap}
188
+
189
+ progress.progress(1.0, text="Done!")
190
+
191
+ # Summary Table
192
+ st.header("📊 Results")
193
+ baseline_ms = results.get("Exhaustive Sliding Window", {}).get("time_ms")
194
+ rows = []
195
+ for name, r in results.items():
196
+ speedup = (baseline_ms / r["time_ms"]
197
+ if baseline_ms and r["time_ms"] > 0 else None)
198
+ rows.append({
199
+ "Algorithm": name,
200
+ "Proposals": r["n_proposals"],
201
+ "Time (ms)": round(r["time_ms"], 1),
202
+ "Detections": len(r["dets"]),
203
+ "ms / Proposal": round(r["time_ms"] / max(r["n_proposals"], 1), 4),
204
+ "Speedup": f"{speedup:.1f}×" if speedup else "—",
205
+ })
206
+ st.dataframe(pd.DataFrame(rows), use_container_width=True, hide_index=True)
207
+
208
+ # Detection Images & Heatmaps
209
+ st.subheader("Detection Results")
210
+ COLORS = {
211
+ "Exhaustive Sliding Window": (0, 255, 0),
212
+ "Image Pyramid": (255, 128, 0),
213
+ "Coarse-to-Fine": (0, 128, 255),
214
+ "Contour Proposals": (255, 0, 255),
215
+ "Template Matching": (0, 255, 255),
216
+ }
217
+
218
+ result_tabs = st.tabs(
219
+ [f"{STRATEGIES[n]['icon']} {n}" for n in results])
220
+
221
+ for tab, (name, r) in zip(result_tabs, results.items()):
222
+ with tab:
223
+ c1, c2 = st.columns(2)
224
+ color = COLORS.get(name, (0, 255, 0))
225
+
226
+ vis = test_img.copy()
227
+ for x1d, y1d, x2d, y2d, _, cf in r["dets"]:
228
+ cv2.rectangle(vis, (x1d, y1d), (x2d, y2d), color, 2)
229
+ cv2.putText(vis, f"{cf:.0%}", (x1d, y1d - 6),
230
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
231
+ c1.image(cv2.cvtColor(vis, cv2.COLOR_BGR2RGB),
232
+ caption=f"{name} — {len(r['dets'])} detections",
233
+ use_container_width=True)
234
+
235
+ hmap = r["heatmap"]
236
+ if hmap.max() > 0:
237
+ hmap_color = cv2.applyColorMap(
238
+ (hmap / hmap.max() * 255).astype(np.uint8),
239
+ cv2.COLORMAP_JET)
240
+ blend = cv2.addWeighted(test_img, 0.5, hmap_color, 0.5, 0)
241
+ c2.image(cv2.cvtColor(blend, cv2.COLOR_BGR2RGB),
242
+ caption=f"{name} — Confidence Heatmap",
243
+ use_container_width=True)
244
+ else:
245
+ c2.info("No positive responses above threshold.")
246
+
247
+ if name in edge_maps:
248
+ st.image(edge_maps[name],
249
+ caption="Canny Edge Map",
250
+ use_container_width=True, clamp=True)
251
+
252
+ m1, m2, m3, m4 = st.columns(4)
253
+ m1.metric("Proposals", r["n_proposals"])
254
+ m2.metric("Time", f"{r['time_ms']:.0f} ms")
255
+ m3.metric("Detections", len(r["dets"]))
256
+ m4.metric("ms / Proposal",
257
+ f"{r['time_ms'] / max(r['n_proposals'], 1):.3f}")
258
+
259
+ if r["dets"]:
260
+ df = pd.DataFrame(r["dets"],
261
+ columns=["x1","y1","x2","y2","label","conf"])
262
+ st.dataframe(df, use_container_width=True, hide_index=True)
263
+
264
+ # Performance Charts
265
+ st.subheader("📈 Performance Comparison")
266
+ ch1, ch2 = st.columns(2)
267
+ names = list(results.keys())
268
+ times = [results[n]["time_ms"] for n in names]
269
+ props = [results[n]["n_proposals"] for n in names]
270
+ n_dets = [len(results[n]["dets"]) for n in names]
271
+ colors_hex = ["#00cc66", "#ff8800", "#0088ff", "#ff00ff", "#00cccc"]
272
+
273
+ with ch1:
274
+ fig = go.Figure(go.Bar(
275
+ x=names, y=times,
276
+ text=[f"{t:.0f}" for t in times], textposition="auto",
277
+ marker_color=colors_hex[:len(names)]))
278
+ fig.update_layout(title="Total Time (ms)", yaxis_title="ms", height=400)
279
+ st.plotly_chart(fig, use_container_width=True)
280
+
281
+ with ch2:
282
+ fig = go.Figure(go.Bar(
283
+ x=names, y=props,
284
+ text=[str(p) for p in props], textposition="auto",
285
+ marker_color=colors_hex[:len(names)]))
286
+ fig.update_layout(title="Proposals Evaluated",
287
+ yaxis_title="Count", height=400)
288
+ st.plotly_chart(fig, use_container_width=True)
289
+
290
+ fig = go.Figure()
291
+ for i, name in enumerate(names):
292
+ fig.add_trace(go.Scatter(
293
+ x=[props[i]], y=[times[i]],
294
+ mode="markers+text",
295
+ marker=dict(size=max(n_dets[i] * 12, 18),
296
+ color=colors_hex[i % len(colors_hex)]),
297
+ text=[name], textposition="top center", name=name))
298
+ fig.update_layout(
299
+ title="Proposals vs Time (marker size ∝ detections)",
300
+ xaxis_title="Proposals Evaluated",
301
+ yaxis_title="Time (ms)", height=500)
302
+ st.plotly_chart(fig, use_container_width=True)