Spaces:
Sleeping
Sleeping
DariusGiannoli commited on
Commit ·
a51a1a7
1
Parent(s): 397a1b0
refactor: tab-based routing with two pipelines (Stereo+Depth & Generalisation)
Browse files- Rewrite app.py as sidebar routing controller (no more multi-page Streamlit)
- Add tabs/stereo/ (7 stages: data_lab, feature_lab, model_tuning, localization,
detection, evaluation, stereo_depth)
- Add tabs/generalisation/ (6 stages: same minus stereo_depth)
- Add utils/middlebury_loader.py (PFM + calib parsing, scene group scanner)
- Namespace session state into stereo_pipeline / gen_pipeline dicts
- Fix data leakage: train on LEFT/variant-A, detect on RIGHT/variant-B
- All widget keys prefixed per pipeline to prevent collisions
- Remove deprecated pages/ directory
This view is limited to 50 files because it contains too many changes. See raw diff
- CLAUDE_CODE_PROMPT.md +656 -0
- app.py +171 -155
- dataOLD/README.md +5 -0
- dataOLD/artroom/bird/yolo/bird_data.yaml +7 -0
- dataOLD/artroom/bird/yolo/train/images/bird_01_original.png +3 -0
- dataOLD/artroom/bird/yolo/train/images/bird_02_rot_pos5.png +3 -0
- dataOLD/artroom/bird/yolo/train/images/bird_03_rot_neg5.png +3 -0
- dataOLD/artroom/bird/yolo/train/images/bird_04_bright.png +3 -0
- dataOLD/artroom/bird/yolo/train/images/bird_05_dark.png +3 -0
- dataOLD/artroom/bird/yolo/train/images/bird_06_noisy.png +3 -0
- dataOLD/artroom/bird/yolo/train/images/bird_07_flip.png +3 -0
- dataOLD/artroom/bird/yolo/train/images/bird_08_blur.png +3 -0
- dataOLD/artroom/bird/yolo/train/images/bird_09_shift_x.png +3 -0
- dataOLD/artroom/bird/yolo/train/images/bird_10_shift_y.png +3 -0
- dataOLD/artroom/bird/yolo/train/images/room_1.png +3 -0
- dataOLD/artroom/bird/yolo/train/images/room_2.png +3 -0
- dataOLD/artroom/bird/yolo/train/images/room_3.png +3 -0
- dataOLD/artroom/bird/yolo/train/images/room_4.png +3 -0
- dataOLD/artroom/bird/yolo/train/images/room_5.png +3 -0
- dataOLD/artroom/bird/yolo/train/labels.cache +0 -0
- dataOLD/artroom/bird/yolo/train/labels/bird_01_original.txt +1 -0
- dataOLD/artroom/bird/yolo/train/labels/bird_02_rot_pos5.txt +1 -0
- dataOLD/artroom/bird/yolo/train/labels/bird_03_rot_neg5.txt +1 -0
- dataOLD/artroom/bird/yolo/train/labels/bird_04_bright.txt +1 -0
- dataOLD/artroom/bird/yolo/train/labels/bird_05_dark.txt +1 -0
- dataOLD/artroom/bird/yolo/train/labels/bird_06_noisy.txt +1 -0
- dataOLD/artroom/bird/yolo/train/labels/bird_07_flip.txt +1 -0
- dataOLD/artroom/bird/yolo/train/labels/bird_08_blur.txt +1 -0
- dataOLD/artroom/bird/yolo/train/labels/bird_09_shift_x.txt +1 -0
- dataOLD/artroom/bird/yolo/train/labels/bird_10_shift_y.txt +1 -0
- dataOLD/artroom/bird/yolo/train/labels/room_1.txt +0 -0
- dataOLD/artroom/bird/yolo/train/labels/room_2.txt +0 -0
- dataOLD/artroom/bird/yolo/train/labels/room_3.txt +0 -0
- dataOLD/artroom/bird/yolo/train/labels/room_4.txt +0 -0
- dataOLD/artroom/bird/yolo/train/labels/room_5.txt +0 -0
- dataOLD/artroom/im0.png +3 -0
- pages/2_Data_Lab.py +0 -321
- pages/3_Feature_Lab.py +0 -111
- pages/4_Model_Tuning.py +0 -475
- pages/5_Localization_Lab.py +0 -348
- pages/6_RealTime_Detection.py +0 -435
- pages/7_Evaluation.py +0 -295
- pages/8_Stereo_Geometry.py +0 -353
- tabs/__init__.py +0 -0
- tabs/generalisation/__init__.py +0 -0
- tabs/generalisation/data_lab.py +269 -0
- tabs/generalisation/detection.py +388 -0
- tabs/generalisation/evaluation.py +205 -0
- tabs/generalisation/feature_lab.py +102 -0
- tabs/generalisation/localization.py +302 -0
CLAUDE_CODE_PROMPT.md
ADDED
|
@@ -0,0 +1,656 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Claude Code Implementation Prompt
|
| 2 |
+
# Recognition-BenchMark — Full Restructure
|
| 3 |
+
|
| 4 |
+
---
|
| 5 |
+
|
| 6 |
+
## Context
|
| 7 |
+
|
| 8 |
+
This is a Streamlit-based stereo-vision benchmarking platform called **Recognition-BenchMark**. It compares a custom hand-crafted feature extractor called **RCE (Relative Contextual Encoding)** against CNN-based deep learning approaches for object recognition and depth estimation.
|
| 9 |
+
|
| 10 |
+
### Current Project Structure
|
| 11 |
+
|
| 12 |
+
```
|
| 13 |
+
app.py ← Landing page (home)
|
| 14 |
+
pages/
|
| 15 |
+
├── 2_Data_Lab.py ← Stage 1
|
| 16 |
+
├── 3_Feature_Lab.py ← Stage 2
|
| 17 |
+
├── 4_Model_Tuning.py ← Stage 3
|
| 18 |
+
├── 5_Localization_Lab.py ← Stage 4
|
| 19 |
+
├── 6_RealTime_Detection.py ← Stage 5
|
| 20 |
+
├── 7_Evaluation.py ← Stage 6
|
| 21 |
+
└── 8_Stereo_Geometry.py ← Stage 7
|
| 22 |
+
src/
|
| 23 |
+
├── config.py ← App configuration constants
|
| 24 |
+
├── detectors/
|
| 25 |
+
│ ├── base.py ← Base detector class
|
| 26 |
+
│ ├── rce/
|
| 27 |
+
│ │ ├── __init__.py
|
| 28 |
+
│ │ └── features.py ← RCE feature extractor (DO NOT MODIFY)
|
| 29 |
+
│ ├── mobilenet.py ← MobileNetV3 detector (DO NOT MODIFY)
|
| 30 |
+
│ ├── mobilevit.py ← MobileViT detector (DO NOT MODIFY)
|
| 31 |
+
│ ├── resnet.py ← ResNet-18 detector (DO NOT MODIFY)
|
| 32 |
+
│ ├── orb.py ← ORB detector (DO NOT MODIFY)
|
| 33 |
+
│ └── yolo.py ← YOLOv8 detector (DO NOT MODIFY)
|
| 34 |
+
├── localization.py ← Localization strategies (DO NOT MODIFY)
|
| 35 |
+
└── models.py ← Model loading utilities (DO NOT MODIFY)
|
| 36 |
+
models/
|
| 37 |
+
├── mobilenet_v3_head.pkl
|
| 38 |
+
├── mobilenet_v3.pth
|
| 39 |
+
├── mobilevit_head.pkl
|
| 40 |
+
├── mobilevit_xxs.pth
|
| 41 |
+
├── orb_reference.pkl
|
| 42 |
+
├── resnet18_head.pkl
|
| 43 |
+
├── resnet18.pth
|
| 44 |
+
└── yolov8n.pt
|
| 45 |
+
data/
|
| 46 |
+
└── middlebury/ ← Bundled dataset (already present)
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
**The entire `src/` directory must not be modified.** All detector logic, feature extraction, localization strategies, and model loading are already implemented there. The pages in `pages/` import from `src/` and must be migrated to the new `tabs/` structure while continuing to import from `src/`.
|
| 50 |
+
|
| 51 |
+
---
|
| 52 |
+
|
| 53 |
+
## Critical Bug To Fix
|
| 54 |
+
|
| 55 |
+
**Data leakage through circular evaluation.** Currently the detection/recognition stage runs on the same left image used to define the training ROI. This is scientifically invalid — the model is tested on its own training source.
|
| 56 |
+
|
| 57 |
+
**The fix:**
|
| 58 |
+
- In the Stereo pipeline: train on LEFT image crop → detect on RIGHT image
|
| 59 |
+
- In the Generalisation pipeline: train on image 1 crop → detect on image 2
|
| 60 |
+
|
| 61 |
+
This must be propagated through session state so every stage after Data Lab knows which image is for training (source) and which is for testing (target).
|
| 62 |
+
|
| 63 |
+
---
|
| 64 |
+
|
| 65 |
+
## Target Architecture
|
| 66 |
+
|
| 67 |
+
### File Structure After Refactor
|
| 68 |
+
|
| 69 |
+
```
|
| 70 |
+
app.py ← REPLACE: routing controller + home page
|
| 71 |
+
tabs/
|
| 72 |
+
├── stereo/
|
| 73 |
+
│ ├── __init__.py
|
| 74 |
+
│ ├── data_lab.py ← NEW: replaces pages/2_Data_Lab.py for stereo
|
| 75 |
+
│ ├── feature_lab.py ← MIGRATE: from pages/3_Feature_Lab.py
|
| 76 |
+
│ ├── model_tuning.py ← MIGRATE: from pages/4_Model_Tuning.py
|
| 77 |
+
│ ├── localization.py ← MIGRATE: from pages/5_Localization_Lab.py
|
| 78 |
+
│ ├── detection.py ← MIGRATE + FIX: from pages/6_RealTime_Detection.py
|
| 79 |
+
│ ├── evaluation.py ← MIGRATE: from pages/7_Evaluation.py
|
| 80 |
+
│ └── stereo_depth.py ← MIGRATE: from pages/8_Stereo_Geometry.py
|
| 81 |
+
├── generalisation/
|
| 82 |
+
│ ├── __init__.py
|
| 83 |
+
│ ├── data_lab.py ← NEW: generalisation-specific data loading
|
| 84 |
+
│ ├── feature_lab.py ← ADAPT: stereo version with gen_pipeline keys
|
| 85 |
+
│ ├── model_tuning.py ← ADAPT: stereo version with gen_pipeline keys
|
| 86 |
+
│ ├── localization.py ← ADAPT: stereo version with gen_pipeline keys
|
| 87 |
+
│ ├── detection.py ← ADAPT + FIX: stereo version with gen_pipeline keys
|
| 88 |
+
│ └── evaluation.py ← ADAPT: stereo version with gen_pipeline keys
|
| 89 |
+
utils/
|
| 90 |
+
└── middlebury_loader.py ← NEW: dataset scanning, loading, parsing
|
| 91 |
+
src/ ← DO NOT TOUCH: all detector/model logic stays here
|
| 92 |
+
pages/ ← DELETE after migration is complete and verified
|
| 93 |
+
data/
|
| 94 |
+
└── middlebury/ ← Already present, do not modify
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
---
|
| 98 |
+
|
| 99 |
+
## Part 1 — app.py Routing Controller
|
| 100 |
+
|
| 101 |
+
Replace the existing `app.py` entirely. The new `app.py` is a **routing controller** that:
|
| 102 |
+
|
| 103 |
+
1. Sets page config (keep existing title/icon/layout)
|
| 104 |
+
2. Builds the sidebar navigation manually using session state
|
| 105 |
+
3. Renders the correct module based on navigation state
|
| 106 |
+
4. Preserves the existing landing page content (pipeline overview, models, depth info)
|
| 107 |
+
|
| 108 |
+
### Sidebar Logic
|
| 109 |
+
|
| 110 |
+
```python
|
| 111 |
+
import streamlit as st
|
| 112 |
+
|
| 113 |
+
# Top-level navigation
|
| 114 |
+
st.sidebar.title("🦅 Recognition BenchMark")
|
| 115 |
+
|
| 116 |
+
top_section = st.sidebar.radio(
|
| 117 |
+
"Navigation",
|
| 118 |
+
["🏠 Home", "📷 Stereo + Depth", "🌍 Generalisation"],
|
| 119 |
+
key="top_nav"
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
if top_section == "🏠 Home":
|
| 123 |
+
# render home/landing page content inline in app.py
|
| 124 |
+
render_home()
|
| 125 |
+
|
| 126 |
+
elif top_section == "📷 Stereo + Depth":
|
| 127 |
+
stereo_stage = st.sidebar.radio(
|
| 128 |
+
"Pipeline Stage",
|
| 129 |
+
[
|
| 130 |
+
"🧪 1 · Data Lab",
|
| 131 |
+
"🔬 2 · Feature Lab",
|
| 132 |
+
"⚙️ 3 · Model Tuning",
|
| 133 |
+
"🔍 4 · Localization",
|
| 134 |
+
"🎯 5 · Detection",
|
| 135 |
+
"📈 6 · Evaluation",
|
| 136 |
+
"📐 7 · Stereo Depth"
|
| 137 |
+
],
|
| 138 |
+
key="stereo_stage"
|
| 139 |
+
)
|
| 140 |
+
# import and call the appropriate render() function from tabs/stereo/
|
| 141 |
+
|
| 142 |
+
elif top_section == "🌍 Generalisation":
|
| 143 |
+
gen_stage = st.sidebar.radio(
|
| 144 |
+
"Pipeline Stage",
|
| 145 |
+
[
|
| 146 |
+
"🧪 1 · Data Lab",
|
| 147 |
+
"🔬 2 · Feature Lab",
|
| 148 |
+
"⚙️ 3 · Model Tuning",
|
| 149 |
+
"🔍 4 · Localization",
|
| 150 |
+
"🎯 5 · Detection",
|
| 151 |
+
"📈 6 · Evaluation"
|
| 152 |
+
],
|
| 153 |
+
key="gen_stage"
|
| 154 |
+
)
|
| 155 |
+
# import and call the appropriate render() function from tabs/generalisation/
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
### Stage Guard Pattern
|
| 159 |
+
|
| 160 |
+
Every stage except Data Lab must check if the previous stage is complete. Use this pattern at the top of each stage's `render()` function:
|
| 161 |
+
|
| 162 |
+
```python
|
| 163 |
+
def render():
|
| 164 |
+
pipe = st.session_state.get("stereo_pipeline", {})
|
| 165 |
+
if "train_image" not in pipe:
|
| 166 |
+
st.warning("⚠️ Complete **Data Lab** first before accessing this stage.")
|
| 167 |
+
st.stop()
|
| 168 |
+
# ... rest of stage logic
|
| 169 |
+
```
|
| 170 |
+
|
| 171 |
+
### Session State Namespacing
|
| 172 |
+
|
| 173 |
+
**Critical:** The two pipelines must never share session state keys.
|
| 174 |
+
|
| 175 |
+
- Stereo pipeline uses: `st.session_state["stereo_pipeline"]` — a dict containing all stereo stage data
|
| 176 |
+
- Generalisation pipeline uses: `st.session_state["gen_pipeline"]` — a dict containing all generalisation stage data
|
| 177 |
+
|
| 178 |
+
Within each dict, use consistent keys:
|
| 179 |
+
```python
|
| 180 |
+
# Stereo pipeline dict keys
|
| 181 |
+
stereo_pipeline = {
|
| 182 |
+
"train_image": np.ndarray, # LEFT image — used for ROI + training
|
| 183 |
+
"test_image": np.ndarray, # RIGHT image — used for detection
|
| 184 |
+
"calib": dict, # parsed calibration parameters
|
| 185 |
+
"disparity_gt": np.ndarray, # ground truth disparity (optional, may be None)
|
| 186 |
+
"roi": dict, # {"x", "y", "w", "h", "label"}
|
| 187 |
+
"crop": np.ndarray, # cropped ROI from train_image
|
| 188 |
+
"crop_aug": list, # augmented crop variants
|
| 189 |
+
"active_modules": list, # RCE modules ["intensity", "sobel", "spectral"]
|
| 190 |
+
"rce_head": object, # trained LogisticRegression
|
| 191 |
+
"cnn_heads": dict, # {"ResNet-18": ..., "MobileNetV3": ..., "MobileViT-XXS": ...}
|
| 192 |
+
"rce_dets": list, # detection results on test_image
|
| 193 |
+
"cnn_dets": dict, # detection results per CNN model
|
| 194 |
+
"source": str, # "middlebury" or "custom"
|
| 195 |
+
"scene_name": str, # Middlebury scene name (if source == "middlebury")
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
# Generalisation pipeline dict keys — same structure minus calib/disparity_gt
|
| 199 |
+
gen_pipeline = {
|
| 200 |
+
"train_image": np.ndarray, # im0.png from training scene variant
|
| 201 |
+
"test_image": np.ndarray, # im0.png from test scene variant
|
| 202 |
+
"roi": dict,
|
| 203 |
+
"crop": np.ndarray,
|
| 204 |
+
"crop_aug": list,
|
| 205 |
+
"active_modules": list,
|
| 206 |
+
"rce_head": object,
|
| 207 |
+
"cnn_heads": dict,
|
| 208 |
+
"rce_dets": list,
|
| 209 |
+
"cnn_dets": dict,
|
| 210 |
+
"source": str, # "middlebury" or "custom"
|
| 211 |
+
"scene_group": str, # e.g. "artroom" (Middlebury only)
|
| 212 |
+
"train_scene": str, # e.g. "artroom1" (Middlebury only)
|
| 213 |
+
"test_scene": str, # e.g. "artroom2" (Middlebury only)
|
| 214 |
+
}
|
| 215 |
+
```
|
| 216 |
+
|
| 217 |
+
---
|
| 218 |
+
|
| 219 |
+
## Part 2 — Middlebury Loader Utility
|
| 220 |
+
|
| 221 |
+
Create `utils/middlebury_loader.py` with the following functions:
|
| 222 |
+
|
| 223 |
+
### `scan_dataset_root(root_path: str) -> list[str]`
|
| 224 |
+
- Scan root directory for valid scene folders
|
| 225 |
+
- A valid scene must contain: `im0.png`, `im1.png`, `calib.txt`
|
| 226 |
+
- Return sorted list of scene names
|
| 227 |
+
|
| 228 |
+
### `get_scene_groups(root_path: str) -> dict`
|
| 229 |
+
- Scan all valid scenes and group them by scene base name (strip trailing digit)
|
| 230 |
+
- e.g. `artroom1`, `artroom2` → group `"artroom"`
|
| 231 |
+
- Return dict: `{"artroom": ["artroom1", "artroom2"], "curule": ["curule1", "curule2", "curule3"], ...}`
|
| 232 |
+
- Used by Tab 2 to present scene group selection then variant selection
|
| 233 |
+
|
| 234 |
+
### `get_available_views(scene_path: str) -> list[dict]`
|
| 235 |
+
- This dataset has NO multi-exposure variants (no im0E.png etc.)
|
| 236 |
+
- Function kept for future compatibility but always returns single entry:
|
| 237 |
+
`[{"suffix": "", "label": "Primary (im0/im1)"}]`
|
| 238 |
+
|
| 239 |
+
### `load_stereo_pair(scene_path: str, view_suffix: str = '') -> dict`
|
| 240 |
+
- Load `im0{suffix}.png` as left image (train_image)
|
| 241 |
+
- Load `im1{suffix}.png` as right image (test_image)
|
| 242 |
+
- Load and parse `calib.txt`
|
| 243 |
+
- Load `disp0.pfm` if it exists (else None)
|
| 244 |
+
- Return dict with keys: `left`, `right`, `calib`, `disparity_gt`
|
| 245 |
+
|
| 246 |
+
### `load_single_view(scene_path: str, view_suffix: str) -> np.ndarray`
|
| 247 |
+
- Load and return a single image: `im0{suffix}.png`
|
| 248 |
+
- Used by generalisation tab when selecting individual views
|
| 249 |
+
|
| 250 |
+
### `parse_calib(calib_path: str) -> dict`
|
| 251 |
+
Parse Middlebury `calib.txt` format:
|
| 252 |
+
```
|
| 253 |
+
cam0=[fx 0 cx; 0 fy cy; 0 0 1]
|
| 254 |
+
cam1=[fx 0 cx; 0 fy cy; 0 0 1]
|
| 255 |
+
doffs=x_offset
|
| 256 |
+
baseline=Bmm
|
| 257 |
+
width=W
|
| 258 |
+
height=H
|
| 259 |
+
ndisp=N
|
| 260 |
+
vmin=v
|
| 261 |
+
vmax=v
|
| 262 |
+
```
|
| 263 |
+
Extract and return: `{"fx": float, "baseline": float, "doffs": float, "width": int, "height": int, "ndisp": int}`
|
| 264 |
+
|
| 265 |
+
Use regex to extract `fx` from the camera matrix string: first numeric value after `cam0=[`.
|
| 266 |
+
|
| 267 |
+
### `load_pfm(filepath: str) -> np.ndarray`
|
| 268 |
+
Load PFM (Portable FloatMap) file:
|
| 269 |
+
- Read header line (`PF` = color, `Pf` = grayscale)
|
| 270 |
+
- Read dimensions line
|
| 271 |
+
- Read scale factor (negative = little-endian)
|
| 272 |
+
- Read float32 binary data
|
| 273 |
+
- Flip vertically (PFM origin is bottom-left)
|
| 274 |
+
- Return numpy array
|
| 275 |
+
|
| 276 |
+
### Dataset Root Resolution
|
| 277 |
+
|
| 278 |
+
The dataset is **bundled directly in the repo** at `./data/middlebury/`. No user configuration needed.
|
| 279 |
+
|
| 280 |
+
```python
|
| 281 |
+
DEFAULT_MIDDLEBURY_ROOT = "./data/middlebury"
|
| 282 |
+
```
|
| 283 |
+
|
| 284 |
+
If the path does not exist or contains no valid scenes, show a clear error. This should not happen in normal deployment since the data is bundled.
|
| 285 |
+
|
| 286 |
+
### Bundled Scenes Reference
|
| 287 |
+
|
| 288 |
+
The following 10 scene folders are bundled, forming 4 scene groups:
|
| 289 |
+
|
| 290 |
+
```python
|
| 291 |
+
BUNDLED_SCENES = {
|
| 292 |
+
"artroom": ["artroom1", "artroom2"],
|
| 293 |
+
"curule": ["curule1", "curule2", "curule3"],
|
| 294 |
+
"skates": ["skates1", "skates2"],
|
| 295 |
+
"skiboots": ["skiboots1", "skiboots2", "skiboots3"],
|
| 296 |
+
}
|
| 297 |
+
```
|
| 298 |
+
|
| 299 |
+
Each folder contains exactly: `im0.png`, `im1.png`, `disp0.pfm`, `disp1.pfm`, `calib.txt`.
|
| 300 |
+
|
| 301 |
+
There are **no multi-exposure variants** (no `im0E.png` etc.) — the scene groups ARE the multi-condition variants. `artroom1` and `artroom2` are different captures of the same artroom scene.
|
| 302 |
+
|
| 303 |
+
---
|
| 304 |
+
|
| 305 |
+
## Part 3 — Tab 1: Stereo Data Lab
|
| 306 |
+
|
| 307 |
+
Create `tabs/stereo/data_lab.py` with a `render()` function.
|
| 308 |
+
|
| 309 |
+
### Data Source Selection
|
| 310 |
+
|
| 311 |
+
```python
|
| 312 |
+
st.header("🧪 Data Lab — Stereo + Depth")
|
| 313 |
+
st.info("**How this works:** Define your object of interest in the LEFT image. The system trains on it and attempts to recognise it in the RIGHT image — a genuinely different viewpoint.")
|
| 314 |
+
|
| 315 |
+
source = st.radio(
|
| 316 |
+
"Data source",
|
| 317 |
+
["📦 Middlebury Dataset", "📁 Upload your own files"],
|
| 318 |
+
horizontal=True
|
| 319 |
+
)
|
| 320 |
+
```
|
| 321 |
+
|
| 322 |
+
### If Middlebury Selected
|
| 323 |
+
|
| 324 |
+
```
|
| 325 |
+
1. Scan dataset root → show selectbox of available scenes
|
| 326 |
+
2. Auto-load im0.png (train/left) and im1.png (test/right)
|
| 327 |
+
3. Auto-load calib.txt → parse parameters
|
| 328 |
+
4. Auto-load disp0.pfm if present
|
| 329 |
+
5. Display LEFT image (train) and RIGHT image (test) side by side
|
| 330 |
+
6. Show parsed calibration parameters in an expander
|
| 331 |
+
7. Show ground truth disparity colormap if available
|
| 332 |
+
```
|
| 333 |
+
|
| 334 |
+
Show a clear visual label:
|
| 335 |
+
- Left image labeled: `🟦 TRAIN IMAGE (Left)`
|
| 336 |
+
- Right image labeled: `🟥 TEST IMAGE (Right)`
|
| 337 |
+
|
| 338 |
+
### If Custom Upload Selected
|
| 339 |
+
|
| 340 |
+
```
|
| 341 |
+
- Left image uploader (png/jpg) → labeled as TRAIN IMAGE
|
| 342 |
+
- Right image uploader (png/jpg) → labeled as TEST IMAGE
|
| 343 |
+
- Calibration file uploader (txt) — REQUIRED for depth estimation
|
| 344 |
+
- PFM ground truth uploader (pfm) — optional, disables depth evaluation if missing
|
| 345 |
+
- If calibration file not provided: show warning "Depth estimation will be disabled"
|
| 346 |
+
```
|
| 347 |
+
|
| 348 |
+
### ROI Definition
|
| 349 |
+
|
| 350 |
+
After images are loaded (either source):
|
| 351 |
+
|
| 352 |
+
```
|
| 353 |
+
1. Display LEFT (train) image only for ROI definition
|
| 354 |
+
2. Use streamlit-cropper or manual coordinate inputs for ROI selection
|
| 355 |
+
- If streamlit-cropper available: use it
|
| 356 |
+
- Fallback: four number_input widgets for x, y, w, h
|
| 357 |
+
3. Text input for class label (default: "object")
|
| 358 |
+
4. Show cropped ROI preview
|
| 359 |
+
5. "Lock Data Lab" button → saves everything to st.session_state["stereo_pipeline"]
|
| 360 |
+
```
|
| 361 |
+
|
| 362 |
+
### Data Augmentation
|
| 363 |
+
|
| 364 |
+
After ROI is locked, show augmentation controls (preserve existing augmentation logic):
|
| 365 |
+
- Rotation, brightness, contrast, noise, blur, flip
|
| 366 |
+
- Preview augmented crops
|
| 367 |
+
- "Apply Augmentation" button
|
| 368 |
+
|
| 369 |
+
### What Gets Saved to Session State
|
| 370 |
+
|
| 371 |
+
```python
|
| 372 |
+
st.session_state["stereo_pipeline"] = {
|
| 373 |
+
"train_image": left_image, # numpy array, BGR
|
| 374 |
+
"test_image": right_image, # numpy array, BGR ← KEY FIX
|
| 375 |
+
"calib": calib_dict, # parsed params or None
|
| 376 |
+
"disparity_gt": disp_gt, # numpy array or None
|
| 377 |
+
"roi": {"x":x, "y":y, "w":w, "h":h, "label":label},
|
| 378 |
+
"crop": cropped_roi,
|
| 379 |
+
"crop_aug": augmented_list,
|
| 380 |
+
"source": "middlebury" or "custom",
|
| 381 |
+
"scene_name": scene_name or "",
|
| 382 |
+
}
|
| 383 |
+
```
|
| 384 |
+
|
| 385 |
+
---
|
| 386 |
+
|
| 387 |
+
## Part 4 — Tab 2: Generalisation Data Lab
|
| 388 |
+
|
| 389 |
+
Create `tabs/generalisation/data_lab.py` with a `render()` function.
|
| 390 |
+
|
| 391 |
+
### Key Difference From Stereo Data Lab
|
| 392 |
+
|
| 393 |
+
- No calibration file
|
| 394 |
+
- No depth estimation
|
| 395 |
+
- Two images can be completely independent OR different views from Middlebury
|
| 396 |
+
- Goal is testing appearance generalisation, not stereo geometry
|
| 397 |
+
|
| 398 |
+
### Data Source Selection
|
| 399 |
+
|
| 400 |
+
```python
|
| 401 |
+
st.header("🧪 Data Lab — Generalisation")
|
| 402 |
+
st.info("**How this works:** Train on one image, test on a completely different image of the same object. No stereo geometry — pure recognition generalisation.")
|
| 403 |
+
|
| 404 |
+
source = st.radio(
|
| 405 |
+
"Data source",
|
| 406 |
+
["📦 Middlebury Multi-View", "📁 Upload your own files"],
|
| 407 |
+
horizontal=True
|
| 408 |
+
)
|
| 409 |
+
```
|
| 410 |
+
|
| 411 |
+
### If Middlebury Selected
|
| 412 |
+
|
| 413 |
+
```
|
| 414 |
+
1. Show scene group selector: ["artroom", "curule", "skates", "skiboots"]
|
| 415 |
+
2. Based on selected group, show available variants:
|
| 416 |
+
- artroom → [artroom1, artroom2]
|
| 417 |
+
- curule → [curule1, curule2, curule3]
|
| 418 |
+
- skates → [skates1, skates2]
|
| 419 |
+
- skiboots → [skiboots1, skiboots2, skiboots3]
|
| 420 |
+
3. Two selectboxes:
|
| 421 |
+
- "Training scene" → user picks one variant (e.g. artroom1)
|
| 422 |
+
- "Test scene" → user picks a DIFFERENT variant (e.g. artroom2)
|
| 423 |
+
- Validate: training scene ≠ test scene, show error if same selected
|
| 424 |
+
4. Load train_scene/im0.png as train_image
|
| 425 |
+
5. Load test_scene/im0.png as test_image
|
| 426 |
+
(NOTE: both are LEFT images im0.png, from different scene variants)
|
| 427 |
+
6. Display both side by side with clear labels
|
| 428 |
+
```
|
| 429 |
+
|
| 430 |
+
Show labels:
|
| 431 |
+
- Train image: `🟦 TRAIN IMAGE (artroom1)`
|
| 432 |
+
- Test image: `🟥 TEST IMAGE (artroom2)`
|
| 433 |
+
|
| 434 |
+
Also show an explanation: *"Both images show the same scene type captured under different conditions. The model trains on one variant and must recognise the same object class in the other — testing genuine appearance generalisation."*
|
| 435 |
+
|
| 436 |
+
### If Custom Upload Selected
|
| 437 |
+
|
| 438 |
+
```
|
| 439 |
+
- Train image uploader → labeled TRAIN IMAGE
|
| 440 |
+
- Test image uploader → labeled TEST IMAGE
|
| 441 |
+
- No calibration, no PFM needed
|
| 442 |
+
- Simple, low barrier
|
| 443 |
+
```
|
| 444 |
+
|
| 445 |
+
### ROI Definition and Augmentation
|
| 446 |
+
|
| 447 |
+
Same as stereo data lab but on the TRAIN image only. Save to `st.session_state["gen_pipeline"]` with same key structure (minus calib and disparity_gt).
|
| 448 |
+
|
| 449 |
+
---
|
| 450 |
+
|
| 451 |
+
## Part 5 — Migrate Existing Pipeline Stages
|
| 452 |
+
|
| 453 |
+
The existing pages (Feature Lab, Model Tuning, Localization, Detection, Evaluation, Stereo Depth) must be migrated into the new `tabs/` structure.
|
| 454 |
+
|
| 455 |
+
### Migration Rules
|
| 456 |
+
|
| 457 |
+
1. **Read each existing page file** before migrating it:
|
| 458 |
+
- `pages/2_Data_Lab.py` → split into `tabs/stereo/data_lab.py` and `tabs/generalisation/data_lab.py`
|
| 459 |
+
- `pages/3_Feature_Lab.py` → `tabs/stereo/feature_lab.py` (adapt for `tabs/generalisation/feature_lab.py`)
|
| 460 |
+
- `pages/4_Model_Tuning.py` → `tabs/stereo/model_tuning.py`
|
| 461 |
+
- `pages/5_Localization_Lab.py` → `tabs/stereo/localization.py`
|
| 462 |
+
- `pages/6_RealTime_Detection.py` → `tabs/stereo/detection.py` ← apply data leakage fix here
|
| 463 |
+
- `pages/7_Evaluation.py` → `tabs/stereo/evaluation.py`
|
| 464 |
+
- `pages/8_Stereo_Geometry.py` → `tabs/stereo/stereo_depth.py`
|
| 465 |
+
|
| 466 |
+
2. **Each page becomes a module** with a `render()` function — wrap all existing page code inside `def render(): ...`
|
| 467 |
+
|
| 468 |
+
3. **Update all session state reads** — the existing pages use `st.session_state.get("pipeline_data", {})` or similar flat keys. Replace with namespaced dicts:
|
| 469 |
+
- Stereo stages: `st.session_state.get("stereo_pipeline", {})`
|
| 470 |
+
- Generalisation stages: `st.session_state.get("gen_pipeline", {})`
|
| 471 |
+
|
| 472 |
+
4. **Preserve all imports from `src/`** — every `from src.xxx import yyy` in the existing pages must be kept exactly as-is
|
| 473 |
+
|
| 474 |
+
5. **Detection stage fix** — the critical data leakage fix:
|
| 475 |
+
- OLD: detection runs on the same image used for training ROI definition
|
| 476 |
+
- NEW: `image_to_scan = pipe["test_image"]` ← the other image
|
| 477 |
+
|
| 478 |
+
### Stage Guard Template
|
| 479 |
+
|
| 480 |
+
```python
|
| 481 |
+
def render():
|
| 482 |
+
pipe = st.session_state.get("stereo_pipeline", {}) # or gen_pipeline
|
| 483 |
+
|
| 484 |
+
required_keys = ["train_image", "test_image", "crop_aug"]
|
| 485 |
+
missing = [k for k in required_keys if k not in pipe]
|
| 486 |
+
if missing:
|
| 487 |
+
st.warning("⚠️ Complete the **Data Lab** stage first.")
|
| 488 |
+
st.info("Go to: 📷 Stereo + Depth → 🧪 Data Lab")
|
| 489 |
+
st.stop()
|
| 490 |
+
```
|
| 491 |
+
|
| 492 |
+
### Specific Migration Notes Per Stage
|
| 493 |
+
|
| 494 |
+
**Feature Lab (Stage 2):**
|
| 495 |
+
- Visualise features extracted from `pipe["crop"]` (from train_image)
|
| 496 |
+
- No changes needed beyond session state key updates
|
| 497 |
+
|
| 498 |
+
**Model Tuning (Stage 3):**
|
| 499 |
+
- Training data comes from `pipe["crop_aug"]` (augmented crops of train_image ROI)
|
| 500 |
+
- Negatives sampled from `pipe["train_image"]` (not test_image)
|
| 501 |
+
- No changes needed beyond session state key updates
|
| 502 |
+
|
| 503 |
+
**Detection (Stage 5) — CRITICAL FIX:**
|
| 504 |
+
- Run sliding window on `pipe["test_image"]` NOT `pipe["train_image"]`
|
| 505 |
+
- Add a visual reminder in the UI: *"Running detection on TEST image (right/second image)"*
|
| 506 |
+
- For stereo: show test_image (right) with detection results
|
| 507 |
+
- For generalisation: show test_image (different exposure) with detection results
|
| 508 |
+
|
| 509 |
+
**Stereo Depth (Stage 7 — stereo only):**
|
| 510 |
+
- Requires `pipe["calib"]` to be not None
|
| 511 |
+
- If calib is None (custom upload without calibration): show warning and disable depth computation
|
| 512 |
+
- If disparity_gt is None: skip ground truth comparison, show note
|
| 513 |
+
- Otherwise: preserve existing StereoSGBM logic entirely
|
| 514 |
+
|
| 515 |
+
---
|
| 516 |
+
|
| 517 |
+
## Part 6 — Session Status Widget Update
|
| 518 |
+
|
| 519 |
+
Update the session status display in `app.py` (home page) to show status for BOTH pipelines:
|
| 520 |
+
|
| 521 |
+
```python
|
| 522 |
+
st.header("📋 Session Status")
|
| 523 |
+
|
| 524 |
+
col1, col2 = st.columns(2)
|
| 525 |
+
|
| 526 |
+
with col1:
|
| 527 |
+
st.subheader("📷 Stereo Pipeline")
|
| 528 |
+
stereo = st.session_state.get("stereo_pipeline", {})
|
| 529 |
+
stereo_checks = {
|
| 530 |
+
"Data loaded": "train_image" in stereo and "test_image" in stereo,
|
| 531 |
+
"ROI defined": "roi" in stereo,
|
| 532 |
+
"Augmentation done": "crop_aug" in stereo,
|
| 533 |
+
"Modules locked": "active_modules" in stereo,
|
| 534 |
+
"Models trained": "rce_head" in stereo,
|
| 535 |
+
"Detection run": "rce_dets" in stereo,
|
| 536 |
+
}
|
| 537 |
+
for label, done in stereo_checks.items():
|
| 538 |
+
st.markdown(f"{'✅' if done else '⬜'} {label}")
|
| 539 |
+
|
| 540 |
+
with col2:
|
| 541 |
+
st.subheader("🌍 Generalisation Pipeline")
|
| 542 |
+
gen = st.session_state.get("gen_pipeline", {})
|
| 543 |
+
gen_checks = {
|
| 544 |
+
"Data loaded": "train_image" in gen and "test_image" in gen,
|
| 545 |
+
"ROI defined": "roi" in gen,
|
| 546 |
+
"Augmentation done": "crop_aug" in gen,
|
| 547 |
+
"Modules locked": "active_modules" in gen,
|
| 548 |
+
"Models trained": "rce_head" in gen,
|
| 549 |
+
"Detection run": "rce_dets" in gen,
|
| 550 |
+
}
|
| 551 |
+
for label, done in gen_checks.items():
|
| 552 |
+
st.markdown(f"{'✅' if done else '⬜'} {label}")
|
| 553 |
+
```
|
| 554 |
+
|
| 555 |
+
---
|
| 556 |
+
|
| 557 |
+
## Part 7 — Shared Utility Modules
|
| 558 |
+
|
| 559 |
+
All core logic lives in `src/` and must be imported identically by both `tabs/stereo/` and `tabs/generalisation/` stages. Do not duplicate or move anything from `src/`.
|
| 560 |
+
|
| 561 |
+
Key imports used by the stage files:
|
| 562 |
+
```python
|
| 563 |
+
from src.detectors.rce.features import RCEExtractor # RCE feature extraction
|
| 564 |
+
from src.detectors.resnet import ResNetDetector # ResNet-18
|
| 565 |
+
from src.detectors.mobilenet import MobileNetDetector # MobileNetV3
|
| 566 |
+
from src.detectors.mobilevit import MobileViTDetector # MobileViT-XXS
|
| 567 |
+
from src.detectors.orb import ORBDetector # ORB keypoint matching
|
| 568 |
+
from src.detectors.yolo import YOLODetector # YOLOv8
|
| 569 |
+
from src.localization import LocalizationStrategy # All 5 localization strategies
|
| 570 |
+
from src.models import load_model # Model loading from models/
|
| 571 |
+
from src.config import * # App constants
|
| 572 |
+
```
|
| 573 |
+
|
| 574 |
+
The `models/` directory contains pre-trained weights referenced by `src/models.py`. Do not move or rename any files in `models/`.
|
| 575 |
+
|
| 576 |
+
---
|
| 577 |
+
|
| 578 |
+
## Part 8 — Landing Page (Home)
|
| 579 |
+
|
| 580 |
+
The existing landing page content in `app.py` must be preserved and rendered when `top_nav == "🏠 Home"`. Extract it into a `render_home()` function within `app.py`.
|
| 581 |
+
|
| 582 |
+
Update the Pipeline Overview section to reflect the new two-pipeline structure:
|
| 583 |
+
|
| 584 |
+
```
|
| 585 |
+
🗺️ Pipeline Overview
|
| 586 |
+
|
| 587 |
+
This platform provides two evaluation pipelines:
|
| 588 |
+
|
| 589 |
+
📷 Stereo + Depth (7 stages)
|
| 590 |
+
Train on the LEFT image, detect in the RIGHT image, estimate metric depth.
|
| 591 |
+
Evaluates RCE in a constrained stereo-vision scenario.
|
| 592 |
+
|
| 593 |
+
🌍 Generalisation (6 stages)
|
| 594 |
+
Train on one view/exposure, detect in a different view/exposure.
|
| 595 |
+
Evaluates RCE's robustness to appearance variation.
|
| 596 |
+
|
| 597 |
+
Both pipelines compare: RCE · ResNet-18 · MobileNetV3-Small · MobileViT-XXS · ORB
|
| 598 |
+
```
|
| 599 |
+
|
| 600 |
+
Update the bottom caption: *"Navigate using the sidebar → Choose a pipeline to begin"*
|
| 601 |
+
|
| 602 |
+
---
|
| 603 |
+
|
| 604 |
+
## Implementation Order
|
| 605 |
+
|
| 606 |
+
Implement in this exact order to avoid breaking dependencies:
|
| 607 |
+
|
| 608 |
+
1. `utils/middlebury_loader.py` — no dependencies, can be tested in isolation
|
| 609 |
+
2. `app.py` — routing shell, import stubs for tabs not yet created
|
| 610 |
+
3. `tabs/stereo/data_lab.py` — foundation of stereo pipeline
|
| 611 |
+
4. `tabs/generalisation/data_lab.py` — foundation of generalisation pipeline
|
| 612 |
+
5. Migrate existing stages into `tabs/stereo/` — feature_lab, model_tuning, localization, detection (with fix), evaluation, stereo_depth
|
| 613 |
+
6. Create `tabs/generalisation/` stages — reuse stereo logic with gen_pipeline session keys
|
| 614 |
+
7. Update home page session status widget
|
| 615 |
+
|
| 616 |
+
---
|
| 617 |
+
|
| 618 |
+
## What NOT To Change
|
| 619 |
+
|
| 620 |
+
- **Entire `src/` directory** — all detector logic, RCE, CNN backbones, ORB, localization, model loading
|
| 621 |
+
- **Entire `models/` directory** — pre-trained weights
|
| 622 |
+
- **Entire `data/` directory** — Middlebury dataset
|
| 623 |
+
- **`notebooks/`, `training/`, `scripts/`** — development artifacts, leave untouched
|
| 624 |
+
- **`Dockerfile`, `packages.txt`, `requirements.txt`** — deployment config (add `streamlit-cropper` to requirements.txt only)
|
| 625 |
+
- The visual style and markdown descriptions of each stage
|
| 626 |
+
- The models description tabs on the home page (RCE, ResNet-18, MobileNetV3, MobileViT-XXS)
|
| 627 |
+
- The depth estimation explanation and LaTeX formula
|
| 628 |
+
|
| 629 |
+
---
|
| 630 |
+
|
| 631 |
+
## Dependencies To Add
|
| 632 |
+
|
| 633 |
+
Add to `requirements.txt` if not already present:
|
| 634 |
+
```
|
| 635 |
+
streamlit-cropper # for ROI selection (optional but preferred)
|
| 636 |
+
```
|
| 637 |
+
|
| 638 |
+
No other new dependencies are needed. The Middlebury loader uses only `numpy`, `opencv-python`, `re`, `os`, and `pathlib` — all already present.
|
| 639 |
+
|
| 640 |
+
---
|
| 641 |
+
|
| 642 |
+
## Testing Checklist
|
| 643 |
+
|
| 644 |
+
After implementation, verify:
|
| 645 |
+
|
| 646 |
+
- [ ] Home page renders with both pipeline status widgets
|
| 647 |
+
- [ ] Clicking "Stereo + Depth" shows 7-stage sub-navigation
|
| 648 |
+
- [ ] Clicking "Generalisation" shows 6-stage sub-navigation
|
| 649 |
+
- [ ] Clicking stages before Data Lab shows guard warning
|
| 650 |
+
- [ ] Middlebury loader finds scenes from `./data/middlebury/`
|
| 651 |
+
- [ ] Stereo Data Lab correctly assigns LEFT → train_image, RIGHT → test_image
|
| 652 |
+
- [ ] Generalisation Data Lab correctly assigns View1 → train_image, View2 → test_image
|
| 653 |
+
- [ ] Detection stage uses `pipe["test_image"]` in BOTH pipelines
|
| 654 |
+
- [ ] Stereo and Generalisation pipelines do not share session state
|
| 655 |
+
- [ ] Depth estimation gracefully disabled when calib is None
|
| 656 |
+
- [ ] All existing stage logic works after migration
|
app.py
CHANGED
|
@@ -3,195 +3,211 @@ import streamlit as st
|
|
| 3 |
st.set_page_config(page_title="Perception Benchmark", layout="wide", page_icon="🦅")
|
| 4 |
|
| 5 |
# ===================================================================
|
| 6 |
-
#
|
| 7 |
# ===================================================================
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
-
st.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
# ===================================================================
|
| 15 |
-
#
|
| 16 |
# ===================================================================
|
| 17 |
-
|
| 18 |
-
st.
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
""")
|
| 22 |
-
|
| 23 |
-
stages = [
|
| 24 |
-
("🧪", "1 · Data Lab", "Upload a stereo image pair, camera calibration file, and two PFM ground-truth depth maps. "
|
| 25 |
-
"Define one or more object ROIs (bounding boxes) with class labels, then apply live data augmentation "
|
| 26 |
-
"(brightness, contrast, rotation, noise, blur, shift, flip). "
|
| 27 |
-
"All assets are locked into session state — nothing is written to disk."),
|
| 28 |
-
("🔬", "2 · Feature Lab", "Toggle RCE physics modules (Intensity · Sobel · Spectral) to build a modular "
|
| 29 |
-
"feature vector. Compare it live against CNN activation maps extracted from a "
|
| 30 |
-
"frozen backbone via forward hooks. Lock your active module configuration."),
|
| 31 |
-
("⚙️", "3 · Model Tuning", "Train lightweight **heads** on your session data (augmented crop = positives, "
|
| 32 |
-
"random non-overlapping patches = negatives). Compare three paradigms side by side: "
|
| 33 |
-
"RCE (with feature importance), CNN (with activation overlay), and ORB (keypoint matching)."),
|
| 34 |
-
("🔍", "4 · Localization Lab", "Compare **five localization strategies** on top of your trained head: "
|
| 35 |
-
"Exhaustive Sliding Window, Image Pyramid (multi-scale), Coarse-to-Fine "
|
| 36 |
-
"hierarchical search, Contour Proposals (edge-driven), and Template "
|
| 37 |
-
"Matching (cross-correlation)."),
|
| 38 |
-
("🎯", "5 · Real-Time Detection","Run a **sliding window** across the right image using RCE, CNN, and ORB "
|
| 39 |
-
"simultaneously. Watch the scan live, then compare bounding boxes, "
|
| 40 |
-
"confidence heatmaps, and latency across all three methods."),
|
| 41 |
-
("📈", "6 · Evaluation", "Quantitative evaluation with **confusion matrices**, **precision-recall curves**, "
|
| 42 |
-
"and **F1 scores** per method. Ground truth is derived from your Data Lab ROIs."),
|
| 43 |
-
("📐", "7 · Stereo Geometry", "Compute a disparity map with **StereoSGBM**, convert it to metric depth "
|
| 44 |
-
"using the stereo formula $Z = fB/(d+d_{\\text{offs}})$, then read depth "
|
| 45 |
-
"directly at every detected bounding box. Compare against PFM ground truth."),
|
| 46 |
-
]
|
| 47 |
-
|
| 48 |
-
for icon, title, desc in stages:
|
| 49 |
-
with st.container(border=True):
|
| 50 |
-
c1, c2 = st.columns([1, 12])
|
| 51 |
-
c1.markdown(f"## {icon}")
|
| 52 |
-
c2.markdown(f"**{title}** \n{desc}")
|
| 53 |
-
|
| 54 |
-
st.divider()
|
| 55 |
|
| 56 |
-
|
| 57 |
-
# Models
|
| 58 |
-
# ===================================================================
|
| 59 |
-
st.header("🧠 Models Used")
|
| 60 |
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
st.markdown("### 🧬 RCE — Relative Contextual Encoding")
|
| 66 |
st.markdown("""
|
| 67 |
-
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
| Module | Input | Operation |
|
| 71 |
|--------|-------|-----------|
|
| 72 |
| **Intensity** | Grayscale | Pixel-value histogram (global appearance) |
|
| 73 |
| **Sobel** | Gradient magnitude | Edge strength distribution (texture) |
|
| 74 |
| **Spectral** | FFT log-magnitude | Frequency content (pattern / structure) |
|
| 75 |
-
|
| 76 |
-
**
|
| 77 |
-
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
**
|
| 89 |
-
**Pre-training:** ImageNet-1k (1.28 M images, 1 000 classes)
|
| 90 |
-
**Backbone output:** 512-dimensional embedding (after `avgpool`)
|
| 91 |
**Head:** LogisticRegression trained on your session data
|
| 92 |
|
| 93 |
-
**
|
| 94 |
-
- 18 layers with residual (skip) connections
|
| 95 |
-
- Residual blocks prevent vanishing gradients in deeper networks
|
| 96 |
-
- `layer4` is hooked for activation map visualisation
|
| 97 |
-
|
| 98 |
-
**In this app:** The entire backbone is **frozen** (`requires_grad=False`).
|
| 99 |
Only the lightweight head adapts to your specific object.
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
with tab_mobilenet:
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
**Source:** PyTorch Hub (`torchvision.models.MobileNet_V3_Small_Weights.DEFAULT`)
|
| 106 |
-
**Pre-training:** ImageNet-1k
|
| 107 |
-
**Backbone output:** 576-dimensional embedding
|
| 108 |
**Head:** LogisticRegression trained on your session data
|
| 109 |
|
| 110 |
-
**
|
| 111 |
-
|
| 112 |
-
- Hard-Swish / Hard-Sigmoid activations (hardware-friendly)
|
| 113 |
-
- Squeeze-and-Excitation (SE) blocks for channel attention
|
| 114 |
-
- Designed for **edge / mobile inference** — ~2.5 M parameters
|
| 115 |
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
st.markdown("""
|
| 123 |
-
**Source:** timm — `mobilevit_xxs.cvnets_in1k` (Apple Research, 2022)
|
| 124 |
-
**Pre-training:** ImageNet-1k
|
| 125 |
-
**Backbone output:** 320-dimensional embedding (`num_classes=0`)
|
| 126 |
**Head:** LogisticRegression trained on your session data
|
| 127 |
|
| 128 |
-
**
|
| 129 |
-
|
| 130 |
-
global self-attention for long-range context
|
| 131 |
-
- MobileNetV2 stem + MobileViT blocks (attention on non-overlapping patches)
|
| 132 |
-
- Only ~1.3 M parameters — smallest of the three
|
| 133 |
-
|
| 134 |
-
**In this app:** The final transformer stage `stages[-1]` is hooked.
|
| 135 |
-
Slower than MobileNetV3 but captures global structure.
|
| 136 |
-
""")
|
| 137 |
|
| 138 |
-
st.divider()
|
| 139 |
|
| 140 |
-
#
|
| 141 |
-
# Depth Estimation
|
| 142 |
-
#
|
| 143 |
-
st.header("📐 Stereo Depth Estimation")
|
| 144 |
|
| 145 |
-
col_d1, col_d2 = st.columns(2)
|
| 146 |
-
with col_d1:
|
| 147 |
-
|
| 148 |
**Algorithm:** `cv2.StereoSGBM` (Semi-Global Block Matching)
|
| 149 |
|
| 150 |
SGBM minimises a global energy function combining:
|
| 151 |
- Data cost (pixel intensity difference)
|
| 152 |
- Smoothness penalty (P1, P2 regularisation)
|
| 153 |
|
| 154 |
-
It processes multiple horizontal and diagonal scan-line passes,
|
| 155 |
making it significantly more accurate than basic block matching.
|
| 156 |
-
|
| 157 |
-
with col_d2:
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
- $
|
| 164 |
-
- $B$ — baseline (mm, from calibration file)
|
| 165 |
-
- $d$ — disparity (pixels)
|
| 166 |
- $d_\\text{offs}$ — optical-center offset between cameras
|
| 167 |
-
|
|
|
|
|
|
|
|
|
|
| 168 |
|
| 169 |
-
st.divider()
|
| 170 |
|
| 171 |
# ===================================================================
|
| 172 |
-
#
|
| 173 |
# ===================================================================
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
checks = {
|
| 179 |
-
"Data Lab locked": "left" in pipe,
|
| 180 |
-
"Crop defined": "crop" in pipe,
|
| 181 |
-
"Augmentation applied": "crop_aug" in pipe,
|
| 182 |
-
"Active modules locked": "active_modules" in st.session_state,
|
| 183 |
-
"RCE head trained": "rce_head" in st.session_state,
|
| 184 |
-
"CNN head trained": any(f"cnn_head_{n}" in st.session_state
|
| 185 |
-
for n in ["ResNet-18", "MobileNetV3", "MobileViT-XXS"]),
|
| 186 |
-
"RCE detections ready": "rce_dets" in st.session_state,
|
| 187 |
-
"CNN detections ready": "cnn_dets" in st.session_state,
|
| 188 |
-
}
|
| 189 |
-
|
| 190 |
-
cols = st.columns(4)
|
| 191 |
-
for i, (label, done) in enumerate(checks.items()):
|
| 192 |
-
cols[i % 4].markdown(
|
| 193 |
-
f"{'✅' if done else '⬜'} {'~~' if not done else ''}{label}{'~~' if not done else ''}"
|
| 194 |
-
)
|
| 195 |
-
|
| 196 |
-
st.divider()
|
| 197 |
-
st.caption("Navigate using the sidebar → Start with **🧪 Data Lab**")
|
|
|
|
| 3 |
st.set_page_config(page_title="Perception Benchmark", layout="wide", page_icon="🦅")
|
| 4 |
|
| 5 |
# ===================================================================
|
| 6 |
+
# Routing — Sidebar Navigation
|
| 7 |
# ===================================================================
|
| 8 |
+
PIPELINES = {
|
| 9 |
+
"🏠 Home": None,
|
| 10 |
+
"📐 Stereo + Depth": {
|
| 11 |
+
"🧪 Data Lab": "tabs.stereo.data_lab",
|
| 12 |
+
"🔬 Feature Lab": "tabs.stereo.feature_lab",
|
| 13 |
+
"⚙️ Model Tuning": "tabs.stereo.model_tuning",
|
| 14 |
+
"🔍 Localization Lab": "tabs.stereo.localization",
|
| 15 |
+
"🎯 Real-Time Detection":"tabs.stereo.detection",
|
| 16 |
+
"📈 Evaluation": "tabs.stereo.evaluation",
|
| 17 |
+
"📐 Stereo Geometry": "tabs.stereo.stereo_depth",
|
| 18 |
+
},
|
| 19 |
+
"🌍 Generalisation": {
|
| 20 |
+
"🧪 Data Lab": "tabs.generalisation.data_lab",
|
| 21 |
+
"🔬 Feature Lab": "tabs.generalisation.feature_lab",
|
| 22 |
+
"⚙️ Model Tuning": "tabs.generalisation.model_tuning",
|
| 23 |
+
"🔍 Localization Lab": "tabs.generalisation.localization",
|
| 24 |
+
"🎯 Real-Time Detection":"tabs.generalisation.detection",
|
| 25 |
+
"📈 Evaluation": "tabs.generalisation.evaluation",
|
| 26 |
+
},
|
| 27 |
+
}
|
| 28 |
|
| 29 |
+
st.sidebar.title("🦅 Recognition BenchMark")
|
| 30 |
+
pipeline_choice = st.sidebar.radio("Pipeline", list(PIPELINES.keys()), key="nav_pipeline")
|
| 31 |
+
|
| 32 |
+
stage_module = None
|
| 33 |
+
if PIPELINES[pipeline_choice] is not None:
|
| 34 |
+
stages_map = PIPELINES[pipeline_choice]
|
| 35 |
+
stage_choice = st.sidebar.radio("Stage", list(stages_map.keys()), key="nav_stage")
|
| 36 |
+
module_path = stages_map[stage_choice]
|
| 37 |
+
# dynamic import
|
| 38 |
+
import importlib
|
| 39 |
+
stage_module = importlib.import_module(module_path)
|
| 40 |
+
|
| 41 |
+
# Session status widget (always visible in sidebar)
|
| 42 |
+
st.sidebar.divider()
|
| 43 |
+
st.sidebar.subheader("📋 Session Status")
|
| 44 |
+
|
| 45 |
+
for pipe_label, pipe_key in [("Stereo", "stereo_pipeline"), ("General", "gen_pipeline")]:
|
| 46 |
+
pipe = st.session_state.get(pipe_key, {})
|
| 47 |
+
checks = {
|
| 48 |
+
"Data locked": "train_image" in pipe,
|
| 49 |
+
"Crop defined": "crop" in pipe,
|
| 50 |
+
"Modules set": "active_modules" in pipe,
|
| 51 |
+
"RCE trained": "rce_head" in pipe,
|
| 52 |
+
"CNN trained": any(f"cnn_head_{n}" in pipe
|
| 53 |
+
for n in ["ResNet-18", "MobileNetV3", "MobileViT-XXS"]),
|
| 54 |
+
"Dets ready": "rce_dets" in pipe or "cnn_dets" in pipe,
|
| 55 |
+
}
|
| 56 |
+
with st.sidebar.expander(f"**{pipe_label}** — {sum(checks.values())}/{len(checks)}"):
|
| 57 |
+
for label, done in checks.items():
|
| 58 |
+
st.markdown(f"{'✅' if done else '⬜'} {label}")
|
| 59 |
|
| 60 |
# ===================================================================
|
| 61 |
+
# Home Page
|
| 62 |
# ===================================================================
|
| 63 |
+
def render_home():
|
| 64 |
+
st.title("🦅 Recognition BenchMark")
|
| 65 |
+
st.subheader("A stereo-vision pipeline for object recognition & depth estimation")
|
| 66 |
+
st.caption("Compare classical feature engineering (RCE) against modern deep learning backbones — end-to-end, in your browser.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
+
st.divider()
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
+
# -------------------------------------------------------------------
|
| 71 |
+
# Two Pipelines
|
| 72 |
+
# -------------------------------------------------------------------
|
| 73 |
+
st.header("🗺️ Two Pipelines")
|
|
|
|
| 74 |
st.markdown("""
|
| 75 |
+
Choose a pipeline from the **sidebar**:
|
| 76 |
+
|
| 77 |
+
- **📐 Stereo + Depth** — 7 stages. Uses a stereo image pair (LEFT=train, RIGHT=test)
|
| 78 |
+
with calibration data and ground-truth disparities. Ends with metric depth estimation.
|
| 79 |
+
- **🌍 Generalisation** — 6 stages. Uses different scene *variants* from the Middlebury dataset
|
| 80 |
+
(train on one variant, test on another). Tests how well models generalise across viewpoints.
|
| 81 |
+
""")
|
| 82 |
+
|
| 83 |
+
col_s, col_g = st.columns(2)
|
| 84 |
+
with col_s:
|
| 85 |
+
st.markdown("### 📐 Stereo + Depth (7 stages)")
|
| 86 |
+
stereo_stages = [
|
| 87 |
+
("🧪", "Data Lab", "Load stereo pair, calib, GT depth. Define ROIs."),
|
| 88 |
+
("🔬", "Feature Lab", "Toggle RCE modules, compare CNN activations."),
|
| 89 |
+
("⚙️", "Model Tuning", "Train RCE / CNN / ORB heads."),
|
| 90 |
+
("🔍", "Localization Lab", "Compare 5 localization strategies."),
|
| 91 |
+
("🎯", "Real-Time Detection","Sliding window on the TEST image."),
|
| 92 |
+
("📈", "Evaluation", "Confusion matrices, PR curves, F1."),
|
| 93 |
+
("📐", "Stereo Geometry", "StereoSGBM disparity → metric depth."),
|
| 94 |
+
]
|
| 95 |
+
for icon, title, desc in stereo_stages:
|
| 96 |
+
st.markdown(f"{icon} **{title}** — {desc}")
|
| 97 |
+
|
| 98 |
+
with col_g:
|
| 99 |
+
st.markdown("### 🌍 Generalisation (6 stages)")
|
| 100 |
+
gen_stages = [
|
| 101 |
+
("🧪", "Data Lab", "Pick scene group & variants (train ≠ test)."),
|
| 102 |
+
("🔬", "Feature Lab", "Toggle RCE modules, compare CNN activations."),
|
| 103 |
+
("⚙️", "Model Tuning", "Train RCE / CNN / ORB heads."),
|
| 104 |
+
("🔍", "Localization Lab", "Compare 5 localization strategies."),
|
| 105 |
+
("🎯", "Real-Time Detection","Sliding window on a different variant."),
|
| 106 |
+
("📈", "Evaluation", "Confusion matrices, PR curves, F1."),
|
| 107 |
+
]
|
| 108 |
+
for icon, title, desc in gen_stages:
|
| 109 |
+
st.markdown(f"{icon} **{title}** — {desc}")
|
| 110 |
+
|
| 111 |
+
st.divider()
|
| 112 |
+
|
| 113 |
+
# -------------------------------------------------------------------
|
| 114 |
+
# Models
|
| 115 |
+
# -------------------------------------------------------------------
|
| 116 |
+
st.header("🧠 Models Used")
|
| 117 |
+
|
| 118 |
+
tab_rce, tab_resnet, tab_mobilenet, tab_mobilevit = st.tabs(
|
| 119 |
+
["RCE Engine", "ResNet-18", "MobileNetV3-Small", "MobileViT-XXS"])
|
| 120 |
+
|
| 121 |
+
with tab_rce:
|
| 122 |
+
st.markdown("### 🧬 RCE — Relative Contextual Encoding")
|
| 123 |
+
st.markdown("""
|
| 124 |
+
**Type:** Modular hand-crafted feature extractor
|
| 125 |
+
**Architecture:** Seven physics-inspired modules, each producing a 10-bin histogram:
|
| 126 |
|
| 127 |
| Module | Input | Operation |
|
| 128 |
|--------|-------|-----------|
|
| 129 |
| **Intensity** | Grayscale | Pixel-value histogram (global appearance) |
|
| 130 |
| **Sobel** | Gradient magnitude | Edge strength distribution (texture) |
|
| 131 |
| **Spectral** | FFT log-magnitude | Frequency content (pattern / structure) |
|
| 132 |
+
| **Laplacian** | Laplacian response | Second-derivative focus / sharpness |
|
| 133 |
+
| **Gradient Orientation** | Sobel angles | Edge direction histogram |
|
| 134 |
+
| **Gabor** | Multi-kernel response | Texture at multiple orientations / scales |
|
| 135 |
+
| **LBP** | Local Binary Patterns | Micro-texture descriptor |
|
| 136 |
+
|
| 137 |
+
Max feature vector = **70D** (7 modules × 10 bins).
|
| 138 |
+
""")
|
| 139 |
+
|
| 140 |
+
with tab_resnet:
|
| 141 |
+
st.markdown("### 🏗️ ResNet-18")
|
| 142 |
+
st.markdown("""
|
| 143 |
+
**Source:** PyTorch Hub (`torchvision.models.ResNet18_Weights.DEFAULT`)
|
| 144 |
+
**Pre-training:** ImageNet-1k (1.28 M images, 1 000 classes)
|
| 145 |
+
**Backbone output:** 512-dimensional embedding (after `avgpool`)
|
|
|
|
|
|
|
| 146 |
**Head:** LogisticRegression trained on your session data
|
| 147 |
|
| 148 |
+
**In this app:** The entire backbone is **frozen** (`requires_grad=False`).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
Only the lightweight head adapts to your specific object.
|
| 150 |
+
""")
|
| 151 |
+
|
| 152 |
+
with tab_mobilenet:
|
| 153 |
+
st.markdown("### 📱 MobileNetV3-Small")
|
| 154 |
+
st.markdown("""
|
| 155 |
+
**Source:** PyTorch Hub (`torchvision.models.MobileNet_V3_Small_Weights.DEFAULT`)
|
| 156 |
+
**Pre-training:** ImageNet-1k
|
| 157 |
+
**Backbone output:** 576-dimensional embedding
|
| 158 |
**Head:** LogisticRegression trained on your session data
|
| 159 |
|
| 160 |
+
**In this app:** Typically 3–5× faster than ResNet-18.
|
| 161 |
+
""")
|
|
|
|
|
|
|
|
|
|
| 162 |
|
| 163 |
+
with tab_mobilevit:
|
| 164 |
+
st.markdown("### 🤖 MobileViT-XXS")
|
| 165 |
+
st.markdown("""
|
| 166 |
+
**Source:** timm — `mobilevit_xxs.cvnets_in1k` (Apple Research, 2022)
|
| 167 |
+
**Pre-training:** ImageNet-1k
|
| 168 |
+
**Backbone output:** 320-dimensional embedding (`num_classes=0`)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
**Head:** LogisticRegression trained on your session data
|
| 170 |
|
| 171 |
+
**In this app:** Hybrid CNN + Vision Transformer. Only ~1.3 M parameters.
|
| 172 |
+
""")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
|
| 174 |
+
st.divider()
|
| 175 |
|
| 176 |
+
# -------------------------------------------------------------------
|
| 177 |
+
# Depth Estimation
|
| 178 |
+
# -------------------------------------------------------------------
|
| 179 |
+
st.header("📐 Stereo Depth Estimation")
|
| 180 |
|
| 181 |
+
col_d1, col_d2 = st.columns(2)
|
| 182 |
+
with col_d1:
|
| 183 |
+
st.markdown("""
|
| 184 |
**Algorithm:** `cv2.StereoSGBM` (Semi-Global Block Matching)
|
| 185 |
|
| 186 |
SGBM minimises a global energy function combining:
|
| 187 |
- Data cost (pixel intensity difference)
|
| 188 |
- Smoothness penalty (P1, P2 regularisation)
|
| 189 |
|
| 190 |
+
It processes multiple horizontal and diagonal scan-line passes,
|
| 191 |
making it significantly more accurate than basic block matching.
|
| 192 |
+
""")
|
| 193 |
+
with col_d2:
|
| 194 |
+
st.markdown("**Depth formula (Middlebury convention):**")
|
| 195 |
+
st.latex(r"Z = \frac{f \times B}{d + d_{\text{offs}}}")
|
| 196 |
+
st.markdown("""
|
| 197 |
+
- $f$ — focal length (pixels)
|
| 198 |
+
- $B$ — baseline (mm, from calibration file)
|
| 199 |
+
- $d$ — disparity (pixels)
|
|
|
|
|
|
|
| 200 |
- $d_\\text{offs}$ — optical-center offset between cameras
|
| 201 |
+
""")
|
| 202 |
+
|
| 203 |
+
st.divider()
|
| 204 |
+
st.caption("Select a pipeline from the **sidebar** to begin.")
|
| 205 |
|
|
|
|
| 206 |
|
| 207 |
# ===================================================================
|
| 208 |
+
# Dispatch
|
| 209 |
# ===================================================================
|
| 210 |
+
if stage_module is not None:
|
| 211 |
+
stage_module.render()
|
| 212 |
+
else:
|
| 213 |
+
render_home()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataOLD/README.md
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
two data views:
|
| 2 |
+
|
| 3 |
+
- Classification (ResNet, MobileNet, RCE)
|
| 4 |
+
- Detection (YOLO)
|
| 5 |
+
|
dataOLD/artroom/bird/yolo/bird_data.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
path: /Users/dariusgiannoli/Desktop/Recognition-BenchMark/data/artroom/bird/yolo
|
| 2 |
+
train: train/images
|
| 3 |
+
val: train/images
|
| 4 |
+
|
| 5 |
+
# Classes
|
| 6 |
+
nc: 1
|
| 7 |
+
names: ['bird']
|
dataOLD/artroom/bird/yolo/train/images/bird_01_original.png
ADDED
|
Git LFS Details
|
dataOLD/artroom/bird/yolo/train/images/bird_02_rot_pos5.png
ADDED
|
Git LFS Details
|
dataOLD/artroom/bird/yolo/train/images/bird_03_rot_neg5.png
ADDED
|
Git LFS Details
|
dataOLD/artroom/bird/yolo/train/images/bird_04_bright.png
ADDED
|
Git LFS Details
|
dataOLD/artroom/bird/yolo/train/images/bird_05_dark.png
ADDED
|
Git LFS Details
|
dataOLD/artroom/bird/yolo/train/images/bird_06_noisy.png
ADDED
|
Git LFS Details
|
dataOLD/artroom/bird/yolo/train/images/bird_07_flip.png
ADDED
|
Git LFS Details
|
dataOLD/artroom/bird/yolo/train/images/bird_08_blur.png
ADDED
|
Git LFS Details
|
dataOLD/artroom/bird/yolo/train/images/bird_09_shift_x.png
ADDED
|
Git LFS Details
|
dataOLD/artroom/bird/yolo/train/images/bird_10_shift_y.png
ADDED
|
Git LFS Details
|
dataOLD/artroom/bird/yolo/train/images/room_1.png
ADDED
|
Git LFS Details
|
dataOLD/artroom/bird/yolo/train/images/room_2.png
ADDED
|
Git LFS Details
|
dataOLD/artroom/bird/yolo/train/images/room_3.png
ADDED
|
Git LFS Details
|
dataOLD/artroom/bird/yolo/train/images/room_4.png
ADDED
|
Git LFS Details
|
dataOLD/artroom/bird/yolo/train/images/room_5.png
ADDED
|
Git LFS Details
|
dataOLD/artroom/bird/yolo/train/labels.cache
ADDED
|
Binary file (3.76 kB). View file
|
|
|
dataOLD/artroom/bird/yolo/train/labels/bird_01_original.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
0 0.5 0.5 0.8 0.8
|
dataOLD/artroom/bird/yolo/train/labels/bird_02_rot_pos5.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
0 0.5 0.5 0.8 0.8
|
dataOLD/artroom/bird/yolo/train/labels/bird_03_rot_neg5.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
0 0.5 0.5 0.8 0.8
|
dataOLD/artroom/bird/yolo/train/labels/bird_04_bright.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
0 0.5 0.5 0.8 0.8
|
dataOLD/artroom/bird/yolo/train/labels/bird_05_dark.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
0 0.5 0.5 0.8 0.8
|
dataOLD/artroom/bird/yolo/train/labels/bird_06_noisy.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
0 0.5 0.5 0.8 0.8
|
dataOLD/artroom/bird/yolo/train/labels/bird_07_flip.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
0 0.5 0.5 0.8 0.8
|
dataOLD/artroom/bird/yolo/train/labels/bird_08_blur.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
0 0.5 0.5 0.8 0.8
|
dataOLD/artroom/bird/yolo/train/labels/bird_09_shift_x.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
0 0.5 0.5 0.8 0.8
|
dataOLD/artroom/bird/yolo/train/labels/bird_10_shift_y.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
0 0.5 0.5 0.8 0.8
|
dataOLD/artroom/bird/yolo/train/labels/room_1.txt
ADDED
|
File without changes
|
dataOLD/artroom/bird/yolo/train/labels/room_2.txt
ADDED
|
File without changes
|
dataOLD/artroom/bird/yolo/train/labels/room_3.txt
ADDED
|
File without changes
|
dataOLD/artroom/bird/yolo/train/labels/room_4.txt
ADDED
|
File without changes
|
dataOLD/artroom/bird/yolo/train/labels/room_5.txt
ADDED
|
File without changes
|
dataOLD/artroom/im0.png
ADDED
|
Git LFS Details
|
pages/2_Data_Lab.py
DELETED
|
@@ -1,321 +0,0 @@
|
|
| 1 |
-
import streamlit as st
|
| 2 |
-
import cv2
|
| 3 |
-
import numpy as np
|
| 4 |
-
import io
|
| 5 |
-
|
| 6 |
-
st.set_page_config(page_title="Data Lab", layout="wide")
|
| 7 |
-
|
| 8 |
-
st.title("🧪 Data Lab: Stereo Asset Loader")
|
| 9 |
-
st.write("Upload your stereo images, camera configuration, and ground truth depth maps.")
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
# ---------------------------------------------------------------------------
|
| 13 |
-
# Helpers
|
| 14 |
-
# ---------------------------------------------------------------------------
|
| 15 |
-
def read_pfm(file_bytes: bytes) -> np.ndarray:
|
| 16 |
-
"""Parse a PFM (Portable Float Map) and return a float32 ndarray."""
|
| 17 |
-
buf = io.BytesIO(file_bytes)
|
| 18 |
-
header = buf.readline().decode("ascii").strip()
|
| 19 |
-
if header not in ("Pf", "PF"):
|
| 20 |
-
raise ValueError(f"Not a valid PFM file (header: {header!r})")
|
| 21 |
-
color = header == "PF"
|
| 22 |
-
line = buf.readline().decode("ascii").strip()
|
| 23 |
-
while line.startswith("#"):
|
| 24 |
-
line = buf.readline().decode("ascii").strip()
|
| 25 |
-
w, h = map(int, line.split())
|
| 26 |
-
scale = float(buf.readline().decode("ascii").strip())
|
| 27 |
-
endian = "<" if scale < 0 else ">"
|
| 28 |
-
channels = 3 if color else 1
|
| 29 |
-
data = np.frombuffer(buf.read(), dtype=np.dtype(endian + "f4"))
|
| 30 |
-
data = data.reshape((h, w, channels) if color else (h, w))
|
| 31 |
-
return np.flipud(data)
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
def vis_depth(depth: np.ndarray) -> np.ndarray:
|
| 35 |
-
"""Normalise depth to [0,1] for display, ignoring non-finite values."""
|
| 36 |
-
finite = depth[np.isfinite(depth)]
|
| 37 |
-
d = np.nan_to_num(depth, nan=0.0, posinf=float(finite.max()))
|
| 38 |
-
return (d / d.max()).astype(np.float32) if d.max() > 0 else d.astype(np.float32)
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
def augment(img: np.ndarray, brightness: float, contrast: float,
|
| 42 |
-
rotation: float, flip_h: bool, flip_v: bool,
|
| 43 |
-
noise: float, blur: int, shift_x: int, shift_y: int) -> np.ndarray:
|
| 44 |
-
"""Apply a chain of augmentations to a BGR crop."""
|
| 45 |
-
out = img.astype(np.float32)
|
| 46 |
-
|
| 47 |
-
# Brightness / Contrast: out = contrast * out + brightness_offset
|
| 48 |
-
out = np.clip(contrast * out + brightness, 0, 255)
|
| 49 |
-
|
| 50 |
-
# Gaussian noise
|
| 51 |
-
if noise > 0:
|
| 52 |
-
out = np.clip(out + np.random.normal(0, noise, out.shape), 0, 255)
|
| 53 |
-
|
| 54 |
-
out = out.astype(np.uint8)
|
| 55 |
-
|
| 56 |
-
# Blur (kernel must be odd)
|
| 57 |
-
k = blur * 2 + 1
|
| 58 |
-
if k > 1:
|
| 59 |
-
out = cv2.GaussianBlur(out, (k, k), 0)
|
| 60 |
-
|
| 61 |
-
# Rotation
|
| 62 |
-
if rotation != 0:
|
| 63 |
-
h, w = out.shape[:2]
|
| 64 |
-
M = cv2.getRotationMatrix2D((w / 2, h / 2), rotation, 1.0)
|
| 65 |
-
out = cv2.warpAffine(out, M, (w, h), borderMode=cv2.BORDER_REFLECT)
|
| 66 |
-
|
| 67 |
-
# Translation
|
| 68 |
-
if shift_x != 0 or shift_y != 0:
|
| 69 |
-
h, w = out.shape[:2]
|
| 70 |
-
M = np.float32([[1, 0, shift_x], [0, 1, shift_y]])
|
| 71 |
-
out = cv2.warpAffine(out, M, (w, h), borderMode=cv2.BORDER_REFLECT)
|
| 72 |
-
|
| 73 |
-
# Flips
|
| 74 |
-
if flip_h:
|
| 75 |
-
out = cv2.flip(out, 1)
|
| 76 |
-
if flip_v:
|
| 77 |
-
out = cv2.flip(out, 0)
|
| 78 |
-
|
| 79 |
-
return out
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
# --- Session State Initialization ---
|
| 83 |
-
MAX_UPLOAD_BYTES = 50 * 1024 * 1024 # 50 MB
|
| 84 |
-
|
| 85 |
-
if "pipeline_data" not in st.session_state:
|
| 86 |
-
st.session_state["pipeline_data"] = {}
|
| 87 |
-
|
| 88 |
-
# ---------------------------------------------------------------------------
|
| 89 |
-
# Step 1 — Upload Assets
|
| 90 |
-
# ---------------------------------------------------------------------------
|
| 91 |
-
st.subheader("Step 1: Upload Assets")
|
| 92 |
-
col1, col2 = st.columns(2)
|
| 93 |
-
|
| 94 |
-
with col1:
|
| 95 |
-
up_l = st.file_uploader("Left Image (Reference)", type=["png", "jpg", "jpeg"])
|
| 96 |
-
if up_l:
|
| 97 |
-
if up_l.size > MAX_UPLOAD_BYTES:
|
| 98 |
-
st.error(f"❌ Left image too large ({up_l.size / 1e6:.1f} MB). Max 50 MB.")
|
| 99 |
-
up_l = None
|
| 100 |
-
else:
|
| 101 |
-
img_l_preview = cv2.imdecode(np.frombuffer(up_l.read(), np.uint8), cv2.IMREAD_COLOR)
|
| 102 |
-
up_l.seek(0)
|
| 103 |
-
st.image(cv2.cvtColor(img_l_preview, cv2.COLOR_BGR2RGB), caption="Left Image Preview", use_container_width=True)
|
| 104 |
-
|
| 105 |
-
up_conf = st.file_uploader("Camera Config (.txt or .conf)", type=["txt", "conf"])
|
| 106 |
-
|
| 107 |
-
up_gt_l = st.file_uploader("Left Ground Truth Depth (.pfm)", type=["pfm"])
|
| 108 |
-
if up_gt_l:
|
| 109 |
-
try:
|
| 110 |
-
gt_l_prev = read_pfm(up_gt_l.read()); up_gt_l.seek(0)
|
| 111 |
-
st.image(vis_depth(gt_l_prev), caption="Left GT Depth Preview", use_container_width=True)
|
| 112 |
-
except (ValueError, Exception) as e:
|
| 113 |
-
st.error(f"❌ Invalid PFM file (left): {e}")
|
| 114 |
-
up_gt_l = None
|
| 115 |
-
|
| 116 |
-
with col2:
|
| 117 |
-
up_r = st.file_uploader("Right Image (Stereo Match)", type=["png", "jpg", "jpeg"])
|
| 118 |
-
if up_r:
|
| 119 |
-
if up_r.size > MAX_UPLOAD_BYTES:
|
| 120 |
-
st.error(f"❌ Right image too large ({up_r.size / 1e6:.1f} MB). Max 50 MB.")
|
| 121 |
-
up_r = None
|
| 122 |
-
else:
|
| 123 |
-
img_r_preview = cv2.imdecode(np.frombuffer(up_r.read(), np.uint8), cv2.IMREAD_COLOR)
|
| 124 |
-
up_r.seek(0)
|
| 125 |
-
st.image(cv2.cvtColor(img_r_preview, cv2.COLOR_BGR2RGB), caption="Right Image Preview", use_container_width=True)
|
| 126 |
-
|
| 127 |
-
up_gt_r = st.file_uploader("Right Ground Truth Depth (.pfm)", type=["pfm"])
|
| 128 |
-
if up_gt_r:
|
| 129 |
-
try:
|
| 130 |
-
gt_r_prev = read_pfm(up_gt_r.read()); up_gt_r.seek(0)
|
| 131 |
-
st.image(vis_depth(gt_r_prev), caption="Right GT Depth Preview", use_container_width=True)
|
| 132 |
-
except (ValueError, Exception) as e:
|
| 133 |
-
st.error(f"❌ Invalid PFM file (right): {e}")
|
| 134 |
-
up_gt_r = None
|
| 135 |
-
|
| 136 |
-
# ---------------------------------------------------------------------------
|
| 137 |
-
# Step 2 — Full pipeline (requires all 5 files)
|
| 138 |
-
# ---------------------------------------------------------------------------
|
| 139 |
-
if up_l and up_r and up_conf and up_gt_l and up_gt_r:
|
| 140 |
-
img_l = cv2.imdecode(np.frombuffer(up_l.read(), np.uint8), cv2.IMREAD_COLOR)
|
| 141 |
-
img_r = cv2.imdecode(np.frombuffer(up_r.read(), np.uint8), cv2.IMREAD_COLOR)
|
| 142 |
-
conf_content = up_conf.read().decode("utf-8")
|
| 143 |
-
gt_depth_l = read_pfm(up_gt_l.read())
|
| 144 |
-
gt_depth_r = read_pfm(up_gt_r.read())
|
| 145 |
-
|
| 146 |
-
st.success("✅ All assets loaded successfully!")
|
| 147 |
-
|
| 148 |
-
# --- Stereo pair ---
|
| 149 |
-
st.divider()
|
| 150 |
-
st.subheader("Step 2: Asset Visualization")
|
| 151 |
-
st.write("### 📸 Stereo Pair")
|
| 152 |
-
v1, v2 = st.columns(2)
|
| 153 |
-
v1.image(cv2.cvtColor(img_l, cv2.COLOR_BGR2RGB), caption="Left View", use_container_width=True)
|
| 154 |
-
v2.image(cv2.cvtColor(img_r, cv2.COLOR_BGR2RGB), caption="Right View", use_container_width=True)
|
| 155 |
-
|
| 156 |
-
# --- Ground truth maps ---
|
| 157 |
-
st.write("### 📊 Ground Truth Depth Maps")
|
| 158 |
-
d1, d2 = st.columns(2)
|
| 159 |
-
d1.image(vis_depth(gt_depth_l), caption="Left GT Depth (Normalized)", use_container_width=True)
|
| 160 |
-
d2.image(vis_depth(gt_depth_r), caption="Right GT Depth (Normalized)", use_container_width=True)
|
| 161 |
-
|
| 162 |
-
# --- Config ---
|
| 163 |
-
with st.expander("📄 Camera Configuration"):
|
| 164 |
-
st.text_area("Raw Config", conf_content, height=200)
|
| 165 |
-
|
| 166 |
-
# -----------------------------------------------------------------------
|
| 167 |
-
# Step 3 — Crop ROI(s) from Left Image (Multi-Object)
|
| 168 |
-
# -----------------------------------------------------------------------
|
| 169 |
-
st.divider()
|
| 170 |
-
st.subheader("Step 3: Crop Region(s) of Interest")
|
| 171 |
-
st.write("Define one or more bounding boxes — each becomes a separate class for recognition.")
|
| 172 |
-
|
| 173 |
-
H, W = img_l.shape[:2]
|
| 174 |
-
|
| 175 |
-
# Manage list of ROIs in session state
|
| 176 |
-
if "rois" not in st.session_state:
|
| 177 |
-
st.session_state["rois"] = [{"label": "object", "x0": 0, "y0": 0,
|
| 178 |
-
"x1": min(W, 100), "y1": min(H, 100)}]
|
| 179 |
-
|
| 180 |
-
def _add_roi():
|
| 181 |
-
if len(st.session_state["rois"]) >= 20:
|
| 182 |
-
return
|
| 183 |
-
st.session_state["rois"].append(
|
| 184 |
-
{"label": f"object_{len(st.session_state['rois'])+1}",
|
| 185 |
-
"x0": 0, "y0": 0,
|
| 186 |
-
"x1": min(W, 100), "y1": min(H, 100)})
|
| 187 |
-
|
| 188 |
-
def _remove_roi(idx):
|
| 189 |
-
if len(st.session_state["rois"]) > 1:
|
| 190 |
-
st.session_state["rois"].pop(idx)
|
| 191 |
-
|
| 192 |
-
ROI_COLORS = [(0,255,0), (255,0,0), (0,0,255), (255,255,0),
|
| 193 |
-
(255,0,255), (0,255,255), (128,255,0), (255,128,0)]
|
| 194 |
-
|
| 195 |
-
for i, roi in enumerate(st.session_state["rois"]):
|
| 196 |
-
color = ROI_COLORS[i % len(ROI_COLORS)]
|
| 197 |
-
color_hex = "#{:02x}{:02x}{:02x}".format(*color)
|
| 198 |
-
with st.container(border=True):
|
| 199 |
-
hc1, hc2, hc3 = st.columns([3, 6, 1])
|
| 200 |
-
hc1.markdown(f"**ROI {i+1}** <span style='color:{color_hex}'>■</span>",
|
| 201 |
-
unsafe_allow_html=True)
|
| 202 |
-
roi["label"] = hc2.text_input("Class Label", roi["label"],
|
| 203 |
-
key=f"roi_lbl_{i}")
|
| 204 |
-
if len(st.session_state["rois"]) > 1:
|
| 205 |
-
hc3.button("✕", key=f"roi_del_{i}",
|
| 206 |
-
on_click=_remove_roi, args=(i,))
|
| 207 |
-
|
| 208 |
-
cr1, cr2, cr3, cr4 = st.columns(4)
|
| 209 |
-
roi["x0"] = int(cr1.number_input("X start", 0, W-2, int(roi["x0"]),
|
| 210 |
-
step=1, key=f"roi_x0_{i}"))
|
| 211 |
-
roi["y0"] = int(cr2.number_input("Y start", 0, H-2, int(roi["y0"]),
|
| 212 |
-
step=1, key=f"roi_y0_{i}"))
|
| 213 |
-
roi["x1"] = int(cr3.number_input("X end", roi["x0"]+1, W,
|
| 214 |
-
min(W, int(roi["x1"])),
|
| 215 |
-
step=1, key=f"roi_x1_{i}"))
|
| 216 |
-
roi["y1"] = int(cr4.number_input("Y end", roi["y0"]+1, H,
|
| 217 |
-
min(H, int(roi["y1"])),
|
| 218 |
-
step=1, key=f"roi_y1_{i}"))
|
| 219 |
-
|
| 220 |
-
st.button("➕ Add Another ROI", on_click=_add_roi,
|
| 221 |
-
disabled=len(st.session_state["rois"]) >= 20)
|
| 222 |
-
|
| 223 |
-
# Draw all ROIs on the image
|
| 224 |
-
overlay = img_l.copy()
|
| 225 |
-
crops = []
|
| 226 |
-
for i, roi in enumerate(st.session_state["rois"]):
|
| 227 |
-
color = ROI_COLORS[i % len(ROI_COLORS)]
|
| 228 |
-
x0, y0, x1, y1 = roi["x0"], roi["y0"], roi["x1"], roi["y1"]
|
| 229 |
-
cv2.rectangle(overlay, (x0, y0), (x1, y1), color, 2)
|
| 230 |
-
cv2.putText(overlay, roi["label"], (x0, y0 - 6),
|
| 231 |
-
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
|
| 232 |
-
crops.append(img_l[y0:y1, x0:x1].copy())
|
| 233 |
-
|
| 234 |
-
ov1, ov2 = st.columns([3, 2])
|
| 235 |
-
ov1.image(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB),
|
| 236 |
-
caption="Left Image — ROIs highlighted", use_container_width=True)
|
| 237 |
-
with ov2:
|
| 238 |
-
for i, (c, roi) in enumerate(zip(crops, st.session_state["rois"])):
|
| 239 |
-
st.image(cv2.cvtColor(c, cv2.COLOR_BGR2RGB),
|
| 240 |
-
caption=f"{roi['label']} ({c.shape[1]}×{c.shape[0]})",
|
| 241 |
-
width=160)
|
| 242 |
-
|
| 243 |
-
# For backward compatibility: first ROI is the "primary"
|
| 244 |
-
crop_bgr = crops[0]
|
| 245 |
-
x0, y0, x1, y1 = (st.session_state["rois"][0]["x0"],
|
| 246 |
-
st.session_state["rois"][0]["y0"],
|
| 247 |
-
st.session_state["rois"][0]["x1"],
|
| 248 |
-
st.session_state["rois"][0]["y1"])
|
| 249 |
-
|
| 250 |
-
# -----------------------------------------------------------------------
|
| 251 |
-
# Step 4 — Data Augmentation
|
| 252 |
-
# -----------------------------------------------------------------------
|
| 253 |
-
st.divider()
|
| 254 |
-
st.subheader("Step 4: Data Augmentation")
|
| 255 |
-
st.write("Tune the parameters below — the augmented crop updates live.")
|
| 256 |
-
|
| 257 |
-
ac1, ac2 = st.columns(2)
|
| 258 |
-
with ac1:
|
| 259 |
-
brightness = st.slider("Brightness offset", -100, 100, 0, step=1)
|
| 260 |
-
contrast = st.slider("Contrast scale", 0.5, 3.0, 1.0, step=0.05)
|
| 261 |
-
rotation = st.slider("Rotation (°)", -180, 180, 0, step=1)
|
| 262 |
-
noise = st.slider("Gaussian noise σ", 0, 50, 0, step=1)
|
| 263 |
-
with ac2:
|
| 264 |
-
blur = st.slider("Blur kernel (0 = off)", 0, 10, 0, step=1)
|
| 265 |
-
shift_x = st.slider("Shift X (px)", -100, 100, 0, step=1)
|
| 266 |
-
shift_y = st.slider("Shift Y (px)", -100, 100, 0, step=1)
|
| 267 |
-
flip_h = st.checkbox("Flip Horizontal")
|
| 268 |
-
flip_v = st.checkbox("Flip Vertical")
|
| 269 |
-
|
| 270 |
-
aug = augment(crop_bgr, brightness, contrast, rotation,
|
| 271 |
-
flip_h, flip_v, noise, blur, shift_x, shift_y)
|
| 272 |
-
|
| 273 |
-
# Apply same augmentation to all crops
|
| 274 |
-
all_augs = [augment(c, brightness, contrast, rotation,
|
| 275 |
-
flip_h, flip_v, noise, blur, shift_x, shift_y)
|
| 276 |
-
for c in crops]
|
| 277 |
-
|
| 278 |
-
aug_col1, aug_col2 = st.columns(2)
|
| 279 |
-
aug_col1.image(cv2.cvtColor(crop_bgr, cv2.COLOR_BGR2RGB),
|
| 280 |
-
caption="Original Crop (ROI 1)", use_container_width=True)
|
| 281 |
-
aug_col2.image(cv2.cvtColor(aug, cv2.COLOR_BGR2RGB),
|
| 282 |
-
caption="Augmented Crop (ROI 1)", use_container_width=True)
|
| 283 |
-
|
| 284 |
-
if len(crops) > 1:
|
| 285 |
-
st.caption(f"Augmentation applied identically to all {len(crops)} ROIs.")
|
| 286 |
-
|
| 287 |
-
# -----------------------------------------------------------------------
|
| 288 |
-
# Step 5 — Lock & Store
|
| 289 |
-
# -----------------------------------------------------------------------
|
| 290 |
-
st.divider()
|
| 291 |
-
if st.button("🚀 Lock Data & Proceed to Benchmark"):
|
| 292 |
-
if not st.session_state.get("rois") or len(st.session_state["rois"]) == 0:
|
| 293 |
-
st.error("❌ Define at least one ROI before locking!")
|
| 294 |
-
st.stop()
|
| 295 |
-
rois_data = []
|
| 296 |
-
for i, roi in enumerate(st.session_state["rois"]):
|
| 297 |
-
rois_data.append({
|
| 298 |
-
"label": roi["label"],
|
| 299 |
-
"bbox": (roi["x0"], roi["y0"], roi["x1"], roi["y1"]),
|
| 300 |
-
"crop": crops[i],
|
| 301 |
-
"crop_aug": all_augs[i],
|
| 302 |
-
})
|
| 303 |
-
|
| 304 |
-
st.session_state["pipeline_data"] = {
|
| 305 |
-
"left": img_l,
|
| 306 |
-
"right": img_r,
|
| 307 |
-
"gt_left": gt_depth_l,
|
| 308 |
-
"gt_right": gt_depth_r,
|
| 309 |
-
"conf_raw": conf_content,
|
| 310 |
-
# Backward compatibility: first ROI
|
| 311 |
-
"crop": crop_bgr,
|
| 312 |
-
"crop_aug": aug,
|
| 313 |
-
"crop_bbox": (x0, y0, x1, y1),
|
| 314 |
-
# Multi-object
|
| 315 |
-
"rois": rois_data,
|
| 316 |
-
}
|
| 317 |
-
st.success(f"Data stored with **{len(rois_data)} ROI(s)**! "
|
| 318 |
-
f"Move to Feature Lab.")
|
| 319 |
-
|
| 320 |
-
else:
|
| 321 |
-
st.info("Please upload all 5 files (left image, right image, config, left GT, right GT) to proceed.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pages/3_Feature_Lab.py
DELETED
|
@@ -1,111 +0,0 @@
|
|
| 1 |
-
import streamlit as st
|
| 2 |
-
import cv2
|
| 3 |
-
import numpy as np
|
| 4 |
-
import plotly.graph_objects as go
|
| 5 |
-
import sys, os
|
| 6 |
-
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 7 |
-
from src.detectors.rce.features import REGISTRY
|
| 8 |
-
from src.models import BACKBONES
|
| 9 |
-
|
| 10 |
-
st.set_page_config(page_title="Feature Lab", layout="wide")
|
| 11 |
-
|
| 12 |
-
if "pipeline_data" not in st.session_state:
|
| 13 |
-
st.error("Please complete the Data Lab first!")
|
| 14 |
-
st.stop()
|
| 15 |
-
|
| 16 |
-
assets = st.session_state["pipeline_data"]
|
| 17 |
-
# Use augmented crop if available, otherwise fall back to original crop
|
| 18 |
-
obj = assets.get("crop_aug", assets.get("crop"))
|
| 19 |
-
if obj is None:
|
| 20 |
-
st.error("No crop found. Please go back to Data Lab and define a ROI.")
|
| 21 |
-
st.stop()
|
| 22 |
-
gray = cv2.cvtColor(obj, cv2.COLOR_BGR2GRAY)
|
| 23 |
-
|
| 24 |
-
st.title("🔬 Feature Lab: Physical Module Selection")
|
| 25 |
-
|
| 26 |
-
col_rce, col_cnn = st.columns([3, 2])
|
| 27 |
-
|
| 28 |
-
# ---------------------------------------------------------------------------
|
| 29 |
-
# LEFT — RCE Modular Engine (pure UI — all math lives in features.py)
|
| 30 |
-
# ---------------------------------------------------------------------------
|
| 31 |
-
with col_rce:
|
| 32 |
-
st.header("🧬 RCE: Modular Physics Engine")
|
| 33 |
-
st.subheader("Select Active Modules")
|
| 34 |
-
|
| 35 |
-
# Dynamically build checkboxes from the registry (rows of 4)
|
| 36 |
-
active = {}
|
| 37 |
-
items = list(REGISTRY.items())
|
| 38 |
-
for row_start in range(0, len(items), 4):
|
| 39 |
-
row_items = items[row_start:row_start + 4]
|
| 40 |
-
m_cols = st.columns(4)
|
| 41 |
-
for col, (key, meta) in zip(m_cols, row_items):
|
| 42 |
-
active[key] = col.checkbox(meta["label"], value=(key in ("intensity", "sobel", "spectral")))
|
| 43 |
-
|
| 44 |
-
# Build vector + collect visualizations by calling registry functions
|
| 45 |
-
final_vector = []
|
| 46 |
-
viz_images = []
|
| 47 |
-
for key, meta in REGISTRY.items():
|
| 48 |
-
if active[key]:
|
| 49 |
-
vec, viz = meta["fn"](gray)
|
| 50 |
-
final_vector.extend(vec)
|
| 51 |
-
viz_images.append((meta["viz_title"], viz))
|
| 52 |
-
|
| 53 |
-
# Visualizations (rows of 3)
|
| 54 |
-
st.divider()
|
| 55 |
-
if viz_images:
|
| 56 |
-
for row_start in range(0, len(viz_images), 3):
|
| 57 |
-
row = viz_images[row_start:row_start + 3]
|
| 58 |
-
v_cols = st.columns(3)
|
| 59 |
-
for col, (title, img) in zip(v_cols, row):
|
| 60 |
-
col.image(img, caption=title, use_container_width=True)
|
| 61 |
-
else:
|
| 62 |
-
st.warning("No modules selected — vector is empty.")
|
| 63 |
-
|
| 64 |
-
# DNA vector bar chart
|
| 65 |
-
st.write(f"### Current DNA Vector Size: **{len(final_vector)}**")
|
| 66 |
-
fig_vec = go.Figure(data=[go.Bar(y=final_vector, marker_color="#00d4ff")])
|
| 67 |
-
fig_vec.update_layout(title="Active Feature Vector (RCE Input)",
|
| 68 |
-
template="plotly_dark", height=300)
|
| 69 |
-
st.plotly_chart(fig_vec, use_container_width=True)
|
| 70 |
-
|
| 71 |
-
# ---------------------------------------------------------------------------
|
| 72 |
-
# RIGHT — CNN comparison panel
|
| 73 |
-
# ---------------------------------------------------------------------------
|
| 74 |
-
with col_cnn:
|
| 75 |
-
st.header("🧠 CNN: Static Architecture")
|
| 76 |
-
selected_cnn = st.selectbox("Compare against Model", list(BACKBONES.keys()))
|
| 77 |
-
st.info("CNN features are fixed by pre-trained weights. You cannot toggle them like the RCE.")
|
| 78 |
-
|
| 79 |
-
with st.spinner(f"Loading {selected_cnn} and extracting activations..."):
|
| 80 |
-
try:
|
| 81 |
-
bmeta = BACKBONES[selected_cnn]
|
| 82 |
-
backbone = bmeta["loader"]() # cached frozen backbone
|
| 83 |
-
layer_name = bmeta["hook_layer"]
|
| 84 |
-
|
| 85 |
-
act_maps = backbone.get_activation_maps(obj, n_maps=6)
|
| 86 |
-
st.caption(f"Hooked layer: `{layer_name}` — showing 6 of {len(act_maps)} channels")
|
| 87 |
-
act_cols = st.columns(3)
|
| 88 |
-
for i, amap in enumerate(act_maps):
|
| 89 |
-
act_cols[i % 3].image(amap, caption=f"Channel {i}", use_container_width=True)
|
| 90 |
-
|
| 91 |
-
except Exception as e:
|
| 92 |
-
st.error(f"Could not load model: {e}")
|
| 93 |
-
|
| 94 |
-
st.divider()
|
| 95 |
-
st.markdown(f"""
|
| 96 |
-
**Analysis:**
|
| 97 |
-
- **Modularity:** RCE is **High** | CNN is **Zero**
|
| 98 |
-
- **Explainability:** RCE is **High** | CNN is **Low**
|
| 99 |
-
- **Compute Cost:** {len(final_vector)} floats | 512+ floats
|
| 100 |
-
""")
|
| 101 |
-
|
| 102 |
-
# ---------------------------------------------------------------------------
|
| 103 |
-
# Lock configuration
|
| 104 |
-
# ---------------------------------------------------------------------------
|
| 105 |
-
if st.button("🚀 Lock Modular Configuration"):
|
| 106 |
-
if not final_vector:
|
| 107 |
-
st.error("Please select at least one module!")
|
| 108 |
-
else:
|
| 109 |
-
st.session_state["pipeline_data"]["final_vector"] = np.array(final_vector)
|
| 110 |
-
st.session_state["active_modules"] = {k: v for k, v in active.items()}
|
| 111 |
-
st.success("Modular DNA Locked! Ready for Model Tuning.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pages/4_Model_Tuning.py
DELETED
|
@@ -1,475 +0,0 @@
|
|
| 1 |
-
import streamlit as st
|
| 2 |
-
import cv2
|
| 3 |
-
import numpy as np
|
| 4 |
-
import time
|
| 5 |
-
import plotly.graph_objects as go
|
| 6 |
-
import sys, os
|
| 7 |
-
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 8 |
-
|
| 9 |
-
from src.detectors.rce.features import REGISTRY
|
| 10 |
-
from src.models import BACKBONES, RecognitionHead
|
| 11 |
-
from src.utils import build_rce_vector
|
| 12 |
-
|
| 13 |
-
st.set_page_config(page_title="Model Tuning", layout="wide")
|
| 14 |
-
st.title("⚙️ Model Tuning: Train & Compare")
|
| 15 |
-
|
| 16 |
-
# ---------------------------------------------------------------------------
|
| 17 |
-
# Guard: require Data Lab completion
|
| 18 |
-
# ---------------------------------------------------------------------------
|
| 19 |
-
if "pipeline_data" not in st.session_state or "crop" not in st.session_state.get("pipeline_data", {}):
|
| 20 |
-
st.error("Please complete the **Data Lab** first (upload assets & define a crop).")
|
| 21 |
-
st.stop()
|
| 22 |
-
|
| 23 |
-
assets = st.session_state["pipeline_data"]
|
| 24 |
-
crop = assets["crop"]
|
| 25 |
-
crop_aug = assets.get("crop_aug", crop)
|
| 26 |
-
left_img = assets["left"]
|
| 27 |
-
bbox = assets.get("crop_bbox", (0, 0, crop.shape[1], crop.shape[0]))
|
| 28 |
-
rois = assets.get("rois", [{"label": "object", "bbox": bbox,
|
| 29 |
-
"crop": crop, "crop_aug": crop_aug}])
|
| 30 |
-
active_modules = st.session_state.get("active_modules", {k: True for k in REGISTRY})
|
| 31 |
-
|
| 32 |
-
is_multi = len(rois) > 1
|
| 33 |
-
|
| 34 |
-
# ---------------------------------------------------------------------------
|
| 35 |
-
# Build training set from session data (no disk reads)
|
| 36 |
-
# ---------------------------------------------------------------------------
|
| 37 |
-
def build_training_set():
|
| 38 |
-
"""
|
| 39 |
-
Multi-class aware training set builder.
|
| 40 |
-
Positive samples per class: original crop + augmented crop.
|
| 41 |
-
Negative samples: random patches that don't overlap ANY ROI.
|
| 42 |
-
"""
|
| 43 |
-
images = []
|
| 44 |
-
labels = []
|
| 45 |
-
|
| 46 |
-
for roi in rois:
|
| 47 |
-
images.append(roi["crop"])
|
| 48 |
-
labels.append(roi["label"])
|
| 49 |
-
images.append(roi["crop_aug"])
|
| 50 |
-
labels.append(roi["label"])
|
| 51 |
-
|
| 52 |
-
all_bboxes = [roi["bbox"] for roi in rois]
|
| 53 |
-
H, W = left_img.shape[:2]
|
| 54 |
-
x0r, y0r, x1r, y1r = rois[0]["bbox"]
|
| 55 |
-
ch, cw = y1r - y0r, x1r - x0r
|
| 56 |
-
rng = np.random.default_rng(42)
|
| 57 |
-
|
| 58 |
-
n_neg_target = len(images) * 2
|
| 59 |
-
attempts = 0
|
| 60 |
-
negatives = []
|
| 61 |
-
while len(negatives) < n_neg_target and attempts < 300:
|
| 62 |
-
rx = rng.integers(0, max(W - cw, 1))
|
| 63 |
-
ry = rng.integers(0, max(H - ch, 1))
|
| 64 |
-
overlaps = False
|
| 65 |
-
for bx0, by0, bx1, by1 in all_bboxes:
|
| 66 |
-
if rx < bx1 and rx + cw > bx0 and ry < by1 and ry + ch > by0:
|
| 67 |
-
overlaps = True
|
| 68 |
-
break
|
| 69 |
-
if overlaps:
|
| 70 |
-
attempts += 1
|
| 71 |
-
continue
|
| 72 |
-
patch = left_img[ry:ry+ch, rx:rx+cw]
|
| 73 |
-
if patch.shape[0] > 0 and patch.shape[1] > 0:
|
| 74 |
-
negatives.append(patch)
|
| 75 |
-
attempts += 1
|
| 76 |
-
|
| 77 |
-
images.extend(negatives)
|
| 78 |
-
labels.extend(["background"] * len(negatives))
|
| 79 |
-
return images, labels, len(negatives) < n_neg_target // 2
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
# ===================================================================
|
| 83 |
-
# Show training data
|
| 84 |
-
# ===================================================================
|
| 85 |
-
st.subheader("Training Data (from Data Lab)")
|
| 86 |
-
if is_multi:
|
| 87 |
-
st.caption(f"**{len(rois)} classes** defined — each ROI becomes a separate class.")
|
| 88 |
-
roi_cols = st.columns(min(len(rois), 4))
|
| 89 |
-
for i, roi in enumerate(rois):
|
| 90 |
-
with roi_cols[i % len(roi_cols)]:
|
| 91 |
-
st.image(cv2.cvtColor(roi["crop"], cv2.COLOR_BGR2RGB),
|
| 92 |
-
caption=f"✅ {roi['label']}", width=140)
|
| 93 |
-
else:
|
| 94 |
-
st.caption("Positives = your crop + augmented crop | "
|
| 95 |
-
"Negatives = random non-overlapping patches")
|
| 96 |
-
td1, td2 = st.columns(2)
|
| 97 |
-
td1.image(cv2.cvtColor(crop, cv2.COLOR_BGR2RGB),
|
| 98 |
-
caption="Original Crop (positive)", width=180)
|
| 99 |
-
td2.image(cv2.cvtColor(crop_aug, cv2.COLOR_BGR2RGB),
|
| 100 |
-
caption="Augmented Crop (positive)", width=180)
|
| 101 |
-
|
| 102 |
-
st.divider()
|
| 103 |
-
|
| 104 |
-
# ===================================================================
|
| 105 |
-
# LAYOUT: RCE | CNN | ORB
|
| 106 |
-
# ===================================================================
|
| 107 |
-
col_rce, col_cnn, col_orb = st.columns(3)
|
| 108 |
-
|
| 109 |
-
# ---------------------------------------------------------------------------
|
| 110 |
-
# LEFT — RCE Training
|
| 111 |
-
# ---------------------------------------------------------------------------
|
| 112 |
-
with col_rce:
|
| 113 |
-
st.header("🧬 RCE Training")
|
| 114 |
-
|
| 115 |
-
active_names = [REGISTRY[k]["label"] for k in active_modules if active_modules[k]]
|
| 116 |
-
if not active_names:
|
| 117 |
-
st.error("No RCE modules selected. Go back to Feature Lab.")
|
| 118 |
-
else:
|
| 119 |
-
st.write(f"**Active modules:** {', '.join(active_names)}")
|
| 120 |
-
|
| 121 |
-
st.subheader("Training Parameters")
|
| 122 |
-
rce_C = st.slider("Regularization (C)", 0.01, 10.0, 1.0, step=0.01,
|
| 123 |
-
help="Higher = less regularization, may overfit")
|
| 124 |
-
rce_max_iter = st.slider("Max Iterations", 100, 5000, 1000, step=100)
|
| 125 |
-
|
| 126 |
-
if st.button("🚀 Train RCE Head"):
|
| 127 |
-
images, labels, neg_short = build_training_set()
|
| 128 |
-
if neg_short:
|
| 129 |
-
st.warning(f"⚠️ Only {sum(1 for l in labels if l == 'background')} "
|
| 130 |
-
f"negatives collected (target was {sum(1 for l in labels if l != 'background') * 2}). "
|
| 131 |
-
f"Training data may be imbalanced.")
|
| 132 |
-
from sklearn.metrics import accuracy_score
|
| 133 |
-
from sklearn.model_selection import cross_val_score
|
| 134 |
-
|
| 135 |
-
progress = st.progress(0, text="Extracting RCE features...")
|
| 136 |
-
n = len(images)
|
| 137 |
-
X = []
|
| 138 |
-
for i, img in enumerate(images):
|
| 139 |
-
X.append(build_rce_vector(img, active_modules))
|
| 140 |
-
progress.progress((i + 1) / n, text=f"Feature extraction: {i+1}/{n}")
|
| 141 |
-
|
| 142 |
-
X = np.array(X)
|
| 143 |
-
progress.progress(1.0, text="Fitting Logistic Regression...")
|
| 144 |
-
|
| 145 |
-
t0 = time.perf_counter()
|
| 146 |
-
try:
|
| 147 |
-
head = RecognitionHead(C=rce_C, max_iter=rce_max_iter).fit(X, labels)
|
| 148 |
-
except ValueError as e:
|
| 149 |
-
st.error(f"Training failed: {e}")
|
| 150 |
-
st.stop()
|
| 151 |
-
train_time = time.perf_counter() - t0
|
| 152 |
-
progress.progress(1.0, text="✅ Training complete!")
|
| 153 |
-
|
| 154 |
-
preds = head.model.predict(X)
|
| 155 |
-
train_acc = accuracy_score(labels, preds)
|
| 156 |
-
|
| 157 |
-
st.success(f"Trained in **{train_time:.2f}s**")
|
| 158 |
-
m1, m2, m3, m4 = st.columns(4)
|
| 159 |
-
m1.metric("Train Accuracy", f"{train_acc:.1%}")
|
| 160 |
-
# Cross-validation (only if enough samples)
|
| 161 |
-
if len(images) >= 6:
|
| 162 |
-
n_splits = min(5, len(set(labels)))
|
| 163 |
-
if n_splits >= 2:
|
| 164 |
-
cv_scores = cross_val_score(head.model, X, labels,
|
| 165 |
-
cv=min(3, len(images) // 2))
|
| 166 |
-
m2.metric("CV Accuracy", f"{cv_scores.mean():.1%}",
|
| 167 |
-
delta=f"±{cv_scores.std():.1%}")
|
| 168 |
-
else:
|
| 169 |
-
m2.metric("CV Accuracy", "N/A")
|
| 170 |
-
else:
|
| 171 |
-
m2.metric("CV Accuracy", "N/A")
|
| 172 |
-
m3.metric("Vector Size", f"{X.shape[1]} floats")
|
| 173 |
-
m4.metric("Samples", f"{len(images)}")
|
| 174 |
-
if len(images) < 10:
|
| 175 |
-
st.warning("⚠️ Training set is small (<10 samples). "
|
| 176 |
-
"Reported accuracy may not reflect real performance.")
|
| 177 |
-
if is_multi:
|
| 178 |
-
st.caption(f"Classes: {', '.join(head.classes_)}")
|
| 179 |
-
|
| 180 |
-
probs = head.predict_proba(X)
|
| 181 |
-
fig = go.Figure()
|
| 182 |
-
for ci, cls in enumerate(head.classes_):
|
| 183 |
-
fig.add_trace(go.Histogram(x=probs[:, ci], name=cls,
|
| 184 |
-
opacity=0.7, nbinsx=20))
|
| 185 |
-
fig.update_layout(title="Confidence Distribution", barmode="overlay",
|
| 186 |
-
template="plotly_dark", height=280,
|
| 187 |
-
xaxis_title="Confidence", yaxis_title="Count")
|
| 188 |
-
st.plotly_chart(fig, use_container_width=True)
|
| 189 |
-
|
| 190 |
-
# ---- Feature Importance (RCE) ----
|
| 191 |
-
st.subheader("🔍 Feature Importance")
|
| 192 |
-
coefs = head.model.coef_
|
| 193 |
-
feat_names = []
|
| 194 |
-
for key, meta_r in REGISTRY.items():
|
| 195 |
-
if active_modules.get(key, False):
|
| 196 |
-
for b in range(10):
|
| 197 |
-
feat_names.append(f"{meta_r['label']}[{b}]")
|
| 198 |
-
|
| 199 |
-
if coefs.shape[0] == 1:
|
| 200 |
-
importance = np.abs(coefs[0])
|
| 201 |
-
fig_imp = go.Figure(go.Bar(
|
| 202 |
-
x=feat_names, y=importance,
|
| 203 |
-
marker_color=["#00d4ff" if "Intensity" in fn
|
| 204 |
-
else "#ff6600" if "Sobel" in fn
|
| 205 |
-
else "#aa00ff" for fn in feat_names]))
|
| 206 |
-
fig_imp.update_layout(title="LogReg Coefficient Magnitude",
|
| 207 |
-
template="plotly_dark", height=300,
|
| 208 |
-
xaxis_title="Feature", yaxis_title="|Coefficient|")
|
| 209 |
-
else:
|
| 210 |
-
fig_imp = go.Figure()
|
| 211 |
-
for ci, cls in enumerate(head.classes_):
|
| 212 |
-
if cls == "background":
|
| 213 |
-
continue
|
| 214 |
-
fig_imp.add_trace(go.Bar(
|
| 215 |
-
x=feat_names, y=np.abs(coefs[ci]),
|
| 216 |
-
name=cls, opacity=0.8))
|
| 217 |
-
fig_imp.update_layout(title="LogReg Coefficients per Class",
|
| 218 |
-
template="plotly_dark", height=300,
|
| 219 |
-
barmode="group",
|
| 220 |
-
xaxis_title="Feature", yaxis_title="|Coefficient|")
|
| 221 |
-
st.plotly_chart(fig_imp, use_container_width=True)
|
| 222 |
-
|
| 223 |
-
# Module-level aggregation
|
| 224 |
-
module_importance = {}
|
| 225 |
-
idx = 0
|
| 226 |
-
for key, meta_r in REGISTRY.items():
|
| 227 |
-
if active_modules.get(key, False):
|
| 228 |
-
module_importance[meta_r["label"]] = float(
|
| 229 |
-
np.abs(coefs[:, idx:idx+10]).mean())
|
| 230 |
-
idx += 10
|
| 231 |
-
|
| 232 |
-
if module_importance:
|
| 233 |
-
fig_mod = go.Figure(go.Pie(
|
| 234 |
-
labels=list(module_importance.keys()),
|
| 235 |
-
values=list(module_importance.values()),
|
| 236 |
-
hole=0.4))
|
| 237 |
-
fig_mod.update_layout(title="Module Contribution (avg |coef|)",
|
| 238 |
-
template="plotly_dark", height=280)
|
| 239 |
-
st.plotly_chart(fig_mod, use_container_width=True)
|
| 240 |
-
|
| 241 |
-
st.session_state["rce_head"] = head
|
| 242 |
-
st.session_state["rce_train_acc"] = train_acc
|
| 243 |
-
|
| 244 |
-
if "rce_head" in st.session_state:
|
| 245 |
-
st.divider()
|
| 246 |
-
st.subheader("Quick Predict (Crop)")
|
| 247 |
-
head = st.session_state["rce_head"]
|
| 248 |
-
t0 = time.perf_counter()
|
| 249 |
-
vec = build_rce_vector(crop_aug, active_modules)
|
| 250 |
-
label, conf = head.predict(vec)
|
| 251 |
-
dt = (time.perf_counter() - t0) * 1000
|
| 252 |
-
st.write(f"**{label}** — {conf:.1%} confidence — {dt:.1f} ms")
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
# ---------------------------------------------------------------------------
|
| 256 |
-
# MIDDLE — CNN Fine-Tuning
|
| 257 |
-
# ---------------------------------------------------------------------------
|
| 258 |
-
with col_cnn:
|
| 259 |
-
st.header("🧠 CNN Fine-Tuning")
|
| 260 |
-
|
| 261 |
-
selected = st.selectbox("Select Model", list(BACKBONES.keys()))
|
| 262 |
-
meta = BACKBONES[selected]
|
| 263 |
-
st.caption(f"Backbone embedding: **{meta['dim']}D** → Logistic Regression head")
|
| 264 |
-
|
| 265 |
-
st.subheader("Training Parameters")
|
| 266 |
-
cnn_C = st.slider("Regularization (C) ", 0.01, 10.0, 1.0, step=0.01,
|
| 267 |
-
key="cnn_c", help="Higher = less regularization")
|
| 268 |
-
cnn_max_iter = st.slider("Max Iterations ", 100, 5000, 1000, step=100,
|
| 269 |
-
key="cnn_iter")
|
| 270 |
-
|
| 271 |
-
if st.button(f"🚀 Train {selected} Head"):
|
| 272 |
-
images, labels, neg_short = build_training_set()
|
| 273 |
-
if neg_short:
|
| 274 |
-
st.warning(f"⚠️ Negative sample shortfall — training may be imbalanced.")
|
| 275 |
-
backbone = meta["loader"]()
|
| 276 |
-
|
| 277 |
-
from sklearn.metrics import accuracy_score
|
| 278 |
-
from sklearn.model_selection import cross_val_score
|
| 279 |
-
|
| 280 |
-
progress = st.progress(0, text=f"Extracting {selected} features...")
|
| 281 |
-
n = len(images)
|
| 282 |
-
X = []
|
| 283 |
-
for i, img in enumerate(images):
|
| 284 |
-
X.append(backbone.get_features(img))
|
| 285 |
-
progress.progress((i + 1) / n, text=f"Feature extraction: {i+1}/{n}")
|
| 286 |
-
|
| 287 |
-
X = np.array(X)
|
| 288 |
-
progress.progress(1.0, text="Fitting Logistic Regression...")
|
| 289 |
-
|
| 290 |
-
t0 = time.perf_counter()
|
| 291 |
-
try:
|
| 292 |
-
head = RecognitionHead(C=cnn_C, max_iter=cnn_max_iter).fit(X, labels)
|
| 293 |
-
except ValueError as e:
|
| 294 |
-
st.error(f"Training failed: {e}")
|
| 295 |
-
st.stop()
|
| 296 |
-
train_time = time.perf_counter() - t0
|
| 297 |
-
progress.progress(1.0, text="✅ Training complete!")
|
| 298 |
-
|
| 299 |
-
preds = head.model.predict(X)
|
| 300 |
-
train_acc = accuracy_score(labels, preds)
|
| 301 |
-
|
| 302 |
-
st.success(f"Trained in **{train_time:.2f}s**")
|
| 303 |
-
m1, m2, m3, m4 = st.columns(4)
|
| 304 |
-
m1.metric("Train Accuracy", f"{train_acc:.1%}")
|
| 305 |
-
if len(images) >= 6:
|
| 306 |
-
n_splits = min(5, len(set(labels)))
|
| 307 |
-
if n_splits >= 2:
|
| 308 |
-
cv_scores = cross_val_score(head.model, X, labels,
|
| 309 |
-
cv=min(3, len(images) // 2))
|
| 310 |
-
m2.metric("CV Accuracy", f"{cv_scores.mean():.1%}",
|
| 311 |
-
delta=f"±{cv_scores.std():.1%}")
|
| 312 |
-
else:
|
| 313 |
-
m2.metric("CV Accuracy", "N/A")
|
| 314 |
-
else:
|
| 315 |
-
m2.metric("CV Accuracy", "N/A")
|
| 316 |
-
m3.metric("Vector Size", f"{X.shape[1]}D")
|
| 317 |
-
m4.metric("Samples", f"{len(images)}")
|
| 318 |
-
if is_multi:
|
| 319 |
-
st.caption(f"Classes: {', '.join(head.classes_)}")
|
| 320 |
-
|
| 321 |
-
probs = head.predict_proba(X)
|
| 322 |
-
fig = go.Figure()
|
| 323 |
-
for ci, cls in enumerate(head.classes_):
|
| 324 |
-
fig.add_trace(go.Histogram(x=probs[:, ci], name=cls,
|
| 325 |
-
opacity=0.7, nbinsx=20))
|
| 326 |
-
fig.update_layout(title="Confidence Distribution", barmode="overlay",
|
| 327 |
-
template="plotly_dark", height=280,
|
| 328 |
-
xaxis_title="Confidence", yaxis_title="Count")
|
| 329 |
-
st.plotly_chart(fig, use_container_width=True)
|
| 330 |
-
|
| 331 |
-
# ---- Activation Overlay (Grad-CAM style) ----
|
| 332 |
-
st.subheader("🔍 Activation Overlay")
|
| 333 |
-
st.caption("Highest-activation spatial regions from the hooked layer, "
|
| 334 |
-
"overlaid on the crop as a Grad-CAM–style heatmap.")
|
| 335 |
-
try:
|
| 336 |
-
act_maps = backbone.get_activation_maps(crop_aug, n_maps=1)
|
| 337 |
-
if act_maps:
|
| 338 |
-
cam = act_maps[0]
|
| 339 |
-
cam_resized = cv2.resize(cam, (crop_aug.shape[1], crop_aug.shape[0]))
|
| 340 |
-
cam_color = cv2.applyColorMap(
|
| 341 |
-
(cam_resized * 255).astype(np.uint8), cv2.COLORMAP_JET)
|
| 342 |
-
overlay_img = cv2.addWeighted(crop_aug, 0.5, cam_color, 0.5, 0)
|
| 343 |
-
gc1, gc2 = st.columns(2)
|
| 344 |
-
gc1.image(cv2.cvtColor(crop_aug, cv2.COLOR_BGR2RGB),
|
| 345 |
-
caption="Input Crop", use_container_width=True)
|
| 346 |
-
gc2.image(cv2.cvtColor(overlay_img, cv2.COLOR_BGR2RGB),
|
| 347 |
-
caption="Activation Overlay", use_container_width=True)
|
| 348 |
-
except Exception:
|
| 349 |
-
pass
|
| 350 |
-
|
| 351 |
-
st.session_state[f"cnn_head_{selected}"] = head
|
| 352 |
-
st.session_state[f"cnn_acc_{selected}"] = train_acc
|
| 353 |
-
|
| 354 |
-
if f"cnn_head_{selected}" in st.session_state:
|
| 355 |
-
st.divider()
|
| 356 |
-
st.subheader("Quick Predict (Crop)")
|
| 357 |
-
backbone = meta["loader"]()
|
| 358 |
-
head = st.session_state[f"cnn_head_{selected}"]
|
| 359 |
-
t0 = time.perf_counter()
|
| 360 |
-
feats = backbone.get_features(crop_aug)
|
| 361 |
-
label, conf = head.predict(feats)
|
| 362 |
-
dt = (time.perf_counter() - t0) * 1000
|
| 363 |
-
st.write(f"**{label}** — {conf:.1%} confidence — {dt:.1f} ms")
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
# ---------------------------------------------------------------------------
|
| 367 |
-
# RIGHT — ORB Training
|
| 368 |
-
# ---------------------------------------------------------------------------
|
| 369 |
-
with col_orb:
|
| 370 |
-
st.header("🏛️ ORB Matching")
|
| 371 |
-
st.caption("Keypoint-based matching — a fundamentally different paradigm. "
|
| 372 |
-
"Extracts ORB descriptors from each ROI crop and matches them "
|
| 373 |
-
"against image patches using brute-force Hamming distance.")
|
| 374 |
-
|
| 375 |
-
from src.detectors.orb import ORBDetector
|
| 376 |
-
|
| 377 |
-
orb_dist_thresh = st.slider("Match Distance Threshold", 10, 100, 70,
|
| 378 |
-
key="orb_dist")
|
| 379 |
-
orb_min_matches = st.slider("Min Good Matches", 1, 20, 5, key="orb_min")
|
| 380 |
-
|
| 381 |
-
if st.button("🚀 Train ORB Reference"):
|
| 382 |
-
orb = ORBDetector()
|
| 383 |
-
progress = st.progress(0, text="Extracting ORB descriptors...")
|
| 384 |
-
|
| 385 |
-
orb_refs = {}
|
| 386 |
-
for i, roi in enumerate(rois):
|
| 387 |
-
gray = cv2.cvtColor(roi["crop_aug"], cv2.COLOR_BGR2GRAY)
|
| 388 |
-
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
| 389 |
-
gray = clahe.apply(gray)
|
| 390 |
-
kp, des = orb.orb.detectAndCompute(gray, None)
|
| 391 |
-
n_feat = 0 if des is None else len(des)
|
| 392 |
-
orb_refs[roi["label"]] = {
|
| 393 |
-
"descriptors": des,
|
| 394 |
-
"n_features": n_feat,
|
| 395 |
-
"keypoints": kp,
|
| 396 |
-
"crop": roi["crop_aug"],
|
| 397 |
-
}
|
| 398 |
-
progress.progress((i + 1) / len(rois),
|
| 399 |
-
text=f"ROI {i+1}/{len(rois)}: {n_feat} features")
|
| 400 |
-
|
| 401 |
-
progress.progress(1.0, text="✅ ORB references extracted!")
|
| 402 |
-
|
| 403 |
-
for lbl, ref in orb_refs.items():
|
| 404 |
-
if ref["keypoints"]:
|
| 405 |
-
vis = cv2.drawKeypoints(ref["crop"], ref["keypoints"],
|
| 406 |
-
None, color=(0, 255, 0))
|
| 407 |
-
st.image(cv2.cvtColor(vis, cv2.COLOR_BGR2RGB),
|
| 408 |
-
caption=f"{lbl}: {ref['n_features']} keypoints",
|
| 409 |
-
use_container_width=True)
|
| 410 |
-
else:
|
| 411 |
-
st.warning(f"{lbl}: No keypoints detected")
|
| 412 |
-
|
| 413 |
-
st.session_state["orb_detector"] = orb
|
| 414 |
-
st.session_state["orb_refs"] = orb_refs
|
| 415 |
-
st.session_state["orb_dist_thresh"] = orb_dist_thresh
|
| 416 |
-
st.session_state["orb_min_matches"] = orb_min_matches
|
| 417 |
-
st.success("ORB references stored in session!")
|
| 418 |
-
|
| 419 |
-
if "orb_refs" in st.session_state:
|
| 420 |
-
st.divider()
|
| 421 |
-
st.subheader("Quick Predict (Crop)")
|
| 422 |
-
orb = st.session_state["orb_detector"]
|
| 423 |
-
refs = st.session_state["orb_refs"]
|
| 424 |
-
dt_thresh = st.session_state["orb_dist_thresh"]
|
| 425 |
-
min_m = st.session_state["orb_min_matches"]
|
| 426 |
-
|
| 427 |
-
gray = cv2.cvtColor(crop_aug, cv2.COLOR_BGR2GRAY)
|
| 428 |
-
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
| 429 |
-
gray = clahe.apply(gray)
|
| 430 |
-
kp, des = orb.orb.detectAndCompute(gray, None)
|
| 431 |
-
|
| 432 |
-
if des is not None:
|
| 433 |
-
for lbl, ref in refs.items():
|
| 434 |
-
if ref["descriptors"] is None:
|
| 435 |
-
st.write(f"**{lbl}:** no reference features")
|
| 436 |
-
continue
|
| 437 |
-
matches = orb.bf.match(ref["descriptors"], des)
|
| 438 |
-
good = [m for m in matches if m.distance < dt_thresh]
|
| 439 |
-
conf = min(len(good) / max(min_m, 1), 1.0)
|
| 440 |
-
verdict = lbl if len(good) >= min_m else "background"
|
| 441 |
-
st.write(f"**{verdict}** — {len(good)} matches — "
|
| 442 |
-
f"{conf:.0%} confidence")
|
| 443 |
-
else:
|
| 444 |
-
st.write("No keypoints in test image.")
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
# ===========================================================================
|
| 448 |
-
# Bottom — Side-by-side comparison table
|
| 449 |
-
# ===========================================================================
|
| 450 |
-
st.divider()
|
| 451 |
-
st.subheader("📊 Training Comparison")
|
| 452 |
-
|
| 453 |
-
rows = []
|
| 454 |
-
rce_acc = st.session_state.get("rce_train_acc")
|
| 455 |
-
if rce_acc is not None:
|
| 456 |
-
rows.append({"Model": "RCE", "Type": "Feature Engineering",
|
| 457 |
-
"Train Accuracy": f"{rce_acc:.1%}",
|
| 458 |
-
"Vector Size": str(sum(10 for k in active_modules if active_modules[k]))})
|
| 459 |
-
for name in BACKBONES:
|
| 460 |
-
acc = st.session_state.get(f"cnn_acc_{name}")
|
| 461 |
-
if acc is not None:
|
| 462 |
-
rows.append({"Model": name, "Type": "CNN Backbone",
|
| 463 |
-
"Train Accuracy": f"{acc:.1%}",
|
| 464 |
-
"Vector Size": f"{BACKBONES[name]['dim']}D"})
|
| 465 |
-
if "orb_refs" in st.session_state:
|
| 466 |
-
total_kp = sum(r["n_features"] for r in st.session_state["orb_refs"].values())
|
| 467 |
-
rows.append({"Model": "ORB", "Type": "Keypoint Matching",
|
| 468 |
-
"Train Accuracy": "N/A (matching)",
|
| 469 |
-
"Vector Size": f"{total_kp} descriptors"})
|
| 470 |
-
|
| 471 |
-
if rows:
|
| 472 |
-
import pandas as pd
|
| 473 |
-
st.dataframe(pd.DataFrame(rows), use_container_width=True, hide_index=True)
|
| 474 |
-
else:
|
| 475 |
-
st.info("Train at least one model to see the comparison.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pages/5_Localization_Lab.py
DELETED
|
@@ -1,348 +0,0 @@
|
|
| 1 |
-
import streamlit as st
|
| 2 |
-
import cv2
|
| 3 |
-
import numpy as np
|
| 4 |
-
import pandas as pd
|
| 5 |
-
import plotly.graph_objects as go
|
| 6 |
-
import sys, os
|
| 7 |
-
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 8 |
-
|
| 9 |
-
from src.detectors.rce.features import REGISTRY
|
| 10 |
-
from src.models import BACKBONES, RecognitionHead
|
| 11 |
-
from src.utils import build_rce_vector
|
| 12 |
-
from src.localization import (
|
| 13 |
-
exhaustive_sliding_window,
|
| 14 |
-
image_pyramid,
|
| 15 |
-
coarse_to_fine,
|
| 16 |
-
contour_proposals,
|
| 17 |
-
template_matching,
|
| 18 |
-
STRATEGIES,
|
| 19 |
-
)
|
| 20 |
-
|
| 21 |
-
st.set_page_config(page_title="Localization Lab", layout="wide")
|
| 22 |
-
st.title("🔍 Localization Lab")
|
| 23 |
-
st.markdown(
|
| 24 |
-
"Compare **localization strategies** — algorithms that decide *where* "
|
| 25 |
-
"to look in the image. The recognition head stays the same; only the "
|
| 26 |
-
"search method changes."
|
| 27 |
-
)
|
| 28 |
-
|
| 29 |
-
# ===================================================================
|
| 30 |
-
# Guard
|
| 31 |
-
# ===================================================================
|
| 32 |
-
if "pipeline_data" not in st.session_state or \
|
| 33 |
-
"crop" not in st.session_state.get("pipeline_data", {}):
|
| 34 |
-
st.error("Complete **Data Lab** first (upload assets & define a crop).")
|
| 35 |
-
st.stop()
|
| 36 |
-
|
| 37 |
-
assets = st.session_state["pipeline_data"]
|
| 38 |
-
right_img = assets["right"]
|
| 39 |
-
crop = assets["crop"]
|
| 40 |
-
crop_aug = assets.get("crop_aug", crop)
|
| 41 |
-
bbox = assets.get("crop_bbox", (0, 0, crop.shape[1], crop.shape[0]))
|
| 42 |
-
active_mods = st.session_state.get("active_modules",
|
| 43 |
-
{k: True for k in REGISTRY})
|
| 44 |
-
|
| 45 |
-
x0, y0, x1, y1 = bbox
|
| 46 |
-
win_h, win_w = y1 - y0, x1 - x0
|
| 47 |
-
|
| 48 |
-
if win_h <= 0 or win_w <= 0:
|
| 49 |
-
st.error("Invalid window size from crop bbox. "
|
| 50 |
-
"Go back to **Data Lab** and redefine the ROI.")
|
| 51 |
-
st.stop()
|
| 52 |
-
|
| 53 |
-
rce_head = st.session_state.get("rce_head")
|
| 54 |
-
has_any_cnn = any(f"cnn_head_{n}" in st.session_state for n in BACKBONES)
|
| 55 |
-
|
| 56 |
-
if rce_head is None and not has_any_cnn:
|
| 57 |
-
st.warning("No trained heads found. Go to **Model Tuning** first.")
|
| 58 |
-
st.stop()
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
# ===================================================================
|
| 62 |
-
# RCE feature function
|
| 63 |
-
# ===================================================================
|
| 64 |
-
def rce_feature_fn(patch_bgr):
|
| 65 |
-
return build_rce_vector(patch_bgr, active_mods)
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
# ===================================================================
|
| 69 |
-
# Algorithm Reference (collapsible)
|
| 70 |
-
# ===================================================================
|
| 71 |
-
st.divider()
|
| 72 |
-
with st.expander("📚 **Algorithm Reference** — click to expand", expanded=False):
|
| 73 |
-
tabs = st.tabs([f"{v['icon']} {k}" for k, v in STRATEGIES.items()])
|
| 74 |
-
for tab, (name, meta) in zip(tabs, STRATEGIES.items()):
|
| 75 |
-
with tab:
|
| 76 |
-
st.markdown(f"### {meta['icon']} {name}")
|
| 77 |
-
st.caption(meta["short"])
|
| 78 |
-
st.markdown(meta["detail"])
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
# ===================================================================
|
| 82 |
-
# Configuration
|
| 83 |
-
# ===================================================================
|
| 84 |
-
st.divider()
|
| 85 |
-
st.header("⚙️ Configuration")
|
| 86 |
-
|
| 87 |
-
# --- Head selection ---
|
| 88 |
-
col_head, col_info = st.columns([2, 3])
|
| 89 |
-
with col_head:
|
| 90 |
-
head_options = []
|
| 91 |
-
if rce_head is not None:
|
| 92 |
-
head_options.append("RCE")
|
| 93 |
-
trained_cnns = [n for n in BACKBONES if f"cnn_head_{n}" in st.session_state]
|
| 94 |
-
head_options.extend(trained_cnns)
|
| 95 |
-
selected_head = st.selectbox("Recognition Head", head_options,
|
| 96 |
-
key="loc_head")
|
| 97 |
-
|
| 98 |
-
if selected_head == "RCE":
|
| 99 |
-
feature_fn = rce_feature_fn
|
| 100 |
-
head = rce_head
|
| 101 |
-
else:
|
| 102 |
-
bmeta = BACKBONES[selected_head]
|
| 103 |
-
backbone = bmeta["loader"]()
|
| 104 |
-
feature_fn = backbone.get_features
|
| 105 |
-
head = st.session_state[f"cnn_head_{selected_head}"]
|
| 106 |
-
|
| 107 |
-
with col_info:
|
| 108 |
-
if selected_head == "RCE":
|
| 109 |
-
mods = [REGISTRY[k]["label"] for k in active_mods if active_mods[k]]
|
| 110 |
-
st.info(f"**RCE** — Modules: {', '.join(mods)}")
|
| 111 |
-
else:
|
| 112 |
-
st.info(f"**{selected_head}** — "
|
| 113 |
-
f"{BACKBONES[selected_head]['dim']}D feature vector")
|
| 114 |
-
|
| 115 |
-
# --- Algorithm checkboxes ---
|
| 116 |
-
st.subheader("Select Algorithms to Compare")
|
| 117 |
-
algo_cols = st.columns(5)
|
| 118 |
-
algo_names = list(STRATEGIES.keys())
|
| 119 |
-
algo_checks = {}
|
| 120 |
-
for col, name in zip(algo_cols, algo_names):
|
| 121 |
-
algo_checks[name] = col.checkbox(
|
| 122 |
-
f"{STRATEGIES[name]['icon']} {name}",
|
| 123 |
-
value=(name != "Template Matching"), # default all on except TM
|
| 124 |
-
key=f"chk_{name}")
|
| 125 |
-
|
| 126 |
-
any_selected = any(algo_checks.values())
|
| 127 |
-
|
| 128 |
-
# --- Shared parameters ---
|
| 129 |
-
st.subheader("Parameters")
|
| 130 |
-
sp1, sp2, sp3 = st.columns(3)
|
| 131 |
-
stride = sp1.slider("Base Stride (px)", 4, max(win_w, win_h),
|
| 132 |
-
max(win_w // 4, 4), step=2, key="loc_stride")
|
| 133 |
-
conf_thresh = sp2.slider("Confidence Threshold", 0.5, 1.0, 0.7, 0.05,
|
| 134 |
-
key="loc_conf")
|
| 135 |
-
nms_iou = sp3.slider("NMS IoU Threshold", 0.1, 0.9, 0.3, 0.05,
|
| 136 |
-
key="loc_nms")
|
| 137 |
-
|
| 138 |
-
# --- Per-algorithm settings ---
|
| 139 |
-
with st.expander("🔧 Per-Algorithm Settings"):
|
| 140 |
-
pa1, pa2, pa3 = st.columns(3)
|
| 141 |
-
with pa1:
|
| 142 |
-
st.markdown("**Image Pyramid**")
|
| 143 |
-
pyr_min = st.slider("Min Scale", 0.3, 1.0, 0.5, 0.05, key="pyr_min")
|
| 144 |
-
pyr_max = st.slider("Max Scale", 1.0, 2.0, 1.5, 0.1, key="pyr_max")
|
| 145 |
-
pyr_n = st.slider("Number of Scales", 3, 7, 5, key="pyr_n")
|
| 146 |
-
with pa2:
|
| 147 |
-
st.markdown("**Coarse-to-Fine**")
|
| 148 |
-
c2f_factor = st.slider("Coarse Factor", 2, 8, 4, key="c2f_factor")
|
| 149 |
-
c2f_radius = st.slider("Refine Radius (strides)", 1, 5, 2,
|
| 150 |
-
key="c2f_radius")
|
| 151 |
-
with pa3:
|
| 152 |
-
st.markdown("**Contour Proposals**")
|
| 153 |
-
cnt_low = st.slider("Canny Low", 10, 100, 50, key="cnt_low")
|
| 154 |
-
cnt_high = st.slider("Canny High", 50, 300, 150, key="cnt_high")
|
| 155 |
-
cnt_tol = st.slider("Area Tolerance", 1.5, 10.0, 3.0, 0.5,
|
| 156 |
-
key="cnt_tol")
|
| 157 |
-
|
| 158 |
-
st.caption(
|
| 159 |
-
f"Window: **{win_w}×{win_h} px** · "
|
| 160 |
-
f"Image: **{right_img.shape[1]}×{right_img.shape[0]} px** · "
|
| 161 |
-
f"Stride: **{stride} px**"
|
| 162 |
-
)
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
# ===================================================================
|
| 166 |
-
# Run
|
| 167 |
-
# ===================================================================
|
| 168 |
-
st.divider()
|
| 169 |
-
run_btn = st.button("▶ Run Comparison", type="primary",
|
| 170 |
-
disabled=not any_selected, use_container_width=True)
|
| 171 |
-
|
| 172 |
-
if run_btn:
|
| 173 |
-
selected_algos = [n for n in algo_names if algo_checks[n]]
|
| 174 |
-
progress = st.progress(0, text="Starting…")
|
| 175 |
-
results = {}
|
| 176 |
-
edge_maps = {} # for contour visualisation
|
| 177 |
-
|
| 178 |
-
for i, name in enumerate(selected_algos):
|
| 179 |
-
progress.progress(i / len(selected_algos), text=f"Running **{name}**…")
|
| 180 |
-
|
| 181 |
-
if name == "Exhaustive Sliding Window":
|
| 182 |
-
dets, n, ms, hmap = exhaustive_sliding_window(
|
| 183 |
-
right_img, win_h, win_w, feature_fn, head,
|
| 184 |
-
stride, conf_thresh, nms_iou)
|
| 185 |
-
|
| 186 |
-
elif name == "Image Pyramid":
|
| 187 |
-
scales = np.linspace(pyr_min, pyr_max, pyr_n).tolist()
|
| 188 |
-
dets, n, ms, hmap = image_pyramid(
|
| 189 |
-
right_img, win_h, win_w, feature_fn, head,
|
| 190 |
-
stride, conf_thresh, nms_iou, scales=scales)
|
| 191 |
-
|
| 192 |
-
elif name == "Coarse-to-Fine":
|
| 193 |
-
dets, n, ms, hmap = coarse_to_fine(
|
| 194 |
-
right_img, win_h, win_w, feature_fn, head,
|
| 195 |
-
stride, conf_thresh, nms_iou,
|
| 196 |
-
coarse_factor=c2f_factor, refine_radius=c2f_radius)
|
| 197 |
-
|
| 198 |
-
elif name == "Contour Proposals":
|
| 199 |
-
dets, n, ms, hmap, edges = contour_proposals(
|
| 200 |
-
right_img, win_h, win_w, feature_fn, head,
|
| 201 |
-
conf_thresh, nms_iou,
|
| 202 |
-
canny_low=cnt_low, canny_high=cnt_high,
|
| 203 |
-
area_tolerance=cnt_tol)
|
| 204 |
-
edge_maps[name] = edges
|
| 205 |
-
|
| 206 |
-
elif name == "Template Matching":
|
| 207 |
-
dets, n, ms, hmap = template_matching(
|
| 208 |
-
right_img, crop_aug, conf_thresh, nms_iou)
|
| 209 |
-
|
| 210 |
-
results[name] = {
|
| 211 |
-
"dets": dets, "n_proposals": n,
|
| 212 |
-
"time_ms": ms, "heatmap": hmap,
|
| 213 |
-
}
|
| 214 |
-
|
| 215 |
-
progress.progress(1.0, text="Done!")
|
| 216 |
-
|
| 217 |
-
# ===============================================================
|
| 218 |
-
# Summary Table
|
| 219 |
-
# ===============================================================
|
| 220 |
-
st.header("📊 Results")
|
| 221 |
-
|
| 222 |
-
baseline_ms = results.get("Exhaustive Sliding Window", {}).get("time_ms")
|
| 223 |
-
rows = []
|
| 224 |
-
for name, r in results.items():
|
| 225 |
-
speedup = (baseline_ms / r["time_ms"]
|
| 226 |
-
if baseline_ms and r["time_ms"] > 0 else None)
|
| 227 |
-
rows.append({
|
| 228 |
-
"Algorithm": name,
|
| 229 |
-
"Proposals": r["n_proposals"],
|
| 230 |
-
"Time (ms)": round(r["time_ms"], 1),
|
| 231 |
-
"Detections": len(r["dets"]),
|
| 232 |
-
"ms / Proposal": round(r["time_ms"] / max(r["n_proposals"], 1), 4),
|
| 233 |
-
"Speedup": f"{speedup:.1f}×" if speedup else "—",
|
| 234 |
-
})
|
| 235 |
-
|
| 236 |
-
st.dataframe(pd.DataFrame(rows), use_container_width=True, hide_index=True)
|
| 237 |
-
|
| 238 |
-
# ===============================================================
|
| 239 |
-
# Detection Images & Heatmaps (one tab per algorithm)
|
| 240 |
-
# ===============================================================
|
| 241 |
-
st.subheader("Detection Results")
|
| 242 |
-
COLORS = {
|
| 243 |
-
"Exhaustive Sliding Window": (0, 255, 0),
|
| 244 |
-
"Image Pyramid": (255, 128, 0),
|
| 245 |
-
"Coarse-to-Fine": (0, 128, 255),
|
| 246 |
-
"Contour Proposals": (255, 0, 255),
|
| 247 |
-
"Template Matching": (0, 255, 255),
|
| 248 |
-
}
|
| 249 |
-
|
| 250 |
-
result_tabs = st.tabs(
|
| 251 |
-
[f"{STRATEGIES[n]['icon']} {n}" for n in results])
|
| 252 |
-
|
| 253 |
-
for tab, (name, r) in zip(result_tabs, results.items()):
|
| 254 |
-
with tab:
|
| 255 |
-
c1, c2 = st.columns(2)
|
| 256 |
-
color = COLORS.get(name, (0, 255, 0))
|
| 257 |
-
|
| 258 |
-
# --- Detection overlay ---
|
| 259 |
-
vis = right_img.copy()
|
| 260 |
-
for x1d, y1d, x2d, y2d, _, cf in r["dets"]:
|
| 261 |
-
cv2.rectangle(vis, (x1d, y1d), (x2d, y2d), color, 2)
|
| 262 |
-
cv2.putText(vis, f"{cf:.0%}", (x1d, y1d - 6),
|
| 263 |
-
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
|
| 264 |
-
c1.image(cv2.cvtColor(vis, cv2.COLOR_BGR2RGB),
|
| 265 |
-
caption=f"{name} — {len(r['dets'])} detections",
|
| 266 |
-
use_container_width=True)
|
| 267 |
-
|
| 268 |
-
# --- Heatmap ---
|
| 269 |
-
hmap = r["heatmap"]
|
| 270 |
-
if hmap.max() > 0:
|
| 271 |
-
hmap_color = cv2.applyColorMap(
|
| 272 |
-
(hmap / hmap.max() * 255).astype(np.uint8),
|
| 273 |
-
cv2.COLORMAP_JET)
|
| 274 |
-
blend = cv2.addWeighted(right_img, 0.5, hmap_color, 0.5, 0)
|
| 275 |
-
c2.image(cv2.cvtColor(blend, cv2.COLOR_BGR2RGB),
|
| 276 |
-
caption=f"{name} — Confidence Heatmap",
|
| 277 |
-
use_container_width=True)
|
| 278 |
-
else:
|
| 279 |
-
c2.info("No positive responses above threshold.")
|
| 280 |
-
|
| 281 |
-
# --- Contour edge map (extra) ---
|
| 282 |
-
if name in edge_maps:
|
| 283 |
-
st.image(edge_maps[name],
|
| 284 |
-
caption="Canny Edge Map (proposals derived from these contours)",
|
| 285 |
-
use_container_width=True, clamp=True)
|
| 286 |
-
|
| 287 |
-
# --- Per-algorithm metrics ---
|
| 288 |
-
m1, m2, m3, m4 = st.columns(4)
|
| 289 |
-
m1.metric("Proposals", r["n_proposals"])
|
| 290 |
-
m2.metric("Time", f"{r['time_ms']:.0f} ms")
|
| 291 |
-
m3.metric("Detections", len(r["dets"]))
|
| 292 |
-
m4.metric("ms / Proposal",
|
| 293 |
-
f"{r['time_ms'] / max(r['n_proposals'], 1):.3f}")
|
| 294 |
-
|
| 295 |
-
# --- Detection table ---
|
| 296 |
-
if r["dets"]:
|
| 297 |
-
df = pd.DataFrame(r["dets"],
|
| 298 |
-
columns=["x1","y1","x2","y2","label","conf"])
|
| 299 |
-
st.dataframe(df, use_container_width=True, hide_index=True)
|
| 300 |
-
|
| 301 |
-
# ===============================================================
|
| 302 |
-
# Performance Charts
|
| 303 |
-
# ===============================================================
|
| 304 |
-
st.subheader("📈 Performance Comparison")
|
| 305 |
-
ch1, ch2 = st.columns(2)
|
| 306 |
-
|
| 307 |
-
names = list(results.keys())
|
| 308 |
-
times = [results[n]["time_ms"] for n in names]
|
| 309 |
-
props = [results[n]["n_proposals"] for n in names]
|
| 310 |
-
n_dets = [len(results[n]["dets"]) for n in names]
|
| 311 |
-
colors_hex = ["#00cc66", "#ff8800", "#0088ff", "#ff00ff", "#00cccc"]
|
| 312 |
-
|
| 313 |
-
with ch1:
|
| 314 |
-
fig = go.Figure(go.Bar(
|
| 315 |
-
x=names, y=times,
|
| 316 |
-
text=[f"{t:.0f}" for t in times], textposition="auto",
|
| 317 |
-
marker_color=colors_hex[:len(names)]))
|
| 318 |
-
fig.update_layout(title="Total Time (ms)",
|
| 319 |
-
yaxis_title="ms", height=400)
|
| 320 |
-
st.plotly_chart(fig, use_container_width=True)
|
| 321 |
-
|
| 322 |
-
with ch2:
|
| 323 |
-
fig = go.Figure(go.Bar(
|
| 324 |
-
x=names, y=props,
|
| 325 |
-
text=[str(p) for p in props], textposition="auto",
|
| 326 |
-
marker_color=colors_hex[:len(names)]))
|
| 327 |
-
fig.update_layout(title="Proposals Evaluated",
|
| 328 |
-
yaxis_title="Count", height=400)
|
| 329 |
-
st.plotly_chart(fig, use_container_width=True)
|
| 330 |
-
|
| 331 |
-
# --- Scatter: proposals vs time (marker = detections) ---
|
| 332 |
-
fig = go.Figure()
|
| 333 |
-
for i, name in enumerate(names):
|
| 334 |
-
fig.add_trace(go.Scatter(
|
| 335 |
-
x=[props[i]], y=[times[i]],
|
| 336 |
-
mode="markers+text",
|
| 337 |
-
marker=dict(size=max(n_dets[i] * 12, 18),
|
| 338 |
-
color=colors_hex[i % len(colors_hex)]),
|
| 339 |
-
text=[name], textposition="top center",
|
| 340 |
-
name=name,
|
| 341 |
-
))
|
| 342 |
-
fig.update_layout(
|
| 343 |
-
title="Proposals vs Time (marker size ∝ detections)",
|
| 344 |
-
xaxis_title="Proposals Evaluated",
|
| 345 |
-
yaxis_title="Time (ms)",
|
| 346 |
-
height=500,
|
| 347 |
-
)
|
| 348 |
-
st.plotly_chart(fig, use_container_width=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pages/6_RealTime_Detection.py
DELETED
|
@@ -1,435 +0,0 @@
|
|
| 1 |
-
import streamlit as st
|
| 2 |
-
import cv2
|
| 3 |
-
import numpy as np
|
| 4 |
-
import time
|
| 5 |
-
import plotly.graph_objects as go
|
| 6 |
-
import sys, os
|
| 7 |
-
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 8 |
-
|
| 9 |
-
from src.detectors.rce.features import REGISTRY
|
| 10 |
-
from src.models import BACKBONES, RecognitionHead
|
| 11 |
-
from src.utils import build_rce_vector
|
| 12 |
-
from src.localization import nms as _nms, _iou
|
| 13 |
-
|
| 14 |
-
st.set_page_config(page_title="Real-Time Detection", layout="wide")
|
| 15 |
-
st.title("🎯 Real-Time Detection")
|
| 16 |
-
|
| 17 |
-
# ---------------------------------------------------------------------------
|
| 18 |
-
# Guard
|
| 19 |
-
# ---------------------------------------------------------------------------
|
| 20 |
-
if "pipeline_data" not in st.session_state or "crop" not in st.session_state.get("pipeline_data", {}):
|
| 21 |
-
st.error("Complete **Data Lab** first (upload assets & define a crop).")
|
| 22 |
-
st.stop()
|
| 23 |
-
|
| 24 |
-
assets = st.session_state["pipeline_data"]
|
| 25 |
-
right_img = assets["right"]
|
| 26 |
-
crop = assets["crop"]
|
| 27 |
-
crop_aug = assets.get("crop_aug", crop)
|
| 28 |
-
bbox = assets.get("crop_bbox", (0, 0, crop.shape[1], crop.shape[0]))
|
| 29 |
-
rois = assets.get("rois", [{"label": "object", "bbox": bbox,
|
| 30 |
-
"crop": crop, "crop_aug": crop_aug}])
|
| 31 |
-
active_mods = st.session_state.get("active_modules", {k: True for k in REGISTRY})
|
| 32 |
-
|
| 33 |
-
x0, y0, x1, y1 = bbox
|
| 34 |
-
win_h, win_w = y1 - y0, x1 - x0 # window = same size as crop
|
| 35 |
-
|
| 36 |
-
if win_h <= 0 or win_w <= 0:
|
| 37 |
-
st.error("Invalid window size from crop bbox. "
|
| 38 |
-
"Go back to **Data Lab** and redefine the ROI.")
|
| 39 |
-
st.stop()
|
| 40 |
-
|
| 41 |
-
# Color palette for multi-class drawing
|
| 42 |
-
CLASS_COLORS = [(0,255,0),(0,0,255),(255,165,0),(255,0,255),(0,255,255),
|
| 43 |
-
(128,255,0),(255,128,0),(0,128,255)]
|
| 44 |
-
|
| 45 |
-
rce_head = st.session_state.get("rce_head")
|
| 46 |
-
has_any_cnn = any(f"cnn_head_{n}" in st.session_state for n in BACKBONES)
|
| 47 |
-
has_orb = "orb_refs" in st.session_state
|
| 48 |
-
|
| 49 |
-
if rce_head is None and not has_any_cnn and not has_orb:
|
| 50 |
-
st.warning("No trained heads found. Go to **Model Tuning** and train at least one head.")
|
| 51 |
-
st.stop()
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
# ===================================================================
|
| 55 |
-
# Sliding Window Engine (shared by both sides)
|
| 56 |
-
# ===================================================================
|
| 57 |
-
def sliding_window_detect(
|
| 58 |
-
image: np.ndarray,
|
| 59 |
-
feature_fn, # callable(patch_bgr) -> 1-D np.ndarray
|
| 60 |
-
head: RecognitionHead,
|
| 61 |
-
stride: int,
|
| 62 |
-
conf_thresh: float,
|
| 63 |
-
nms_iou: float,
|
| 64 |
-
progress_placeholder=None,
|
| 65 |
-
live_image_placeholder=None,
|
| 66 |
-
):
|
| 67 |
-
"""
|
| 68 |
-
Slide a window of size (win_h, win_w) across *image* with *stride*.
|
| 69 |
-
At each position call *feature_fn* → *head.predict*.
|
| 70 |
-
Returns (detections, heatmap, total_time_ms, n_windows).
|
| 71 |
-
|
| 72 |
-
Each detection is (x, y, x+win_w, y+win_h, label, confidence).
|
| 73 |
-
heatmap is a float32 array same size as image (object confidence).
|
| 74 |
-
"""
|
| 75 |
-
H, W = image.shape[:2]
|
| 76 |
-
heatmap = np.zeros((H, W), dtype=np.float32)
|
| 77 |
-
detections = []
|
| 78 |
-
t0 = time.perf_counter()
|
| 79 |
-
|
| 80 |
-
positions = []
|
| 81 |
-
for y in range(0, H - win_h + 1, stride):
|
| 82 |
-
for x in range(0, W - win_w + 1, stride):
|
| 83 |
-
positions.append((x, y))
|
| 84 |
-
|
| 85 |
-
n_total = len(positions)
|
| 86 |
-
if n_total == 0:
|
| 87 |
-
return [], heatmap, 0.0, 0
|
| 88 |
-
|
| 89 |
-
for idx, (x, y) in enumerate(positions):
|
| 90 |
-
patch = image[y:y+win_h, x:x+win_w]
|
| 91 |
-
feats = feature_fn(patch)
|
| 92 |
-
label, conf = head.predict(feats)
|
| 93 |
-
|
| 94 |
-
# Fill heatmap with non-background confidence
|
| 95 |
-
if label != "background":
|
| 96 |
-
heatmap[y:y+win_h, x:x+win_w] = np.maximum(
|
| 97 |
-
heatmap[y:y+win_h, x:x+win_w], conf)
|
| 98 |
-
if conf >= conf_thresh:
|
| 99 |
-
detections.append((x, y, x+win_w, y+win_h, label, conf))
|
| 100 |
-
|
| 101 |
-
# Live updates (every 5th window or last)
|
| 102 |
-
if live_image_placeholder is not None and (idx % 5 == 0 or idx == n_total - 1):
|
| 103 |
-
vis = image.copy()
|
| 104 |
-
# Draw current scan position
|
| 105 |
-
cv2.rectangle(vis, (x, y), (x+win_w, y+win_h), (255, 255, 0), 1)
|
| 106 |
-
# Draw current detections
|
| 107 |
-
for dx, dy, dx2, dy2, dl, dc in detections:
|
| 108 |
-
cv2.rectangle(vis, (dx, dy), (dx2, dy2), (0, 255, 0), 2)
|
| 109 |
-
cv2.putText(vis, f"{dc:.0%}", (dx, dy - 4),
|
| 110 |
-
cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 255, 0), 1)
|
| 111 |
-
live_image_placeholder.image(
|
| 112 |
-
cv2.cvtColor(vis, cv2.COLOR_BGR2RGB),
|
| 113 |
-
caption=f"Scanning… {idx+1}/{n_total}",
|
| 114 |
-
use_container_width=True)
|
| 115 |
-
|
| 116 |
-
if progress_placeholder is not None:
|
| 117 |
-
progress_placeholder.progress(
|
| 118 |
-
(idx + 1) / n_total,
|
| 119 |
-
text=f"Window {idx+1}/{n_total}")
|
| 120 |
-
|
| 121 |
-
total_ms = (time.perf_counter() - t0) * 1000
|
| 122 |
-
|
| 123 |
-
# --- Non-Maximum Suppression ---
|
| 124 |
-
if detections:
|
| 125 |
-
detections = _nms(detections, nms_iou)
|
| 126 |
-
|
| 127 |
-
return detections, heatmap, total_ms, n_total
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
# ===================================================================
|
| 131 |
-
# RCE feature function
|
| 132 |
-
# ===================================================================
|
| 133 |
-
def rce_feature_fn(patch_bgr):
|
| 134 |
-
return build_rce_vector(patch_bgr, active_mods)
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
# ===================================================================
|
| 138 |
-
# Controls
|
| 139 |
-
# ===================================================================
|
| 140 |
-
st.subheader("Sliding Window Parameters")
|
| 141 |
-
p1, p2, p3 = st.columns(3)
|
| 142 |
-
stride = p1.slider("Stride (px)", 4, max(win_w // 2, 4),
|
| 143 |
-
max(win_w // 4, 4), step=2,
|
| 144 |
-
help="Lower = more windows = slower but finer")
|
| 145 |
-
conf_thresh = p2.slider("Confidence Threshold", 0.5, 1.0, 0.7, 0.05)
|
| 146 |
-
nms_iou = p3.slider("NMS IoU Threshold", 0.1, 0.9, 0.3, 0.05)
|
| 147 |
-
|
| 148 |
-
st.caption(f"Window size: **{win_w}×{win_h} px** | "
|
| 149 |
-
f"Right image: **{right_img.shape[1]}×{right_img.shape[0]} px** | "
|
| 150 |
-
f"≈ {((right_img.shape[0]-win_h)//stride + 1) * ((right_img.shape[1]-win_w)//stride + 1)} windows")
|
| 151 |
-
|
| 152 |
-
st.divider()
|
| 153 |
-
|
| 154 |
-
# ===================================================================
|
| 155 |
-
# Side-by-side layout
|
| 156 |
-
# ===================================================================
|
| 157 |
-
col_rce, col_cnn, col_orb = st.columns(3)
|
| 158 |
-
|
| 159 |
-
# -------------------------------------------------------------------
|
| 160 |
-
# LEFT — RCE Detection
|
| 161 |
-
# -------------------------------------------------------------------
|
| 162 |
-
with col_rce:
|
| 163 |
-
st.header("🧬 RCE Detection")
|
| 164 |
-
if rce_head is None:
|
| 165 |
-
st.info("No RCE head trained. Train one in **Model Tuning**.")
|
| 166 |
-
else:
|
| 167 |
-
st.caption(f"Modules: {', '.join(REGISTRY[k]['label'] for k in active_mods if active_mods[k])}")
|
| 168 |
-
rce_run = st.button("▶ Run RCE Scan", key="rce_run")
|
| 169 |
-
|
| 170 |
-
rce_progress = st.empty()
|
| 171 |
-
rce_live = st.empty()
|
| 172 |
-
rce_results = st.container()
|
| 173 |
-
|
| 174 |
-
if rce_run:
|
| 175 |
-
dets, hmap, ms, nw = sliding_window_detect(
|
| 176 |
-
right_img, rce_feature_fn, rce_head,
|
| 177 |
-
stride, conf_thresh, nms_iou,
|
| 178 |
-
progress_placeholder=rce_progress,
|
| 179 |
-
live_image_placeholder=rce_live,
|
| 180 |
-
)
|
| 181 |
-
|
| 182 |
-
# Final image with boxes
|
| 183 |
-
final = right_img.copy()
|
| 184 |
-
class_labels = sorted(set(d[4] for d in dets)) if dets else []
|
| 185 |
-
for x1d, y1d, x2d, y2d, lbl, cf in dets:
|
| 186 |
-
ci = class_labels.index(lbl) if lbl in class_labels else 0
|
| 187 |
-
clr = CLASS_COLORS[ci % len(CLASS_COLORS)]
|
| 188 |
-
cv2.rectangle(final, (x1d, y1d), (x2d, y2d), clr, 2)
|
| 189 |
-
cv2.putText(final, f"{lbl} {cf:.0%}", (x1d, y1d - 6),
|
| 190 |
-
cv2.FONT_HERSHEY_SIMPLEX, 0.4, clr, 1)
|
| 191 |
-
rce_live.image(cv2.cvtColor(final, cv2.COLOR_BGR2RGB),
|
| 192 |
-
caption="RCE — Final Detections",
|
| 193 |
-
use_container_width=True)
|
| 194 |
-
rce_progress.empty()
|
| 195 |
-
|
| 196 |
-
with rce_results:
|
| 197 |
-
# Metrics
|
| 198 |
-
rm1, rm2, rm3, rm4 = st.columns(4)
|
| 199 |
-
rm1.metric("Detections", len(dets))
|
| 200 |
-
rm2.metric("Windows", nw)
|
| 201 |
-
rm3.metric("Total Time", f"{ms:.0f} ms")
|
| 202 |
-
rm4.metric("Per Window", f"{ms/max(nw,1):.2f} ms")
|
| 203 |
-
|
| 204 |
-
# Confidence heatmap
|
| 205 |
-
if hmap.max() > 0:
|
| 206 |
-
hmap_color = cv2.applyColorMap(
|
| 207 |
-
(hmap / hmap.max() * 255).astype(np.uint8),
|
| 208 |
-
cv2.COLORMAP_JET)
|
| 209 |
-
blend = cv2.addWeighted(right_img, 0.5, hmap_color, 0.5, 0)
|
| 210 |
-
st.image(cv2.cvtColor(blend, cv2.COLOR_BGR2RGB),
|
| 211 |
-
caption="RCE — Confidence Heatmap",
|
| 212 |
-
use_container_width=True)
|
| 213 |
-
|
| 214 |
-
# Detection table
|
| 215 |
-
if dets:
|
| 216 |
-
import pandas as pd
|
| 217 |
-
df = pd.DataFrame(dets, columns=["x1","y1","x2","y2","label","conf"])
|
| 218 |
-
st.dataframe(df, use_container_width=True, hide_index=True)
|
| 219 |
-
|
| 220 |
-
st.session_state["rce_dets"] = dets
|
| 221 |
-
st.session_state["rce_det_ms"] = ms
|
| 222 |
-
|
| 223 |
-
# -------------------------------------------------------------------
|
| 224 |
-
# RIGHT — CNN Detection
|
| 225 |
-
# -------------------------------------------------------------------
|
| 226 |
-
with col_cnn:
|
| 227 |
-
st.header("🧠 CNN Detection")
|
| 228 |
-
|
| 229 |
-
# Find which CNN heads are trained
|
| 230 |
-
trained_cnns = [n for n in BACKBONES if f"cnn_head_{n}" in st.session_state]
|
| 231 |
-
if not trained_cnns:
|
| 232 |
-
st.info("No CNN head trained. Train one in **Model Tuning**.")
|
| 233 |
-
else:
|
| 234 |
-
selected = st.selectbox("Select Model", trained_cnns, key="det_cnn_sel")
|
| 235 |
-
bmeta = BACKBONES[selected]
|
| 236 |
-
backbone = bmeta["loader"]()
|
| 237 |
-
head = st.session_state[f"cnn_head_{selected}"]
|
| 238 |
-
|
| 239 |
-
st.caption(f"Backbone: **{selected}** ({bmeta['dim']}D) — Head in session state")
|
| 240 |
-
cnn_run = st.button(f"▶ Run {selected} Scan", key="cnn_run")
|
| 241 |
-
|
| 242 |
-
cnn_progress = st.empty()
|
| 243 |
-
cnn_live = st.empty()
|
| 244 |
-
cnn_results = st.container()
|
| 245 |
-
|
| 246 |
-
if cnn_run:
|
| 247 |
-
dets, hmap, ms, nw = sliding_window_detect(
|
| 248 |
-
right_img, backbone.get_features, head,
|
| 249 |
-
stride, conf_thresh, nms_iou,
|
| 250 |
-
progress_placeholder=cnn_progress,
|
| 251 |
-
live_image_placeholder=cnn_live,
|
| 252 |
-
)
|
| 253 |
-
|
| 254 |
-
# Final image
|
| 255 |
-
final = right_img.copy()
|
| 256 |
-
class_labels = sorted(set(d[4] for d in dets)) if dets else []
|
| 257 |
-
for x1d, y1d, x2d, y2d, lbl, cf in dets:
|
| 258 |
-
ci = class_labels.index(lbl) if lbl in class_labels else 0
|
| 259 |
-
clr = CLASS_COLORS[ci % len(CLASS_COLORS)]
|
| 260 |
-
cv2.rectangle(final, (x1d, y1d), (x2d, y2d), clr, 2)
|
| 261 |
-
cv2.putText(final, f"{lbl} {cf:.0%}", (x1d, y1d - 6),
|
| 262 |
-
cv2.FONT_HERSHEY_SIMPLEX, 0.4, clr, 1)
|
| 263 |
-
cnn_live.image(cv2.cvtColor(final, cv2.COLOR_BGR2RGB),
|
| 264 |
-
caption=f"{selected} — Final Detections",
|
| 265 |
-
use_container_width=True)
|
| 266 |
-
cnn_progress.empty()
|
| 267 |
-
|
| 268 |
-
with cnn_results:
|
| 269 |
-
cm1, cm2, cm3, cm4 = st.columns(4)
|
| 270 |
-
cm1.metric("Detections", len(dets))
|
| 271 |
-
cm2.metric("Windows", nw)
|
| 272 |
-
cm3.metric("Total Time", f"{ms:.0f} ms")
|
| 273 |
-
cm4.metric("Per Window", f"{ms/max(nw,1):.2f} ms")
|
| 274 |
-
|
| 275 |
-
if hmap.max() > 0:
|
| 276 |
-
hmap_color = cv2.applyColorMap(
|
| 277 |
-
(hmap / hmap.max() * 255).astype(np.uint8),
|
| 278 |
-
cv2.COLORMAP_JET)
|
| 279 |
-
blend = cv2.addWeighted(right_img, 0.5, hmap_color, 0.5, 0)
|
| 280 |
-
st.image(cv2.cvtColor(blend, cv2.COLOR_BGR2RGB),
|
| 281 |
-
caption=f"{selected} — Confidence Heatmap",
|
| 282 |
-
use_container_width=True)
|
| 283 |
-
|
| 284 |
-
if dets:
|
| 285 |
-
import pandas as pd
|
| 286 |
-
df = pd.DataFrame(dets, columns=["x1","y1","x2","y2","label","conf"])
|
| 287 |
-
st.dataframe(df, use_container_width=True, hide_index=True)
|
| 288 |
-
|
| 289 |
-
st.session_state["cnn_dets"] = dets
|
| 290 |
-
st.session_state["cnn_det_ms"] = ms
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
# -------------------------------------------------------------------
|
| 294 |
-
# RIGHT — ORB Detection
|
| 295 |
-
# -------------------------------------------------------------------
|
| 296 |
-
with col_orb:
|
| 297 |
-
st.header("🏛️ ORB Detection")
|
| 298 |
-
if not has_orb:
|
| 299 |
-
st.info("No ORB reference trained. Train one in **Model Tuning**.")
|
| 300 |
-
else:
|
| 301 |
-
orb_det = st.session_state["orb_detector"]
|
| 302 |
-
orb_refs = st.session_state["orb_refs"]
|
| 303 |
-
dt_thresh = st.session_state.get("orb_dist_thresh", 70)
|
| 304 |
-
min_m = st.session_state.get("orb_min_matches", 5)
|
| 305 |
-
st.caption(f"References: {', '.join(orb_refs.keys())} | "
|
| 306 |
-
f"dist<{dt_thresh}, min {min_m} matches")
|
| 307 |
-
orb_run = st.button("▶ Run ORB Scan", key="orb_run")
|
| 308 |
-
|
| 309 |
-
orb_progress = st.empty()
|
| 310 |
-
orb_live = st.empty()
|
| 311 |
-
orb_results = st.container()
|
| 312 |
-
|
| 313 |
-
if orb_run:
|
| 314 |
-
H, W = right_img.shape[:2]
|
| 315 |
-
positions = [(x, y)
|
| 316 |
-
for y in range(0, H - win_h + 1, stride)
|
| 317 |
-
for x in range(0, W - win_w + 1, stride)]
|
| 318 |
-
n_total = len(positions)
|
| 319 |
-
heatmap = np.zeros((H, W), dtype=np.float32)
|
| 320 |
-
detections = []
|
| 321 |
-
t0 = time.perf_counter()
|
| 322 |
-
|
| 323 |
-
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
| 324 |
-
|
| 325 |
-
for idx, (px, py) in enumerate(positions):
|
| 326 |
-
patch = right_img[py:py+win_h, px:px+win_w]
|
| 327 |
-
gray = cv2.cvtColor(patch, cv2.COLOR_BGR2GRAY)
|
| 328 |
-
gray = clahe.apply(gray)
|
| 329 |
-
kp, des = orb_det.orb.detectAndCompute(gray, None)
|
| 330 |
-
|
| 331 |
-
if des is not None:
|
| 332 |
-
best_label, best_conf = "background", 0.0
|
| 333 |
-
for lbl, ref in orb_refs.items():
|
| 334 |
-
if ref["descriptors"] is None:
|
| 335 |
-
continue
|
| 336 |
-
matches = orb_det.bf.match(ref["descriptors"], des)
|
| 337 |
-
good = [m for m in matches if m.distance < dt_thresh]
|
| 338 |
-
conf = min(len(good) / max(min_m, 1), 1.0)
|
| 339 |
-
if len(good) >= min_m and conf > best_conf:
|
| 340 |
-
best_label, best_conf = lbl, conf
|
| 341 |
-
|
| 342 |
-
if best_label != "background":
|
| 343 |
-
heatmap[py:py+win_h, px:px+win_w] = np.maximum(
|
| 344 |
-
heatmap[py:py+win_h, px:px+win_w], best_conf)
|
| 345 |
-
if best_conf >= conf_thresh:
|
| 346 |
-
detections.append(
|
| 347 |
-
(px, py, px+win_w, py+win_h, best_label, best_conf))
|
| 348 |
-
|
| 349 |
-
if idx % 5 == 0 or idx == n_total - 1:
|
| 350 |
-
orb_progress.progress((idx+1)/n_total,
|
| 351 |
-
text=f"Window {idx+1}/{n_total}")
|
| 352 |
-
|
| 353 |
-
total_ms = (time.perf_counter() - t0) * 1000
|
| 354 |
-
if detections:
|
| 355 |
-
detections = _nms(detections, nms_iou)
|
| 356 |
-
|
| 357 |
-
final = right_img.copy()
|
| 358 |
-
cls_labels = sorted(set(d[4] for d in detections)) if detections else []
|
| 359 |
-
for x1d, y1d, x2d, y2d, lbl, cf in detections:
|
| 360 |
-
ci = cls_labels.index(lbl) if lbl in cls_labels else 0
|
| 361 |
-
clr = CLASS_COLORS[ci % len(CLASS_COLORS)]
|
| 362 |
-
cv2.rectangle(final, (x1d, y1d), (x2d, y2d), clr, 2)
|
| 363 |
-
cv2.putText(final, f"{lbl} {cf:.0%}", (x1d, y1d - 6),
|
| 364 |
-
cv2.FONT_HERSHEY_SIMPLEX, 0.4, clr, 1)
|
| 365 |
-
orb_live.image(cv2.cvtColor(final, cv2.COLOR_BGR2RGB),
|
| 366 |
-
caption="ORB — Final Detections",
|
| 367 |
-
use_container_width=True)
|
| 368 |
-
orb_progress.empty()
|
| 369 |
-
|
| 370 |
-
with orb_results:
|
| 371 |
-
om1, om2, om3, om4 = st.columns(4)
|
| 372 |
-
om1.metric("Detections", len(detections))
|
| 373 |
-
om2.metric("Windows", n_total)
|
| 374 |
-
om3.metric("Total Time", f"{total_ms:.0f} ms")
|
| 375 |
-
om4.metric("Per Window", f"{total_ms/max(n_total,1):.2f} ms")
|
| 376 |
-
|
| 377 |
-
if heatmap.max() > 0:
|
| 378 |
-
hmap_color = cv2.applyColorMap(
|
| 379 |
-
(heatmap / heatmap.max() * 255).astype(np.uint8),
|
| 380 |
-
cv2.COLORMAP_JET)
|
| 381 |
-
blend = cv2.addWeighted(right_img, 0.5, hmap_color, 0.5, 0)
|
| 382 |
-
st.image(cv2.cvtColor(blend, cv2.COLOR_BGR2RGB),
|
| 383 |
-
caption="ORB — Confidence Heatmap",
|
| 384 |
-
use_container_width=True)
|
| 385 |
-
|
| 386 |
-
if detections:
|
| 387 |
-
import pandas as pd
|
| 388 |
-
df = pd.DataFrame(detections,
|
| 389 |
-
columns=["x1","y1","x2","y2","label","conf"])
|
| 390 |
-
st.dataframe(df, use_container_width=True, hide_index=True)
|
| 391 |
-
|
| 392 |
-
st.session_state["orb_dets"] = detections
|
| 393 |
-
st.session_state["orb_det_ms"] = total_ms
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
# ===================================================================
|
| 397 |
-
# Bottom — Comparison (if any two have run)
|
| 398 |
-
# ===================================================================
|
| 399 |
-
rce_dets = st.session_state.get("rce_dets")
|
| 400 |
-
cnn_dets = st.session_state.get("cnn_dets")
|
| 401 |
-
orb_dets = st.session_state.get("orb_dets")
|
| 402 |
-
|
| 403 |
-
methods = {}
|
| 404 |
-
if rce_dets is not None:
|
| 405 |
-
methods["RCE"] = (rce_dets, st.session_state.get("rce_det_ms", 0), (0,255,0))
|
| 406 |
-
if cnn_dets is not None:
|
| 407 |
-
methods["CNN"] = (cnn_dets, st.session_state.get("cnn_det_ms", 0), (0,0,255))
|
| 408 |
-
if orb_dets is not None:
|
| 409 |
-
methods["ORB"] = (orb_dets, st.session_state.get("orb_det_ms", 0), (255,165,0))
|
| 410 |
-
|
| 411 |
-
if len(methods) >= 2:
|
| 412 |
-
st.divider()
|
| 413 |
-
st.subheader("📊 Side-by-Side Comparison")
|
| 414 |
-
|
| 415 |
-
import pandas as pd
|
| 416 |
-
comp = {"Metric": ["Detections", "Best Confidence", "Total Time (ms)"]}
|
| 417 |
-
for name, (dets, ms, _) in methods.items():
|
| 418 |
-
comp[name] = [
|
| 419 |
-
len(dets),
|
| 420 |
-
f"{max((d[5] for d in dets), default=0):.1%}",
|
| 421 |
-
f"{ms:.0f}",
|
| 422 |
-
]
|
| 423 |
-
st.dataframe(pd.DataFrame(comp), use_container_width=True, hide_index=True)
|
| 424 |
-
|
| 425 |
-
# Overlay all methods on one image
|
| 426 |
-
overlay = right_img.copy()
|
| 427 |
-
for name, (dets, _, clr) in methods.items():
|
| 428 |
-
for x1d, y1d, x2d, y2d, lbl, cf in dets:
|
| 429 |
-
cv2.rectangle(overlay, (x1d, y1d), (x2d, y2d), clr, 2)
|
| 430 |
-
cv2.putText(overlay, f"{name}:{lbl} {cf:.0%}", (x1d, y1d - 6),
|
| 431 |
-
cv2.FONT_HERSHEY_SIMPLEX, 0.35, clr, 1)
|
| 432 |
-
legend = " | ".join(f"{n}={'green' if c==(0,255,0) else 'blue' if c==(0,0,255) else 'orange'}"
|
| 433 |
-
for n, (_, _, c) in methods.items())
|
| 434 |
-
st.image(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB),
|
| 435 |
-
caption=legend, use_container_width=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pages/7_Evaluation.py
DELETED
|
@@ -1,295 +0,0 @@
|
|
| 1 |
-
import streamlit as st
|
| 2 |
-
import cv2
|
| 3 |
-
import numpy as np
|
| 4 |
-
import plotly.graph_objects as go
|
| 5 |
-
import plotly.figure_factory as ff
|
| 6 |
-
import sys, os
|
| 7 |
-
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 8 |
-
|
| 9 |
-
from src.detectors.rce.features import REGISTRY
|
| 10 |
-
from src.models import BACKBONES
|
| 11 |
-
|
| 12 |
-
st.set_page_config(page_title="Evaluation", layout="wide")
|
| 13 |
-
st.title("📈 Evaluation: Confusion Matrix & PR Curves")
|
| 14 |
-
|
| 15 |
-
# ---------------------------------------------------------------------------
|
| 16 |
-
# Guard
|
| 17 |
-
# ---------------------------------------------------------------------------
|
| 18 |
-
if "pipeline_data" not in st.session_state:
|
| 19 |
-
st.error("Complete the **Data Lab** first.")
|
| 20 |
-
st.stop()
|
| 21 |
-
|
| 22 |
-
assets = st.session_state["pipeline_data"]
|
| 23 |
-
crop = assets["crop"]
|
| 24 |
-
crop_aug = assets.get("crop_aug", crop)
|
| 25 |
-
bbox = assets.get("crop_bbox", (0, 0, crop.shape[1], crop.shape[0]))
|
| 26 |
-
rois = assets.get("rois", [{"label": "object", "bbox": bbox,
|
| 27 |
-
"crop": crop, "crop_aug": crop_aug}])
|
| 28 |
-
|
| 29 |
-
rce_dets = st.session_state.get("rce_dets")
|
| 30 |
-
cnn_dets = st.session_state.get("cnn_dets")
|
| 31 |
-
orb_dets = st.session_state.get("orb_dets")
|
| 32 |
-
|
| 33 |
-
if rce_dets is None and cnn_dets is None and orb_dets is None:
|
| 34 |
-
st.warning("Run detection on at least one method in **Real-Time Detection** first.")
|
| 35 |
-
st.stop()
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
# ---------------------------------------------------------------------------
|
| 39 |
-
# Ground truth from ROIs
|
| 40 |
-
# ---------------------------------------------------------------------------
|
| 41 |
-
gt_boxes = [(roi["bbox"], roi["label"]) for roi in rois]
|
| 42 |
-
|
| 43 |
-
st.sidebar.subheader("Evaluation Settings")
|
| 44 |
-
iou_thresh = st.sidebar.slider("IoU Threshold", 0.1, 0.9, 0.5, 0.05,
|
| 45 |
-
help="Minimum IoU to count a detection as TP")
|
| 46 |
-
|
| 47 |
-
st.subheader("Ground Truth (from Data Lab ROIs)")
|
| 48 |
-
st.caption(f"{len(gt_boxes)} ground-truth ROIs defined")
|
| 49 |
-
gt_vis = assets["right"].copy()
|
| 50 |
-
for (bx0, by0, bx1, by1), lbl in gt_boxes:
|
| 51 |
-
cv2.rectangle(gt_vis, (bx0, by0), (bx1, by1), (0, 255, 255), 2)
|
| 52 |
-
cv2.putText(gt_vis, lbl, (bx0, by0 - 6),
|
| 53 |
-
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 255), 1)
|
| 54 |
-
st.image(cv2.cvtColor(gt_vis, cv2.COLOR_BGR2RGB),
|
| 55 |
-
caption="Ground Truth Annotations", use_container_width=True)
|
| 56 |
-
|
| 57 |
-
st.divider()
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
# ---------------------------------------------------------------------------
|
| 61 |
-
# Matching helpers
|
| 62 |
-
# ---------------------------------------------------------------------------
|
| 63 |
-
def _iou(a, b):
|
| 64 |
-
xi1 = max(a[0], b[0]); yi1 = max(a[1], b[1])
|
| 65 |
-
xi2 = min(a[2], b[2]); yi2 = min(a[3], b[3])
|
| 66 |
-
inter = max(0, xi2 - xi1) * max(0, yi2 - yi1)
|
| 67 |
-
aa = (a[2] - a[0]) * (a[3] - a[1])
|
| 68 |
-
ab = (b[2] - b[0]) * (b[3] - b[1])
|
| 69 |
-
return inter / (aa + ab - inter + 1e-6)
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
def match_detections(dets, gt_list, iou_thr):
|
| 73 |
-
"""
|
| 74 |
-
Match detections to GT boxes.
|
| 75 |
-
Returns (results, n_missed, matched_gt_indices).
|
| 76 |
-
results = list of (det, matched_gt_label_or_None, iou) sorted by confidence.
|
| 77 |
-
matched_gt_indices = set of GT indices that were matched.
|
| 78 |
-
"""
|
| 79 |
-
dets_sorted = sorted(dets, key=lambda d: d[5], reverse=True)
|
| 80 |
-
matched_gt = set()
|
| 81 |
-
results = []
|
| 82 |
-
|
| 83 |
-
for det in dets_sorted:
|
| 84 |
-
det_box = det[:4]
|
| 85 |
-
det_label = det[4]
|
| 86 |
-
best_iou = 0.0
|
| 87 |
-
best_gt_idx = -1
|
| 88 |
-
best_gt_label = None
|
| 89 |
-
|
| 90 |
-
for gi, (gt_box, gt_label) in enumerate(gt_list):
|
| 91 |
-
if gi in matched_gt:
|
| 92 |
-
continue
|
| 93 |
-
iou_val = _iou(det_box, gt_box)
|
| 94 |
-
if iou_val > best_iou:
|
| 95 |
-
best_iou = iou_val
|
| 96 |
-
best_gt_idx = gi
|
| 97 |
-
best_gt_label = gt_label
|
| 98 |
-
|
| 99 |
-
if best_iou >= iou_thr and best_gt_idx >= 0:
|
| 100 |
-
matched_gt.add(best_gt_idx)
|
| 101 |
-
results.append((det, best_gt_label, best_iou))
|
| 102 |
-
else:
|
| 103 |
-
results.append((det, None, best_iou))
|
| 104 |
-
|
| 105 |
-
return results, len(gt_list) - len(matched_gt), matched_gt
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
def compute_pr_curve(dets, gt_list, iou_thr, steps=50):
|
| 109 |
-
"""
|
| 110 |
-
Sweep confidence thresholds and compute precision/recall.
|
| 111 |
-
Returns (thresholds, precisions, recalls, f1s).
|
| 112 |
-
"""
|
| 113 |
-
if not dets:
|
| 114 |
-
return [], [], [], []
|
| 115 |
-
|
| 116 |
-
thresholds = np.linspace(0.0, 1.0, steps)
|
| 117 |
-
precisions, recalls, f1s = [], [], []
|
| 118 |
-
|
| 119 |
-
for thr in thresholds:
|
| 120 |
-
filtered = [d for d in dets if d[5] >= thr]
|
| 121 |
-
if not filtered:
|
| 122 |
-
precisions.append(1.0)
|
| 123 |
-
recalls.append(0.0)
|
| 124 |
-
f1s.append(0.0)
|
| 125 |
-
continue
|
| 126 |
-
|
| 127 |
-
matched, n_missed, _ = match_detections(filtered, gt_list, iou_thr)
|
| 128 |
-
tp = sum(1 for _, gt_lbl, _ in matched if gt_lbl is not None)
|
| 129 |
-
fp = sum(1 for _, gt_lbl, _ in matched if gt_lbl is None)
|
| 130 |
-
fn = n_missed
|
| 131 |
-
|
| 132 |
-
prec = tp / (tp + fp) if (tp + fp) > 0 else 1.0
|
| 133 |
-
rec = tp / (tp + fn) if (tp + fn) > 0 else 0.0
|
| 134 |
-
f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0.0
|
| 135 |
-
precisions.append(prec)
|
| 136 |
-
recalls.append(rec)
|
| 137 |
-
f1s.append(f1)
|
| 138 |
-
|
| 139 |
-
return thresholds.tolist(), precisions, recalls, f1s
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
def build_confusion_matrix(dets, gt_list, iou_thr):
|
| 143 |
-
"""
|
| 144 |
-
Build a confusion matrix: rows = predicted, cols = actual.
|
| 145 |
-
Classes = all GT labels + 'background'.
|
| 146 |
-
"""
|
| 147 |
-
gt_labels = sorted(set(lbl for _, lbl in gt_list))
|
| 148 |
-
all_labels = gt_labels + ["background"]
|
| 149 |
-
|
| 150 |
-
n = len(all_labels)
|
| 151 |
-
matrix = np.zeros((n, n), dtype=int)
|
| 152 |
-
label_to_idx = {lbl: i for i, lbl in enumerate(all_labels)}
|
| 153 |
-
|
| 154 |
-
matched, n_missed, matched_gt_indices = match_detections(dets, gt_list, iou_thr)
|
| 155 |
-
|
| 156 |
-
for det, gt_lbl, _ in matched:
|
| 157 |
-
pred_lbl = det[4]
|
| 158 |
-
if gt_lbl is not None:
|
| 159 |
-
# TP or mislabel
|
| 160 |
-
pi = label_to_idx.get(pred_lbl, label_to_idx["background"])
|
| 161 |
-
gi = label_to_idx[gt_lbl]
|
| 162 |
-
matrix[pi][gi] += 1
|
| 163 |
-
else:
|
| 164 |
-
# FP
|
| 165 |
-
pi = label_to_idx.get(pred_lbl, label_to_idx["background"])
|
| 166 |
-
matrix[pi][label_to_idx["background"]] += 1
|
| 167 |
-
|
| 168 |
-
# FN: unmatched GT
|
| 169 |
-
for gi, (_, gt_lbl) in enumerate(gt_list):
|
| 170 |
-
if gi not in matched_gt_indices:
|
| 171 |
-
matrix[label_to_idx["background"]][label_to_idx[gt_lbl]] += 1
|
| 172 |
-
|
| 173 |
-
return matrix, all_labels
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
# ---------------------------------------------------------------------------
|
| 177 |
-
# Collect all methods with detections
|
| 178 |
-
# ---------------------------------------------------------------------------
|
| 179 |
-
methods = {}
|
| 180 |
-
if rce_dets is not None:
|
| 181 |
-
methods["RCE"] = rce_dets
|
| 182 |
-
if cnn_dets is not None:
|
| 183 |
-
methods["CNN"] = cnn_dets
|
| 184 |
-
if orb_dets is not None:
|
| 185 |
-
methods["ORB"] = orb_dets
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
# ===================================================================
|
| 189 |
-
# 1. Confusion Matrices
|
| 190 |
-
# ===================================================================
|
| 191 |
-
st.subheader("🔲 Confusion Matrices")
|
| 192 |
-
cm_cols = st.columns(len(methods))
|
| 193 |
-
|
| 194 |
-
for col, (name, dets) in zip(cm_cols, methods.items()):
|
| 195 |
-
with col:
|
| 196 |
-
st.markdown(f"**{name}**")
|
| 197 |
-
matrix, labels = build_confusion_matrix(dets, gt_boxes, iou_thresh)
|
| 198 |
-
|
| 199 |
-
fig_cm = ff.create_annotated_heatmap(
|
| 200 |
-
z=matrix.tolist(),
|
| 201 |
-
x=labels, y=labels,
|
| 202 |
-
colorscale="Blues",
|
| 203 |
-
showscale=True)
|
| 204 |
-
fig_cm.update_layout(
|
| 205 |
-
title=f"{name} Confusion Matrix",
|
| 206 |
-
xaxis_title="Actual",
|
| 207 |
-
yaxis_title="Predicted",
|
| 208 |
-
template="plotly_dark",
|
| 209 |
-
height=350)
|
| 210 |
-
fig_cm.update_yaxes(autorange="reversed")
|
| 211 |
-
st.plotly_chart(fig_cm, use_container_width=True)
|
| 212 |
-
|
| 213 |
-
# Summary metrics at this default threshold
|
| 214 |
-
matched, n_missed, _ = match_detections(dets, gt_boxes, iou_thresh)
|
| 215 |
-
tp = sum(1 for _, g, _ in matched if g is not None)
|
| 216 |
-
fp = sum(1 for _, g, _ in matched if g is None)
|
| 217 |
-
fn = n_missed
|
| 218 |
-
prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0
|
| 219 |
-
rec = tp / (tp + fn) if (tp + fn) > 0 else 0.0
|
| 220 |
-
f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0.0
|
| 221 |
-
|
| 222 |
-
m1, m2, m3 = st.columns(3)
|
| 223 |
-
m1.metric("Precision", f"{prec:.1%}")
|
| 224 |
-
m2.metric("Recall", f"{rec:.1%}")
|
| 225 |
-
m3.metric("F1 Score", f"{f1:.1%}")
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
# ===================================================================
|
| 229 |
-
# 2. Precision-Recall Curves
|
| 230 |
-
# ===================================================================
|
| 231 |
-
st.divider()
|
| 232 |
-
st.subheader("📉 Precision-Recall Curves")
|
| 233 |
-
|
| 234 |
-
method_colors = {"RCE": "#00ff88", "CNN": "#4488ff", "ORB": "#ff8800"}
|
| 235 |
-
fig_pr = go.Figure()
|
| 236 |
-
fig_f1 = go.Figure()
|
| 237 |
-
|
| 238 |
-
summary_rows = []
|
| 239 |
-
|
| 240 |
-
for name, dets in methods.items():
|
| 241 |
-
thrs, precs, recs, f1s = compute_pr_curve(dets, gt_boxes, iou_thresh)
|
| 242 |
-
clr = method_colors.get(name, "#ffffff")
|
| 243 |
-
|
| 244 |
-
fig_pr.add_trace(go.Scatter(
|
| 245 |
-
x=recs, y=precs, mode="lines+markers",
|
| 246 |
-
name=name, line=dict(color=clr, width=2),
|
| 247 |
-
marker=dict(size=4)))
|
| 248 |
-
|
| 249 |
-
fig_f1.add_trace(go.Scatter(
|
| 250 |
-
x=thrs, y=f1s, mode="lines",
|
| 251 |
-
name=name, line=dict(color=clr, width=2)))
|
| 252 |
-
|
| 253 |
-
# AP (area under PR curve)
|
| 254 |
-
if recs and precs:
|
| 255 |
-
ap = float(np.trapz(precs, recs))
|
| 256 |
-
else:
|
| 257 |
-
ap = 0.0
|
| 258 |
-
|
| 259 |
-
best_f1_idx = int(np.argmax(f1s)) if f1s else 0
|
| 260 |
-
summary_rows.append({
|
| 261 |
-
"Method": name,
|
| 262 |
-
"AP": f"{abs(ap):.3f}",
|
| 263 |
-
"Best F1": f"{f1s[best_f1_idx]:.3f}" if f1s else "N/A",
|
| 264 |
-
"@ Threshold": f"{thrs[best_f1_idx]:.2f}" if thrs else "N/A",
|
| 265 |
-
"Detections": len(dets),
|
| 266 |
-
})
|
| 267 |
-
|
| 268 |
-
fig_pr.update_layout(
|
| 269 |
-
title="Precision vs Recall",
|
| 270 |
-
xaxis_title="Recall", yaxis_title="Precision",
|
| 271 |
-
template="plotly_dark", height=400,
|
| 272 |
-
xaxis=dict(range=[0, 1.05]), yaxis=dict(range=[0, 1.05]))
|
| 273 |
-
|
| 274 |
-
fig_f1.update_layout(
|
| 275 |
-
title="F1 Score vs Confidence Threshold",
|
| 276 |
-
xaxis_title="Confidence Threshold", yaxis_title="F1 Score",
|
| 277 |
-
template="plotly_dark", height=400,
|
| 278 |
-
xaxis=dict(range=[0, 1.05]), yaxis=dict(range=[0, 1.05]))
|
| 279 |
-
|
| 280 |
-
pc1, pc2 = st.columns(2)
|
| 281 |
-
pc1.plotly_chart(fig_pr, use_container_width=True)
|
| 282 |
-
pc2.plotly_chart(fig_f1, use_container_width=True)
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
# ===================================================================
|
| 286 |
-
# 3. Summary Table
|
| 287 |
-
# ===================================================================
|
| 288 |
-
st.divider()
|
| 289 |
-
st.subheader("📊 Summary")
|
| 290 |
-
|
| 291 |
-
import pandas as pd
|
| 292 |
-
st.dataframe(pd.DataFrame(summary_rows), use_container_width=True, hide_index=True)
|
| 293 |
-
|
| 294 |
-
st.caption(f"All metrics computed at IoU threshold = **{iou_thresh:.2f}**. "
|
| 295 |
-
"Adjust in the sidebar to explore sensitivity.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pages/8_Stereo_Geometry.py
DELETED
|
@@ -1,353 +0,0 @@
|
|
| 1 |
-
import streamlit as st
|
| 2 |
-
import cv2
|
| 3 |
-
import numpy as np
|
| 4 |
-
import re
|
| 5 |
-
import pandas as pd
|
| 6 |
-
import plotly.graph_objects as go
|
| 7 |
-
import sys, os
|
| 8 |
-
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 9 |
-
|
| 10 |
-
st.set_page_config(page_title="Stereo Geometry", layout="wide")
|
| 11 |
-
st.title("📐 Stereo Geometry: Distance Estimation")
|
| 12 |
-
|
| 13 |
-
# ---------------------------------------------------------------------------
|
| 14 |
-
# Guard
|
| 15 |
-
# ---------------------------------------------------------------------------
|
| 16 |
-
if "pipeline_data" not in st.session_state or "left" not in st.session_state.get("pipeline_data", {}):
|
| 17 |
-
st.error("Complete **Data Lab** first.")
|
| 18 |
-
st.stop()
|
| 19 |
-
|
| 20 |
-
assets = st.session_state["pipeline_data"]
|
| 21 |
-
img_l = assets["left"]
|
| 22 |
-
img_r = assets["right"]
|
| 23 |
-
gt_left = assets.get("gt_left") # float32 disparity map from PFM
|
| 24 |
-
gt_right = assets.get("gt_right")
|
| 25 |
-
conf_raw = assets.get("conf_raw", "")
|
| 26 |
-
crop_bbox = assets.get("crop_bbox") # (x0, y0, x1, y1) on LEFT image
|
| 27 |
-
|
| 28 |
-
rce_dets = st.session_state.get("rce_dets", [])
|
| 29 |
-
cnn_dets = st.session_state.get("cnn_dets", [])
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
# ===================================================================
|
| 33 |
-
# Parse Middlebury-style camera config
|
| 34 |
-
# ===================================================================
|
| 35 |
-
def parse_config(text: str) -> dict:
|
| 36 |
-
"""
|
| 37 |
-
Parse a Middlebury .txt / .conf calibration file.
|
| 38 |
-
Expected keys: cam0, cam1, doffs, baseline, width, height, ndisp, vmin, vmax
|
| 39 |
-
cam0 / cam1 are 3×3 matrices in bracket notation: [f 0 cx; 0 f cy; 0 0 1]
|
| 40 |
-
"""
|
| 41 |
-
params = {}
|
| 42 |
-
if not text or not text.strip():
|
| 43 |
-
return params
|
| 44 |
-
for line in text.strip().splitlines():
|
| 45 |
-
line = line.strip()
|
| 46 |
-
if "=" not in line:
|
| 47 |
-
continue
|
| 48 |
-
key, val = line.split("=", 1)
|
| 49 |
-
key = key.strip()
|
| 50 |
-
val = val.strip()
|
| 51 |
-
if "[" in val:
|
| 52 |
-
nums = list(map(float, re.findall(r"[-+]?\d*\.?\d+(?:[eE][-+]?\d+)?", val)))
|
| 53 |
-
params[key] = np.array(nums).reshape(3, 3) if len(nums) == 9 else nums
|
| 54 |
-
else:
|
| 55 |
-
try:
|
| 56 |
-
params[key] = float(val)
|
| 57 |
-
except ValueError:
|
| 58 |
-
params[key] = val
|
| 59 |
-
return params
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
calib = parse_config(conf_raw)
|
| 63 |
-
|
| 64 |
-
# Extract intrinsics
|
| 65 |
-
cam0 = calib.get("cam0")
|
| 66 |
-
focal = float(cam0[0, 0]) if isinstance(cam0, np.ndarray) and cam0.shape == (3, 3) else 0.0
|
| 67 |
-
doffs = float(calib.get("doffs", 0.0))
|
| 68 |
-
baseline = float(calib.get("baseline", 1.0))
|
| 69 |
-
ndisp = int(calib.get("ndisp", 128))
|
| 70 |
-
|
| 71 |
-
if focal <= 0:
|
| 72 |
-
st.error("❌ Focal length is **0** — the camera config is missing or malformed. "
|
| 73 |
-
"Depth estimation cannot proceed. Return to **Data Lab** and upload a valid "
|
| 74 |
-
"Middlebury camera config.")
|
| 75 |
-
st.stop()
|
| 76 |
-
|
| 77 |
-
if focal > 10000:
|
| 78 |
-
st.error(f"❌ Focal length ({focal:.0f} px) is suspiciously large. "
|
| 79 |
-
"Check your camera config file.")
|
| 80 |
-
st.stop()
|
| 81 |
-
|
| 82 |
-
if baseline <= 0 or baseline > 1000:
|
| 83 |
-
st.error(f"❌ Invalid baseline ({baseline:.1f}). Expected 1–1000 mm.")
|
| 84 |
-
st.stop()
|
| 85 |
-
|
| 86 |
-
st.subheader("Camera Calibration")
|
| 87 |
-
cc1, cc2, cc3, cc4 = st.columns(4)
|
| 88 |
-
cc1.metric("Focal Length (px)", f"{focal:.1f}")
|
| 89 |
-
cc2.metric("Baseline (mm)", f"{baseline:.1f}")
|
| 90 |
-
cc3.metric("Doffs (px)", f"{doffs:.2f}")
|
| 91 |
-
cc4.metric("ndisp", str(ndisp))
|
| 92 |
-
|
| 93 |
-
with st.expander("Full Calibration"):
|
| 94 |
-
st.json({k: v.tolist() if isinstance(v, np.ndarray) else v for k, v in calib.items()})
|
| 95 |
-
|
| 96 |
-
st.divider()
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
# ===================================================================
|
| 100 |
-
# Image-size validation
|
| 101 |
-
# ===================================================================
|
| 102 |
-
if img_l.shape[:2] != img_r.shape[:2]:
|
| 103 |
-
st.error(f"Left ({img_l.shape[1]}×{img_l.shape[0]}) and right "
|
| 104 |
-
f"({img_r.shape[1]}×{img_r.shape[0]}) images must be the same size.")
|
| 105 |
-
st.stop()
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
# ===================================================================
|
| 109 |
-
# Step 1 — Compute Disparity Map
|
| 110 |
-
# ===================================================================
|
| 111 |
-
st.subheader("Step 1: Disparity Map (StereoSGBM)")
|
| 112 |
-
|
| 113 |
-
sc1, sc2, sc3 = st.columns(3)
|
| 114 |
-
block_size = sc1.slider("Block Size", 3, 21, 5, step=2)
|
| 115 |
-
p1_mult = sc2.slider("P1 multiplier", 1, 32, 8)
|
| 116 |
-
p2_mult = sc3.slider("P2 multiplier", 1, 128, 32)
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
@st.cache_data
|
| 120 |
-
def compute_disparity(_left, _right, ndisp, block_size, p1m, p2m):
|
| 121 |
-
"""StereoSGBM disparity. _left/_right are un-hashed (numpy arrays)."""
|
| 122 |
-
gray_l = cv2.cvtColor(_left, cv2.COLOR_BGR2GRAY)
|
| 123 |
-
gray_r = cv2.cvtColor(_right, cv2.COLOR_BGR2GRAY)
|
| 124 |
-
|
| 125 |
-
nd = max(16, (ndisp // 16) * 16)
|
| 126 |
-
sgbm = cv2.StereoSGBM_create(
|
| 127 |
-
minDisparity=0,
|
| 128 |
-
numDisparities=nd,
|
| 129 |
-
blockSize=block_size,
|
| 130 |
-
P1=p1m * 1 * block_size ** 2,
|
| 131 |
-
P2=p2m * 1 * block_size ** 2,
|
| 132 |
-
disp12MaxDiff=1,
|
| 133 |
-
uniquenessRatio=10,
|
| 134 |
-
speckleWindowSize=100,
|
| 135 |
-
speckleRange=32,
|
| 136 |
-
mode=cv2.STEREO_SGBM_MODE_SGBM_3WAY,
|
| 137 |
-
)
|
| 138 |
-
return sgbm.compute(gray_l, gray_r).astype(np.float32) / 16.0
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
with st.spinner("Computing disparity…"):
|
| 142 |
-
try:
|
| 143 |
-
disp = compute_disparity(img_l, img_r, ndisp, block_size, p1_mult, p2_mult)
|
| 144 |
-
except cv2.error as e:
|
| 145 |
-
st.error(f"StereoSGBM failed: {e}")
|
| 146 |
-
st.stop()
|
| 147 |
-
|
| 148 |
-
# Visualize disparity
|
| 149 |
-
disp_vis = np.clip(disp, 0, None)
|
| 150 |
-
disp_max = disp_vis.max() if disp_vis.max() > 0 else 1.0
|
| 151 |
-
disp_norm = (disp_vis / disp_max * 255).astype(np.uint8)
|
| 152 |
-
disp_color = cv2.applyColorMap(disp_norm, cv2.COLORMAP_INFERNO)
|
| 153 |
-
|
| 154 |
-
dc1, dc2 = st.columns(2)
|
| 155 |
-
dc1.image(cv2.cvtColor(img_l, cv2.COLOR_BGR2RGB), caption="Left Image", use_container_width=True)
|
| 156 |
-
dc2.image(cv2.cvtColor(disp_color, cv2.COLOR_BGR2RGB), caption="Disparity Map (SGBM)", use_container_width=True)
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
# ===================================================================
|
| 160 |
-
# Step 2 — Depth Map from Disparity
|
| 161 |
-
# ===================================================================
|
| 162 |
-
st.divider()
|
| 163 |
-
st.subheader("Step 2: Depth Map from Disparity")
|
| 164 |
-
|
| 165 |
-
st.latex(r"Z = \frac{f \times B}{d + d_{\text{offs}}}")
|
| 166 |
-
st.caption("Z = depth (mm), f = focal length (px), B = baseline (mm), d = disparity (px), d_offs = optical center offset (px)")
|
| 167 |
-
|
| 168 |
-
# Compute depth from disparity
|
| 169 |
-
valid = (disp + doffs) > 0
|
| 170 |
-
depth_map = np.zeros_like(disp)
|
| 171 |
-
if focal > 0:
|
| 172 |
-
depth_map[valid] = (focal * baseline) / (disp[valid] + doffs)
|
| 173 |
-
|
| 174 |
-
# Visualize
|
| 175 |
-
depth_vis = depth_map.copy()
|
| 176 |
-
finite = depth_vis[depth_vis > 0]
|
| 177 |
-
if len(finite) > 0:
|
| 178 |
-
clip_max = np.percentile(finite, 98)
|
| 179 |
-
depth_vis = np.clip(depth_vis, 0, clip_max)
|
| 180 |
-
depth_norm = (depth_vis / clip_max * 255).astype(np.uint8)
|
| 181 |
-
else:
|
| 182 |
-
depth_norm = np.zeros_like(depth_map, dtype=np.uint8)
|
| 183 |
-
|
| 184 |
-
depth_color = cv2.applyColorMap(depth_norm, cv2.COLORMAP_TURBO)
|
| 185 |
-
|
| 186 |
-
zc1, zc2 = st.columns(2)
|
| 187 |
-
zc1.image(cv2.cvtColor(depth_color, cv2.COLOR_BGR2RGB),
|
| 188 |
-
caption="Estimated Depth (SGBM)", use_container_width=True)
|
| 189 |
-
|
| 190 |
-
# Ground truth comparison
|
| 191 |
-
if gt_left is not None:
|
| 192 |
-
gt_vis = gt_left.copy()
|
| 193 |
-
gt_finite = gt_vis[np.isfinite(gt_vis) & (gt_vis > 0)]
|
| 194 |
-
if len(gt_finite) > 0:
|
| 195 |
-
gt_clip = np.percentile(gt_finite, 98)
|
| 196 |
-
gt_vis = np.clip(np.nan_to_num(gt_vis, nan=0), 0, gt_clip)
|
| 197 |
-
gt_norm = (gt_vis / gt_clip * 255).astype(np.uint8)
|
| 198 |
-
else:
|
| 199 |
-
gt_norm = np.zeros_like(gt_vis, dtype=np.uint8)
|
| 200 |
-
gt_color = cv2.applyColorMap(gt_norm, cv2.COLORMAP_TURBO)
|
| 201 |
-
zc2.image(cv2.cvtColor(gt_color, cv2.COLOR_BGR2RGB),
|
| 202 |
-
caption="Ground Truth Disparity (from PFM)", use_container_width=True)
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
# ===================================================================
|
| 206 |
-
# Step 3 — Error Map (SGBM vs Ground Truth)
|
| 207 |
-
# ===================================================================
|
| 208 |
-
if gt_left is not None:
|
| 209 |
-
st.divider()
|
| 210 |
-
st.subheader("Step 3: Error Analysis (SGBM vs Ground Truth)")
|
| 211 |
-
|
| 212 |
-
gt_disp = gt_left # Middlebury standard: PFM = disparity map
|
| 213 |
-
|
| 214 |
-
# Ensure GT and SGBM disparity have the same shape
|
| 215 |
-
if gt_disp.shape[:2] != disp.shape[:2]:
|
| 216 |
-
st.warning(
|
| 217 |
-
f"Ground truth shape ({gt_disp.shape[1]}×{gt_disp.shape[0]}) differs from "
|
| 218 |
-
f"disparity shape ({disp.shape[1]}×{disp.shape[0]}). Resizing GT to match."
|
| 219 |
-
)
|
| 220 |
-
gt_disp = cv2.resize(gt_disp, (disp.shape[1], disp.shape[0]),
|
| 221 |
-
interpolation=cv2.INTER_NEAREST)
|
| 222 |
-
|
| 223 |
-
gt_valid = np.isfinite(gt_disp) & (gt_disp > 0)
|
| 224 |
-
both_valid = valid & gt_valid
|
| 225 |
-
|
| 226 |
-
if both_valid.any():
|
| 227 |
-
# Disparity error
|
| 228 |
-
disp_err = np.abs(disp - gt_disp)
|
| 229 |
-
disp_err[~both_valid] = 0
|
| 230 |
-
|
| 231 |
-
# Stats
|
| 232 |
-
err_vals = disp_err[both_valid]
|
| 233 |
-
mae = float(np.mean(err_vals))
|
| 234 |
-
rmse = float(np.sqrt(np.mean(err_vals ** 2)))
|
| 235 |
-
bad_2 = float(np.mean(err_vals > 2.0)) * 100
|
| 236 |
-
|
| 237 |
-
em1, em2, em3 = st.columns(3)
|
| 238 |
-
em1.metric("MAE (px)", f"{mae:.2f}")
|
| 239 |
-
em2.metric("RMSE (px)", f"{rmse:.2f}")
|
| 240 |
-
em3.metric("Bad-2.0 (%)", f"{bad_2:.1f}%")
|
| 241 |
-
|
| 242 |
-
# Error heatmap
|
| 243 |
-
err_clip = np.clip(disp_err, 0, 10)
|
| 244 |
-
err_norm = (err_clip / 10 * 255).astype(np.uint8)
|
| 245 |
-
err_color = cv2.applyColorMap(err_norm, cv2.COLORMAP_HOT)
|
| 246 |
-
st.image(cv2.cvtColor(err_color, cv2.COLOR_BGR2RGB),
|
| 247 |
-
caption="Disparity Error Map (red = high error, clipped at 10 px)",
|
| 248 |
-
use_container_width=True)
|
| 249 |
-
|
| 250 |
-
# Histogram
|
| 251 |
-
fig = go.Figure(data=[go.Histogram(x=err_vals, nbinsx=50,
|
| 252 |
-
marker_color="#ff6361")])
|
| 253 |
-
fig.update_layout(title="Disparity Error Distribution",
|
| 254 |
-
xaxis_title="Absolute Error (px)",
|
| 255 |
-
yaxis_title="Pixel Count",
|
| 256 |
-
template="plotly_dark", height=300)
|
| 257 |
-
st.plotly_chart(fig, use_container_width=True)
|
| 258 |
-
else:
|
| 259 |
-
st.warning("No overlapping valid pixels between SGBM disparity and ground truth.")
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
# ===================================================================
|
| 263 |
-
# Step 4 — Object Distance from Detections
|
| 264 |
-
# ===================================================================
|
| 265 |
-
st.divider()
|
| 266 |
-
st.subheader("Step 4: Object Distance Estimation")
|
| 267 |
-
|
| 268 |
-
all_dets = []
|
| 269 |
-
all_dets.extend(("RCE", *d) for d in rce_dets)
|
| 270 |
-
all_dets.extend(("CNN", *d) for d in cnn_dets)
|
| 271 |
-
|
| 272 |
-
if not all_dets and crop_bbox is not None:
|
| 273 |
-
st.info("No detections from the Real-Time Detection page. Using the **crop bounding box** as a fallback.")
|
| 274 |
-
x0, y0, x1, y1 = crop_bbox
|
| 275 |
-
all_dets.append(("Crop", x0, y0, x1, y1, "object", 1.0))
|
| 276 |
-
elif not all_dets:
|
| 277 |
-
st.warning("No detections found. Run **Real-Time Detection** first, or define a crop in **Data Lab**.")
|
| 278 |
-
st.stop()
|
| 279 |
-
|
| 280 |
-
if focal <= 0:
|
| 281 |
-
st.warning("Focal length is 0 — cannot compute depth. Upload a valid config in **Data Lab**.")
|
| 282 |
-
st.stop()
|
| 283 |
-
|
| 284 |
-
rows = []
|
| 285 |
-
det_overlay = img_l.copy()
|
| 286 |
-
|
| 287 |
-
for source, dx1, dy1, dx2, dy2, lbl, conf in all_dets:
|
| 288 |
-
dx1, dy1, dx2, dy2 = int(dx1), int(dy1), int(dx2), int(dy2)
|
| 289 |
-
|
| 290 |
-
# Clamp to image bounds
|
| 291 |
-
H, W = depth_map.shape[:2]
|
| 292 |
-
dx1c = max(0, min(dx1, W - 1))
|
| 293 |
-
dy1c = max(0, min(dy1, H - 1))
|
| 294 |
-
dx2c = max(0, min(dx2, W))
|
| 295 |
-
dy2c = max(0, min(dy2, H))
|
| 296 |
-
|
| 297 |
-
roi_depth = depth_map[dy1c:dy2c, dx1c:dx2c]
|
| 298 |
-
roi_disp = disp[dy1c:dy2c, dx1c:dx2c]
|
| 299 |
-
roi_valid = roi_depth[roi_depth > 0]
|
| 300 |
-
|
| 301 |
-
if len(roi_valid) > 0:
|
| 302 |
-
med_depth = float(np.median(roi_valid))
|
| 303 |
-
mean_depth = float(np.mean(roi_valid))
|
| 304 |
-
med_disp = float(np.median(roi_disp[roi_disp > 0])) if (roi_disp > 0).any() else 0
|
| 305 |
-
else:
|
| 306 |
-
med_depth = mean_depth = med_disp = 0.0
|
| 307 |
-
|
| 308 |
-
# Ground truth depth at this region
|
| 309 |
-
gt_depth_val = 0.0
|
| 310 |
-
if gt_left is not None:
|
| 311 |
-
gt_roi = gt_left[dy1c:dy2c, dx1c:dx2c]
|
| 312 |
-
gt_roi_valid = gt_roi[np.isfinite(gt_roi) & (gt_roi > 0)]
|
| 313 |
-
if len(gt_roi_valid) > 0:
|
| 314 |
-
gt_med_disp = float(np.median(gt_roi_valid))
|
| 315 |
-
gt_depth_val = (focal * baseline) / (gt_med_disp + doffs) if (gt_med_disp + doffs) > 0 else 0
|
| 316 |
-
|
| 317 |
-
error_mm = abs(med_depth - gt_depth_val) if gt_depth_val > 0 else float("nan")
|
| 318 |
-
|
| 319 |
-
rows.append({
|
| 320 |
-
"Source": source,
|
| 321 |
-
"Box": f"({dx1},{dy1})→({dx2},{dy2})",
|
| 322 |
-
"Confidence": f"{conf:.1%}" if isinstance(conf, float) else str(conf),
|
| 323 |
-
"Med Disparity": f"{med_disp:.1f} px",
|
| 324 |
-
"Med Depth": f"{med_depth:.0f} mm",
|
| 325 |
-
"Mean Depth": f"{mean_depth:.0f} mm",
|
| 326 |
-
"GT Depth": f"{gt_depth_val:.0f} mm" if gt_depth_val > 0 else "N/A",
|
| 327 |
-
"Error": f"{error_mm:.0f} mm" if not np.isnan(error_mm) else "N/A",
|
| 328 |
-
})
|
| 329 |
-
|
| 330 |
-
# Draw on overlay
|
| 331 |
-
color = (0, 255, 0) if "RCE" in source else (0, 0, 255) if "CNN" in source else (255, 255, 0)
|
| 332 |
-
cv2.rectangle(det_overlay, (dx1c, dy1c), (dx2c, dy2c), color, 2)
|
| 333 |
-
depth_str = f"{med_depth / 1000:.2f}m" if med_depth > 0 else "?"
|
| 334 |
-
cv2.putText(det_overlay, f"{source} {depth_str}",
|
| 335 |
-
(dx1c, max(dy1c - 6, 12)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
|
| 336 |
-
|
| 337 |
-
# Show overlay
|
| 338 |
-
st.image(cv2.cvtColor(det_overlay, cv2.COLOR_BGR2RGB),
|
| 339 |
-
caption="Detections with Estimated Distance",
|
| 340 |
-
use_container_width=True)
|
| 341 |
-
|
| 342 |
-
# Table
|
| 343 |
-
st.dataframe(pd.DataFrame(rows), use_container_width=True, hide_index=True)
|
| 344 |
-
|
| 345 |
-
# Primary detection summary
|
| 346 |
-
if rows:
|
| 347 |
-
best = rows[0]
|
| 348 |
-
st.divider()
|
| 349 |
-
st.subheader("🎯 Primary Detection — Distance")
|
| 350 |
-
bc1, bc2, bc3 = st.columns(3)
|
| 351 |
-
bc1.metric("Estimated Depth", best["Med Depth"])
|
| 352 |
-
bc2.metric("Ground Truth", best["GT Depth"])
|
| 353 |
-
bc3.metric("Absolute Error", best["Error"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tabs/__init__.py
ADDED
|
File without changes
|
tabs/generalisation/__init__.py
ADDED
|
File without changes
|
tabs/generalisation/data_lab.py
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Generalisation Data Lab — Stage 1 of the Generalisation pipeline."""
|
| 2 |
+
|
| 3 |
+
import streamlit as st
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
from utils.middlebury_loader import (
|
| 9 |
+
DEFAULT_MIDDLEBURY_ROOT, get_scene_groups, load_single_view,
|
| 10 |
+
read_pfm_bytes,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# ------------------------------------------------------------------
|
| 15 |
+
# Helpers (shared with stereo data lab)
|
| 16 |
+
# ------------------------------------------------------------------
|
| 17 |
+
|
| 18 |
+
def _augment(img, brightness, contrast, rotation,
|
| 19 |
+
flip_h, flip_v, noise, blur, shift_x, shift_y):
|
| 20 |
+
out = img.astype(np.float32)
|
| 21 |
+
out = np.clip(contrast * out + brightness, 0, 255)
|
| 22 |
+
if noise > 0:
|
| 23 |
+
out = np.clip(out + np.random.normal(0, noise, out.shape), 0, 255)
|
| 24 |
+
out = out.astype(np.uint8)
|
| 25 |
+
k = blur * 2 + 1
|
| 26 |
+
if k > 1:
|
| 27 |
+
out = cv2.GaussianBlur(out, (k, k), 0)
|
| 28 |
+
if rotation != 0:
|
| 29 |
+
h, w = out.shape[:2]
|
| 30 |
+
M = cv2.getRotationMatrix2D((w / 2, h / 2), rotation, 1.0)
|
| 31 |
+
out = cv2.warpAffine(out, M, (w, h), borderMode=cv2.BORDER_REFLECT)
|
| 32 |
+
if shift_x != 0 or shift_y != 0:
|
| 33 |
+
h, w = out.shape[:2]
|
| 34 |
+
M = np.float32([[1, 0, shift_x], [0, 1, shift_y]])
|
| 35 |
+
out = cv2.warpAffine(out, M, (w, h), borderMode=cv2.BORDER_REFLECT)
|
| 36 |
+
if flip_h:
|
| 37 |
+
out = cv2.flip(out, 1)
|
| 38 |
+
if flip_v:
|
| 39 |
+
out = cv2.flip(out, 0)
|
| 40 |
+
return out
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
ROI_COLORS = [(0,255,0),(255,0,0),(0,0,255),(255,255,0),
|
| 44 |
+
(255,0,255),(0,255,255),(128,255,0),(255,128,0)]
|
| 45 |
+
MAX_UPLOAD_BYTES = 50 * 1024 * 1024
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def render():
|
| 49 |
+
st.header("🧪 Data Lab — Generalisation")
|
| 50 |
+
st.info("**How this works:** Train on one image, test on a completely "
|
| 51 |
+
"different image of the same object. No stereo geometry — "
|
| 52 |
+
"pure recognition generalisation.")
|
| 53 |
+
|
| 54 |
+
source = st.radio("Data source",
|
| 55 |
+
["📦 Middlebury Multi-View", "📁 Upload your own files"],
|
| 56 |
+
horizontal=True, key="gen_source")
|
| 57 |
+
|
| 58 |
+
# ===================================================================
|
| 59 |
+
# Middlebury multi-view
|
| 60 |
+
# ===================================================================
|
| 61 |
+
if source == "📦 Middlebury Multi-View":
|
| 62 |
+
groups = get_scene_groups()
|
| 63 |
+
if not groups:
|
| 64 |
+
st.error("No valid Middlebury scenes found in ./data/middlebury/")
|
| 65 |
+
return
|
| 66 |
+
|
| 67 |
+
group_name = st.selectbox("Scene group", list(groups.keys()), key="gen_group")
|
| 68 |
+
variants = groups[group_name]
|
| 69 |
+
|
| 70 |
+
gc1, gc2 = st.columns(2)
|
| 71 |
+
train_scene = gc1.selectbox("Training scene", variants, key="gen_train_scene")
|
| 72 |
+
available_test = [v for v in variants if v != train_scene]
|
| 73 |
+
if not available_test:
|
| 74 |
+
st.error("Need at least 2 variants in a group.")
|
| 75 |
+
return
|
| 76 |
+
test_scene = gc2.selectbox("Test scene", available_test, key="gen_test_scene")
|
| 77 |
+
|
| 78 |
+
train_path = os.path.join(DEFAULT_MIDDLEBURY_ROOT, train_scene)
|
| 79 |
+
test_path = os.path.join(DEFAULT_MIDDLEBURY_ROOT, test_scene)
|
| 80 |
+
|
| 81 |
+
img_train = load_single_view(train_path)
|
| 82 |
+
img_test = load_single_view(test_path)
|
| 83 |
+
|
| 84 |
+
st.markdown("*Both images show the same scene type captured under different "
|
| 85 |
+
"conditions. The model trains on one variant and must recognise "
|
| 86 |
+
"the same object class in the other — testing genuine appearance "
|
| 87 |
+
"generalisation.*")
|
| 88 |
+
|
| 89 |
+
c1, c2 = st.columns(2)
|
| 90 |
+
c1.image(cv2.cvtColor(img_train, cv2.COLOR_BGR2RGB),
|
| 91 |
+
caption=f"🟦 TRAIN IMAGE ({train_scene})", use_container_width=True)
|
| 92 |
+
c2.image(cv2.cvtColor(img_test, cv2.COLOR_BGR2RGB),
|
| 93 |
+
caption=f"🟥 TEST IMAGE ({test_scene})", use_container_width=True)
|
| 94 |
+
|
| 95 |
+
scene_group = group_name
|
| 96 |
+
|
| 97 |
+
# ===================================================================
|
| 98 |
+
# Custom upload
|
| 99 |
+
# ===================================================================
|
| 100 |
+
else:
|
| 101 |
+
uc1, uc2 = st.columns(2)
|
| 102 |
+
with uc1:
|
| 103 |
+
up_train = st.file_uploader("Train Image", type=["png","jpg","jpeg"],
|
| 104 |
+
key="gen_up_train")
|
| 105 |
+
with uc2:
|
| 106 |
+
up_test = st.file_uploader("Test Image", type=["png","jpg","jpeg"],
|
| 107 |
+
key="gen_up_test")
|
| 108 |
+
|
| 109 |
+
if not (up_train and up_test):
|
| 110 |
+
st.info("Upload a train and test image to proceed.")
|
| 111 |
+
return
|
| 112 |
+
|
| 113 |
+
if up_train.size > MAX_UPLOAD_BYTES or up_test.size > MAX_UPLOAD_BYTES:
|
| 114 |
+
st.error("Image too large (max 50 MB).")
|
| 115 |
+
return
|
| 116 |
+
|
| 117 |
+
img_train = cv2.imdecode(np.frombuffer(up_train.read(), np.uint8), cv2.IMREAD_COLOR); up_train.seek(0)
|
| 118 |
+
img_test = cv2.imdecode(np.frombuffer(up_test.read(), np.uint8), cv2.IMREAD_COLOR); up_test.seek(0)
|
| 119 |
+
|
| 120 |
+
c1, c2 = st.columns(2)
|
| 121 |
+
c1.image(cv2.cvtColor(img_train, cv2.COLOR_BGR2RGB),
|
| 122 |
+
caption="🟦 TRAIN IMAGE", use_container_width=True)
|
| 123 |
+
c2.image(cv2.cvtColor(img_test, cv2.COLOR_BGR2RGB),
|
| 124 |
+
caption="🟥 TEST IMAGE", use_container_width=True)
|
| 125 |
+
|
| 126 |
+
train_scene = "custom_train"
|
| 127 |
+
test_scene = "custom_test"
|
| 128 |
+
scene_group = "custom"
|
| 129 |
+
|
| 130 |
+
# ===================================================================
|
| 131 |
+
# ROI Definition (on TRAIN image)
|
| 132 |
+
# ===================================================================
|
| 133 |
+
st.divider()
|
| 134 |
+
st.subheader("Step 2: Crop Region(s) of Interest")
|
| 135 |
+
st.write("Define bounding boxes on the **TRAIN image**.")
|
| 136 |
+
|
| 137 |
+
H, W = img_train.shape[:2]
|
| 138 |
+
|
| 139 |
+
if "gen_rois" not in st.session_state:
|
| 140 |
+
st.session_state["gen_rois"] = [
|
| 141 |
+
{"label": "object", "x0": 0, "y0": 0,
|
| 142 |
+
"x1": min(W, 100), "y1": min(H, 100)}
|
| 143 |
+
]
|
| 144 |
+
|
| 145 |
+
def _add_roi():
|
| 146 |
+
if len(st.session_state["gen_rois"]) >= 20:
|
| 147 |
+
return
|
| 148 |
+
st.session_state["gen_rois"].append(
|
| 149 |
+
{"label": f"object_{len(st.session_state['gen_rois'])+1}",
|
| 150 |
+
"x0": 0, "y0": 0,
|
| 151 |
+
"x1": min(W, 100), "y1": min(H, 100)})
|
| 152 |
+
|
| 153 |
+
def _remove_roi(idx):
|
| 154 |
+
if len(st.session_state["gen_rois"]) > 1:
|
| 155 |
+
st.session_state["gen_rois"].pop(idx)
|
| 156 |
+
|
| 157 |
+
for i, roi in enumerate(st.session_state["gen_rois"]):
|
| 158 |
+
color = ROI_COLORS[i % len(ROI_COLORS)]
|
| 159 |
+
color_hex = "#{:02x}{:02x}{:02x}".format(*color)
|
| 160 |
+
with st.container(border=True):
|
| 161 |
+
hc1, hc2, hc3 = st.columns([3, 6, 1])
|
| 162 |
+
hc1.markdown(f"**ROI {i+1}** <span style='color:{color_hex}'>■</span>",
|
| 163 |
+
unsafe_allow_html=True)
|
| 164 |
+
roi["label"] = hc2.text_input("Class Label", roi["label"],
|
| 165 |
+
key=f"gen_roi_lbl_{i}")
|
| 166 |
+
if len(st.session_state["gen_rois"]) > 1:
|
| 167 |
+
hc3.button("✕", key=f"gen_roi_del_{i}",
|
| 168 |
+
on_click=_remove_roi, args=(i,))
|
| 169 |
+
|
| 170 |
+
cr1, cr2, cr3, cr4 = st.columns(4)
|
| 171 |
+
roi["x0"] = int(cr1.number_input("X start", 0, W-2, int(roi["x0"]),
|
| 172 |
+
step=1, key=f"gen_roi_x0_{i}"))
|
| 173 |
+
roi["y0"] = int(cr2.number_input("Y start", 0, H-2, int(roi["y0"]),
|
| 174 |
+
step=1, key=f"gen_roi_y0_{i}"))
|
| 175 |
+
roi["x1"] = int(cr3.number_input("X end", roi["x0"]+1, W,
|
| 176 |
+
min(W, int(roi["x1"])),
|
| 177 |
+
step=1, key=f"gen_roi_x1_{i}"))
|
| 178 |
+
roi["y1"] = int(cr4.number_input("Y end", roi["y0"]+1, H,
|
| 179 |
+
min(H, int(roi["y1"])),
|
| 180 |
+
step=1, key=f"gen_roi_y1_{i}"))
|
| 181 |
+
|
| 182 |
+
st.button("➕ Add Another ROI", on_click=_add_roi,
|
| 183 |
+
disabled=len(st.session_state["gen_rois"]) >= 20,
|
| 184 |
+
key="gen_add_roi")
|
| 185 |
+
|
| 186 |
+
overlay = img_train.copy()
|
| 187 |
+
crops = []
|
| 188 |
+
for i, roi in enumerate(st.session_state["gen_rois"]):
|
| 189 |
+
color = ROI_COLORS[i % len(ROI_COLORS)]
|
| 190 |
+
x0, y0, x1, y1 = roi["x0"], roi["y0"], roi["x1"], roi["y1"]
|
| 191 |
+
cv2.rectangle(overlay, (x0, y0), (x1, y1), color, 2)
|
| 192 |
+
cv2.putText(overlay, roi["label"], (x0, y0 - 6),
|
| 193 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
|
| 194 |
+
crops.append(img_train[y0:y1, x0:x1].copy())
|
| 195 |
+
|
| 196 |
+
ov1, ov2 = st.columns([3, 2])
|
| 197 |
+
ov1.image(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB),
|
| 198 |
+
caption="TRAIN image — ROIs highlighted", use_container_width=True)
|
| 199 |
+
with ov2:
|
| 200 |
+
for i, (c, roi) in enumerate(zip(crops, st.session_state["gen_rois"])):
|
| 201 |
+
st.image(cv2.cvtColor(c, cv2.COLOR_BGR2RGB),
|
| 202 |
+
caption=f"{roi['label']} ({c.shape[1]}×{c.shape[0]})", width=160)
|
| 203 |
+
|
| 204 |
+
crop_bgr = crops[0]
|
| 205 |
+
x0 = st.session_state["gen_rois"][0]["x0"]
|
| 206 |
+
y0 = st.session_state["gen_rois"][0]["y0"]
|
| 207 |
+
x1 = st.session_state["gen_rois"][0]["x1"]
|
| 208 |
+
y1 = st.session_state["gen_rois"][0]["y1"]
|
| 209 |
+
|
| 210 |
+
# ===================================================================
|
| 211 |
+
# Augmentation
|
| 212 |
+
# ===================================================================
|
| 213 |
+
st.divider()
|
| 214 |
+
st.subheader("Step 3: Data Augmentation")
|
| 215 |
+
ac1, ac2 = st.columns(2)
|
| 216 |
+
with ac1:
|
| 217 |
+
brightness = st.slider("Brightness offset", -100, 100, 0, key="gen_bright")
|
| 218 |
+
contrast = st.slider("Contrast scale", 0.5, 3.0, 1.0, 0.05, key="gen_contrast")
|
| 219 |
+
rotation = st.slider("Rotation (°)", -180, 180, 0, key="gen_rot")
|
| 220 |
+
noise = st.slider("Gaussian noise σ", 0, 50, 0, key="gen_noise")
|
| 221 |
+
with ac2:
|
| 222 |
+
blur = st.slider("Blur kernel (0=off)", 0, 10, 0, key="gen_blur")
|
| 223 |
+
shift_x = st.slider("Shift X (px)", -100, 100, 0, key="gen_sx")
|
| 224 |
+
shift_y = st.slider("Shift Y (px)", -100, 100, 0, key="gen_sy")
|
| 225 |
+
flip_h = st.checkbox("Flip Horizontal", key="gen_fh")
|
| 226 |
+
flip_v = st.checkbox("Flip Vertical", key="gen_fv")
|
| 227 |
+
|
| 228 |
+
aug = _augment(crop_bgr, brightness, contrast, rotation,
|
| 229 |
+
flip_h, flip_v, noise, blur, shift_x, shift_y)
|
| 230 |
+
all_augs = [_augment(c, brightness, contrast, rotation,
|
| 231 |
+
flip_h, flip_v, noise, blur, shift_x, shift_y)
|
| 232 |
+
for c in crops]
|
| 233 |
+
|
| 234 |
+
ag1, ag2 = st.columns(2)
|
| 235 |
+
ag1.image(cv2.cvtColor(crop_bgr, cv2.COLOR_BGR2RGB),
|
| 236 |
+
caption="Original Crop (ROI 1)", use_container_width=True)
|
| 237 |
+
ag2.image(cv2.cvtColor(aug, cv2.COLOR_BGR2RGB),
|
| 238 |
+
caption="Augmented Crop (ROI 1)", use_container_width=True)
|
| 239 |
+
|
| 240 |
+
# ===================================================================
|
| 241 |
+
# Lock & Store
|
| 242 |
+
# ===================================================================
|
| 243 |
+
st.divider()
|
| 244 |
+
if st.button("🚀 Lock Data & Proceed", key="gen_lock"):
|
| 245 |
+
rois_data = []
|
| 246 |
+
for i, roi in enumerate(st.session_state["gen_rois"]):
|
| 247 |
+
rois_data.append({
|
| 248 |
+
"label": roi["label"],
|
| 249 |
+
"bbox": (roi["x0"], roi["y0"], roi["x1"], roi["y1"]),
|
| 250 |
+
"crop": crops[i],
|
| 251 |
+
"crop_aug": all_augs[i],
|
| 252 |
+
})
|
| 253 |
+
|
| 254 |
+
st.session_state["gen_pipeline"] = {
|
| 255 |
+
"train_image": img_train,
|
| 256 |
+
"test_image": img_test,
|
| 257 |
+
"roi": {"x": x0, "y": y0, "w": x1 - x0, "h": y1 - y0,
|
| 258 |
+
"label": st.session_state["gen_rois"][0]["label"]},
|
| 259 |
+
"crop": crop_bgr,
|
| 260 |
+
"crop_aug": aug,
|
| 261 |
+
"crop_bbox": (x0, y0, x1, y1),
|
| 262 |
+
"rois": rois_data,
|
| 263 |
+
"source": "middlebury" if source == "📦 Middlebury Multi-View" else "custom",
|
| 264 |
+
"scene_group": scene_group if "scene_group" in dir() else "",
|
| 265 |
+
"train_scene": train_scene if "train_scene" in dir() else "",
|
| 266 |
+
"test_scene": test_scene if "test_scene" in dir() else "",
|
| 267 |
+
}
|
| 268 |
+
st.success(f"✅ Data locked with **{len(rois_data)} ROI(s)**! "
|
| 269 |
+
f"Proceed to Feature Lab.")
|
tabs/generalisation/detection.py
ADDED
|
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Generalisation Detection — Stage 5 of the Generalisation pipeline.
|
| 2 |
+
|
| 3 |
+
CRITICAL: Detection runs on the TEST image (different scene variant).
|
| 4 |
+
Training was done on the TRAIN image.
|
| 5 |
+
This enforces the data-leakage fix.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import streamlit as st
|
| 9 |
+
import cv2
|
| 10 |
+
import numpy as np
|
| 11 |
+
import time
|
| 12 |
+
import plotly.graph_objects as go
|
| 13 |
+
|
| 14 |
+
from src.detectors.rce.features import REGISTRY
|
| 15 |
+
from src.models import BACKBONES, RecognitionHead
|
| 16 |
+
from src.utils import build_rce_vector
|
| 17 |
+
from src.localization import nms as _nms
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
CLASS_COLORS = [(0,255,0),(0,0,255),(255,165,0),(255,0,255),(0,255,255),
|
| 21 |
+
(128,255,0),(255,128,0),(0,128,255)]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def sliding_window_detect(image, feature_fn, head, win_h, win_w,
|
| 25 |
+
stride, conf_thresh, nms_iou,
|
| 26 |
+
progress_placeholder=None,
|
| 27 |
+
live_image_placeholder=None):
|
| 28 |
+
H, W = image.shape[:2]
|
| 29 |
+
heatmap = np.zeros((H, W), dtype=np.float32)
|
| 30 |
+
detections = []
|
| 31 |
+
t0 = time.perf_counter()
|
| 32 |
+
|
| 33 |
+
positions = [(x, y)
|
| 34 |
+
for y in range(0, H - win_h + 1, stride)
|
| 35 |
+
for x in range(0, W - win_w + 1, stride)]
|
| 36 |
+
n_total = len(positions)
|
| 37 |
+
if n_total == 0:
|
| 38 |
+
return [], heatmap, 0.0, 0
|
| 39 |
+
|
| 40 |
+
for idx, (x, y) in enumerate(positions):
|
| 41 |
+
patch = image[y:y+win_h, x:x+win_w]
|
| 42 |
+
feats = feature_fn(patch)
|
| 43 |
+
label, conf = head.predict(feats)
|
| 44 |
+
|
| 45 |
+
if label != "background":
|
| 46 |
+
heatmap[y:y+win_h, x:x+win_w] = np.maximum(
|
| 47 |
+
heatmap[y:y+win_h, x:x+win_w], conf)
|
| 48 |
+
if conf >= conf_thresh:
|
| 49 |
+
detections.append((x, y, x+win_w, y+win_h, label, conf))
|
| 50 |
+
|
| 51 |
+
if live_image_placeholder is not None and (idx % 5 == 0 or idx == n_total - 1):
|
| 52 |
+
vis = image.copy()
|
| 53 |
+
cv2.rectangle(vis, (x, y), (x+win_w, y+win_h), (255, 255, 0), 1)
|
| 54 |
+
for dx, dy, dx2, dy2, dl, dc in detections:
|
| 55 |
+
cv2.rectangle(vis, (dx, dy), (dx2, dy2), (0, 255, 0), 2)
|
| 56 |
+
cv2.putText(vis, f"{dc:.0%}", (dx, dy - 4),
|
| 57 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 255, 0), 1)
|
| 58 |
+
live_image_placeholder.image(
|
| 59 |
+
cv2.cvtColor(vis, cv2.COLOR_BGR2RGB),
|
| 60 |
+
caption=f"Scanning… {idx+1}/{n_total}",
|
| 61 |
+
use_container_width=True)
|
| 62 |
+
|
| 63 |
+
if progress_placeholder is not None:
|
| 64 |
+
progress_placeholder.progress(
|
| 65 |
+
(idx + 1) / n_total, text=f"Window {idx+1}/{n_total}")
|
| 66 |
+
|
| 67 |
+
total_ms = (time.perf_counter() - t0) * 1000
|
| 68 |
+
if detections:
|
| 69 |
+
detections = _nms(detections, nms_iou)
|
| 70 |
+
return detections, heatmap, total_ms, n_total
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def render():
|
| 74 |
+
st.title("🎯 Real-Time Detection")
|
| 75 |
+
|
| 76 |
+
pipe = st.session_state.get("gen_pipeline")
|
| 77 |
+
if not pipe or "crop" not in pipe:
|
| 78 |
+
st.error("Complete **Data Lab** first (upload assets & define a crop).")
|
| 79 |
+
st.stop()
|
| 80 |
+
|
| 81 |
+
# CRITICAL: detect on TEST image, not TRAIN image
|
| 82 |
+
test_img = pipe["test_image"]
|
| 83 |
+
crop = pipe["crop"]
|
| 84 |
+
crop_aug = pipe.get("crop_aug", crop)
|
| 85 |
+
bbox = pipe.get("crop_bbox", (0, 0, crop.shape[1], crop.shape[0]))
|
| 86 |
+
rois = pipe.get("rois", [{"label": "object", "bbox": bbox,
|
| 87 |
+
"crop": crop, "crop_aug": crop_aug}])
|
| 88 |
+
active_mods = pipe.get("active_modules", {k: True for k in REGISTRY})
|
| 89 |
+
|
| 90 |
+
x0, y0, x1, y1 = bbox
|
| 91 |
+
win_h, win_w = y1 - y0, x1 - x0
|
| 92 |
+
|
| 93 |
+
if win_h <= 0 or win_w <= 0:
|
| 94 |
+
st.error("Invalid window size from crop bbox.")
|
| 95 |
+
st.stop()
|
| 96 |
+
|
| 97 |
+
rce_head = pipe.get("rce_head")
|
| 98 |
+
has_any_cnn = any(f"cnn_head_{n}" in pipe for n in BACKBONES)
|
| 99 |
+
has_orb = pipe.get("orb_refs") is not None
|
| 100 |
+
|
| 101 |
+
if rce_head is None and not has_any_cnn and not has_orb:
|
| 102 |
+
st.warning("No trained heads found. Go to **Model Tuning** first.")
|
| 103 |
+
st.stop()
|
| 104 |
+
|
| 105 |
+
def rce_feature_fn(patch_bgr):
|
| 106 |
+
return build_rce_vector(patch_bgr, active_mods)
|
| 107 |
+
|
| 108 |
+
# Controls
|
| 109 |
+
st.subheader("Sliding Window Parameters")
|
| 110 |
+
p1, p2, p3 = st.columns(3)
|
| 111 |
+
stride = p1.slider("Stride (px)", 4, max(win_w // 2, 4),
|
| 112 |
+
max(win_w // 4, 4), step=2, key="gen_det_stride")
|
| 113 |
+
conf_thresh = p2.slider("Confidence Threshold", 0.5, 1.0, 0.7, 0.05,
|
| 114 |
+
key="gen_det_conf")
|
| 115 |
+
nms_iou = p3.slider("NMS IoU Threshold", 0.1, 0.9, 0.3, 0.05,
|
| 116 |
+
key="gen_det_nms")
|
| 117 |
+
|
| 118 |
+
st.caption(f"Window size: **{win_w}×{win_h} px** | "
|
| 119 |
+
f"Test image: **{test_img.shape[1]}×{test_img.shape[0]} px** | "
|
| 120 |
+
f"≈ {((test_img.shape[0]-win_h)//stride + 1) * ((test_img.shape[1]-win_w)//stride + 1)} windows")
|
| 121 |
+
st.divider()
|
| 122 |
+
|
| 123 |
+
col_rce, col_cnn, col_orb = st.columns(3)
|
| 124 |
+
|
| 125 |
+
# -------------------------------------------------------------------
|
| 126 |
+
# RCE Detection
|
| 127 |
+
# -------------------------------------------------------------------
|
| 128 |
+
with col_rce:
|
| 129 |
+
st.header("🧬 RCE Detection")
|
| 130 |
+
if rce_head is None:
|
| 131 |
+
st.info("No RCE head trained.")
|
| 132 |
+
else:
|
| 133 |
+
st.caption(f"Modules: {', '.join(REGISTRY[k]['label'] for k in active_mods if active_mods[k])}")
|
| 134 |
+
rce_run = st.button("▶ Run RCE Scan", key="gen_rce_run")
|
| 135 |
+
rce_progress = st.empty()
|
| 136 |
+
rce_live = st.empty()
|
| 137 |
+
rce_results = st.container()
|
| 138 |
+
|
| 139 |
+
if rce_run:
|
| 140 |
+
dets, hmap, ms, nw = sliding_window_detect(
|
| 141 |
+
test_img, rce_feature_fn, rce_head, win_h, win_w,
|
| 142 |
+
stride, conf_thresh, nms_iou,
|
| 143 |
+
progress_placeholder=rce_progress,
|
| 144 |
+
live_image_placeholder=rce_live)
|
| 145 |
+
|
| 146 |
+
final = test_img.copy()
|
| 147 |
+
class_labels = sorted(set(d[4] for d in dets)) if dets else []
|
| 148 |
+
for x1d, y1d, x2d, y2d, lbl, cf in dets:
|
| 149 |
+
ci = class_labels.index(lbl) if lbl in class_labels else 0
|
| 150 |
+
clr = CLASS_COLORS[ci % len(CLASS_COLORS)]
|
| 151 |
+
cv2.rectangle(final, (x1d, y1d), (x2d, y2d), clr, 2)
|
| 152 |
+
cv2.putText(final, f"{lbl} {cf:.0%}", (x1d, y1d - 6),
|
| 153 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.4, clr, 1)
|
| 154 |
+
rce_live.image(cv2.cvtColor(final, cv2.COLOR_BGR2RGB),
|
| 155 |
+
caption="RCE — Final Detections",
|
| 156 |
+
use_container_width=True)
|
| 157 |
+
rce_progress.empty()
|
| 158 |
+
|
| 159 |
+
with rce_results:
|
| 160 |
+
rm1, rm2, rm3, rm4 = st.columns(4)
|
| 161 |
+
rm1.metric("Detections", len(dets))
|
| 162 |
+
rm2.metric("Windows", nw)
|
| 163 |
+
rm3.metric("Total Time", f"{ms:.0f} ms")
|
| 164 |
+
rm4.metric("Per Window", f"{ms/max(nw,1):.2f} ms")
|
| 165 |
+
|
| 166 |
+
if hmap.max() > 0:
|
| 167 |
+
hmap_color = cv2.applyColorMap(
|
| 168 |
+
(hmap / hmap.max() * 255).astype(np.uint8),
|
| 169 |
+
cv2.COLORMAP_JET)
|
| 170 |
+
blend = cv2.addWeighted(test_img, 0.5, hmap_color, 0.5, 0)
|
| 171 |
+
st.image(cv2.cvtColor(blend, cv2.COLOR_BGR2RGB),
|
| 172 |
+
caption="RCE — Confidence Heatmap",
|
| 173 |
+
use_container_width=True)
|
| 174 |
+
|
| 175 |
+
if dets:
|
| 176 |
+
import pandas as pd
|
| 177 |
+
df = pd.DataFrame(dets, columns=["x1","y1","x2","y2","label","conf"])
|
| 178 |
+
st.dataframe(df, use_container_width=True, hide_index=True)
|
| 179 |
+
|
| 180 |
+
pipe["rce_dets"] = dets
|
| 181 |
+
pipe["rce_det_ms"] = ms
|
| 182 |
+
st.session_state["gen_pipeline"] = pipe
|
| 183 |
+
|
| 184 |
+
# -------------------------------------------------------------------
|
| 185 |
+
# CNN Detection
|
| 186 |
+
# -------------------------------------------------------------------
|
| 187 |
+
with col_cnn:
|
| 188 |
+
st.header("🧠 CNN Detection")
|
| 189 |
+
trained_cnns = [n for n in BACKBONES if f"cnn_head_{n}" in pipe]
|
| 190 |
+
if not trained_cnns:
|
| 191 |
+
st.info("No CNN head trained.")
|
| 192 |
+
else:
|
| 193 |
+
selected = st.selectbox("Select Model", trained_cnns,
|
| 194 |
+
key="gen_det_cnn_sel")
|
| 195 |
+
bmeta = BACKBONES[selected]
|
| 196 |
+
backbone = bmeta["loader"]()
|
| 197 |
+
head = pipe[f"cnn_head_{selected}"]
|
| 198 |
+
|
| 199 |
+
st.caption(f"Backbone: **{selected}** ({bmeta['dim']}D)")
|
| 200 |
+
cnn_run = st.button(f"▶ Run {selected} Scan", key="gen_cnn_run")
|
| 201 |
+
cnn_progress = st.empty()
|
| 202 |
+
cnn_live = st.empty()
|
| 203 |
+
cnn_results = st.container()
|
| 204 |
+
|
| 205 |
+
if cnn_run:
|
| 206 |
+
dets, hmap, ms, nw = sliding_window_detect(
|
| 207 |
+
test_img, backbone.get_features, head, win_h, win_w,
|
| 208 |
+
stride, conf_thresh, nms_iou,
|
| 209 |
+
progress_placeholder=cnn_progress,
|
| 210 |
+
live_image_placeholder=cnn_live)
|
| 211 |
+
|
| 212 |
+
final = test_img.copy()
|
| 213 |
+
class_labels = sorted(set(d[4] for d in dets)) if dets else []
|
| 214 |
+
for x1d, y1d, x2d, y2d, lbl, cf in dets:
|
| 215 |
+
ci = class_labels.index(lbl) if lbl in class_labels else 0
|
| 216 |
+
clr = CLASS_COLORS[ci % len(CLASS_COLORS)]
|
| 217 |
+
cv2.rectangle(final, (x1d, y1d), (x2d, y2d), clr, 2)
|
| 218 |
+
cv2.putText(final, f"{lbl} {cf:.0%}", (x1d, y1d - 6),
|
| 219 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.4, clr, 1)
|
| 220 |
+
cnn_live.image(cv2.cvtColor(final, cv2.COLOR_BGR2RGB),
|
| 221 |
+
caption=f"{selected} — Final Detections",
|
| 222 |
+
use_container_width=True)
|
| 223 |
+
cnn_progress.empty()
|
| 224 |
+
|
| 225 |
+
with cnn_results:
|
| 226 |
+
cm1, cm2, cm3, cm4 = st.columns(4)
|
| 227 |
+
cm1.metric("Detections", len(dets))
|
| 228 |
+
cm2.metric("Windows", nw)
|
| 229 |
+
cm3.metric("Total Time", f"{ms:.0f} ms")
|
| 230 |
+
cm4.metric("Per Window", f"{ms/max(nw,1):.2f} ms")
|
| 231 |
+
|
| 232 |
+
if hmap.max() > 0:
|
| 233 |
+
hmap_color = cv2.applyColorMap(
|
| 234 |
+
(hmap / hmap.max() * 255).astype(np.uint8),
|
| 235 |
+
cv2.COLORMAP_JET)
|
| 236 |
+
blend = cv2.addWeighted(test_img, 0.5, hmap_color, 0.5, 0)
|
| 237 |
+
st.image(cv2.cvtColor(blend, cv2.COLOR_BGR2RGB),
|
| 238 |
+
caption=f"{selected} — Confidence Heatmap",
|
| 239 |
+
use_container_width=True)
|
| 240 |
+
|
| 241 |
+
if dets:
|
| 242 |
+
import pandas as pd
|
| 243 |
+
df = pd.DataFrame(dets, columns=["x1","y1","x2","y2","label","conf"])
|
| 244 |
+
st.dataframe(df, use_container_width=True, hide_index=True)
|
| 245 |
+
|
| 246 |
+
pipe["cnn_dets"] = dets
|
| 247 |
+
pipe["cnn_det_ms"] = ms
|
| 248 |
+
st.session_state["gen_pipeline"] = pipe
|
| 249 |
+
|
| 250 |
+
# -------------------------------------------------------------------
|
| 251 |
+
# ORB Detection
|
| 252 |
+
# -------------------------------------------------------------------
|
| 253 |
+
with col_orb:
|
| 254 |
+
st.header("🏛️ ORB Detection")
|
| 255 |
+
if not has_orb:
|
| 256 |
+
st.info("No ORB reference trained.")
|
| 257 |
+
else:
|
| 258 |
+
orb_det = pipe["orb_detector"]
|
| 259 |
+
orb_refs = pipe["orb_refs"]
|
| 260 |
+
dt_thresh = pipe.get("orb_dist_thresh", 70)
|
| 261 |
+
min_m = pipe.get("orb_min_matches", 5)
|
| 262 |
+
st.caption(f"References: {', '.join(orb_refs.keys())} | "
|
| 263 |
+
f"dist<{dt_thresh}, min {min_m} matches")
|
| 264 |
+
orb_run = st.button("▶ Run ORB Scan", key="gen_orb_run")
|
| 265 |
+
orb_progress = st.empty()
|
| 266 |
+
orb_live = st.empty()
|
| 267 |
+
orb_results = st.container()
|
| 268 |
+
|
| 269 |
+
if orb_run:
|
| 270 |
+
H, W = test_img.shape[:2]
|
| 271 |
+
positions = [(x, y)
|
| 272 |
+
for y in range(0, H - win_h + 1, stride)
|
| 273 |
+
for x in range(0, W - win_w + 1, stride)]
|
| 274 |
+
n_total = len(positions)
|
| 275 |
+
heatmap = np.zeros((H, W), dtype=np.float32)
|
| 276 |
+
detections = []
|
| 277 |
+
t0 = time.perf_counter()
|
| 278 |
+
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
| 279 |
+
|
| 280 |
+
for idx, (px, py) in enumerate(positions):
|
| 281 |
+
patch = test_img[py:py+win_h, px:px+win_w]
|
| 282 |
+
gray = cv2.cvtColor(patch, cv2.COLOR_BGR2GRAY)
|
| 283 |
+
gray = clahe.apply(gray)
|
| 284 |
+
kp, des = orb_det.orb.detectAndCompute(gray, None)
|
| 285 |
+
|
| 286 |
+
if des is not None:
|
| 287 |
+
best_label, best_conf = "background", 0.0
|
| 288 |
+
for lbl, ref in orb_refs.items():
|
| 289 |
+
if ref["descriptors"] is None:
|
| 290 |
+
continue
|
| 291 |
+
matches = orb_det.bf.match(ref["descriptors"], des)
|
| 292 |
+
good = [m for m in matches if m.distance < dt_thresh]
|
| 293 |
+
conf = min(len(good) / max(min_m, 1), 1.0)
|
| 294 |
+
if len(good) >= min_m and conf > best_conf:
|
| 295 |
+
best_label, best_conf = lbl, conf
|
| 296 |
+
|
| 297 |
+
if best_label != "background":
|
| 298 |
+
heatmap[py:py+win_h, px:px+win_w] = np.maximum(
|
| 299 |
+
heatmap[py:py+win_h, px:px+win_w], best_conf)
|
| 300 |
+
if best_conf >= conf_thresh:
|
| 301 |
+
detections.append(
|
| 302 |
+
(px, py, px+win_w, py+win_h, best_label, best_conf))
|
| 303 |
+
|
| 304 |
+
if idx % 5 == 0 or idx == n_total - 1:
|
| 305 |
+
orb_progress.progress((idx+1)/n_total,
|
| 306 |
+
text=f"Window {idx+1}/{n_total}")
|
| 307 |
+
|
| 308 |
+
total_ms = (time.perf_counter() - t0) * 1000
|
| 309 |
+
if detections:
|
| 310 |
+
detections = _nms(detections, nms_iou)
|
| 311 |
+
|
| 312 |
+
final = test_img.copy()
|
| 313 |
+
cls_labels = sorted(set(d[4] for d in detections)) if detections else []
|
| 314 |
+
for x1d, y1d, x2d, y2d, lbl, cf in detections:
|
| 315 |
+
ci = cls_labels.index(lbl) if lbl in cls_labels else 0
|
| 316 |
+
clr = CLASS_COLORS[ci % len(CLASS_COLORS)]
|
| 317 |
+
cv2.rectangle(final, (x1d, y1d), (x2d, y2d), clr, 2)
|
| 318 |
+
cv2.putText(final, f"{lbl} {cf:.0%}", (x1d, y1d - 6),
|
| 319 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.4, clr, 1)
|
| 320 |
+
orb_live.image(cv2.cvtColor(final, cv2.COLOR_BGR2RGB),
|
| 321 |
+
caption="ORB — Final Detections",
|
| 322 |
+
use_container_width=True)
|
| 323 |
+
orb_progress.empty()
|
| 324 |
+
|
| 325 |
+
with orb_results:
|
| 326 |
+
om1, om2, om3, om4 = st.columns(4)
|
| 327 |
+
om1.metric("Detections", len(detections))
|
| 328 |
+
om2.metric("Windows", n_total)
|
| 329 |
+
om3.metric("Total Time", f"{total_ms:.0f} ms")
|
| 330 |
+
om4.metric("Per Window", f"{total_ms/max(n_total,1):.2f} ms")
|
| 331 |
+
|
| 332 |
+
if heatmap.max() > 0:
|
| 333 |
+
hmap_color = cv2.applyColorMap(
|
| 334 |
+
(heatmap / heatmap.max() * 255).astype(np.uint8),
|
| 335 |
+
cv2.COLORMAP_JET)
|
| 336 |
+
blend = cv2.addWeighted(test_img, 0.5, hmap_color, 0.5, 0)
|
| 337 |
+
st.image(cv2.cvtColor(blend, cv2.COLOR_BGR2RGB),
|
| 338 |
+
caption="ORB — Confidence Heatmap",
|
| 339 |
+
use_container_width=True)
|
| 340 |
+
|
| 341 |
+
if detections:
|
| 342 |
+
import pandas as pd
|
| 343 |
+
df = pd.DataFrame(detections,
|
| 344 |
+
columns=["x1","y1","x2","y2","label","conf"])
|
| 345 |
+
st.dataframe(df, use_container_width=True, hide_index=True)
|
| 346 |
+
|
| 347 |
+
pipe["orb_dets"] = detections
|
| 348 |
+
pipe["orb_det_ms"] = total_ms
|
| 349 |
+
st.session_state["gen_pipeline"] = pipe
|
| 350 |
+
|
| 351 |
+
# ===================================================================
|
| 352 |
+
# Bottom — Comparison
|
| 353 |
+
# ===================================================================
|
| 354 |
+
rce_dets = pipe.get("rce_dets")
|
| 355 |
+
cnn_dets = pipe.get("cnn_dets")
|
| 356 |
+
orb_dets = pipe.get("orb_dets")
|
| 357 |
+
|
| 358 |
+
methods = {}
|
| 359 |
+
if rce_dets is not None:
|
| 360 |
+
methods["RCE"] = (rce_dets, pipe.get("rce_det_ms", 0), (0,255,0))
|
| 361 |
+
if cnn_dets is not None:
|
| 362 |
+
methods["CNN"] = (cnn_dets, pipe.get("cnn_det_ms", 0), (0,0,255))
|
| 363 |
+
if orb_dets is not None:
|
| 364 |
+
methods["ORB"] = (orb_dets, pipe.get("orb_det_ms", 0), (255,165,0))
|
| 365 |
+
|
| 366 |
+
if len(methods) >= 2:
|
| 367 |
+
st.divider()
|
| 368 |
+
st.subheader("📊 Side-by-Side Comparison")
|
| 369 |
+
import pandas as pd
|
| 370 |
+
comp = {"Metric": ["Detections", "Best Confidence", "Total Time (ms)"]}
|
| 371 |
+
for name, (dets, ms, _) in methods.items():
|
| 372 |
+
comp[name] = [
|
| 373 |
+
len(dets),
|
| 374 |
+
f"{max((d[5] for d in dets), default=0):.1%}",
|
| 375 |
+
f"{ms:.0f}",
|
| 376 |
+
]
|
| 377 |
+
st.dataframe(pd.DataFrame(comp), use_container_width=True, hide_index=True)
|
| 378 |
+
|
| 379 |
+
overlay = test_img.copy()
|
| 380 |
+
for name, (dets, _, clr) in methods.items():
|
| 381 |
+
for x1d, y1d, x2d, y2d, lbl, cf in dets:
|
| 382 |
+
cv2.rectangle(overlay, (x1d, y1d), (x2d, y2d), clr, 2)
|
| 383 |
+
cv2.putText(overlay, f"{name}:{lbl} {cf:.0%}", (x1d, y1d - 6),
|
| 384 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.35, clr, 1)
|
| 385 |
+
legend = " | ".join(f"{n}={'green' if c==(0,255,0) else 'blue' if c==(0,0,255) else 'orange'}"
|
| 386 |
+
for n, (_, _, c) in methods.items())
|
| 387 |
+
st.image(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB),
|
| 388 |
+
caption=legend, use_container_width=True)
|
tabs/generalisation/evaluation.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Generalisation Evaluation — Stage 6 of the Generalisation pipeline."""
|
| 2 |
+
|
| 3 |
+
import streamlit as st
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
import plotly.graph_objects as go
|
| 7 |
+
import plotly.figure_factory as ff
|
| 8 |
+
|
| 9 |
+
from src.models import BACKBONES
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _iou(a, b):
|
| 13 |
+
xi1 = max(a[0], b[0]); yi1 = max(a[1], b[1])
|
| 14 |
+
xi2 = min(a[2], b[2]); yi2 = min(a[3], b[3])
|
| 15 |
+
inter = max(0, xi2 - xi1) * max(0, yi2 - yi1)
|
| 16 |
+
aa = (a[2] - a[0]) * (a[3] - a[1])
|
| 17 |
+
ab = (b[2] - b[0]) * (b[3] - b[1])
|
| 18 |
+
return inter / (aa + ab - inter + 1e-6)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def match_detections(dets, gt_list, iou_thr):
|
| 22 |
+
dets_sorted = sorted(dets, key=lambda d: d[5], reverse=True)
|
| 23 |
+
matched_gt = set()
|
| 24 |
+
results = []
|
| 25 |
+
for det in dets_sorted:
|
| 26 |
+
det_box = det[:4]
|
| 27 |
+
best_iou, best_gt_idx, best_gt_label = 0.0, -1, None
|
| 28 |
+
for gi, (gt_box, gt_label) in enumerate(gt_list):
|
| 29 |
+
if gi in matched_gt:
|
| 30 |
+
continue
|
| 31 |
+
iou_val = _iou(det_box, gt_box)
|
| 32 |
+
if iou_val > best_iou:
|
| 33 |
+
best_iou, best_gt_idx, best_gt_label = iou_val, gi, gt_label
|
| 34 |
+
if best_iou >= iou_thr and best_gt_idx >= 0:
|
| 35 |
+
matched_gt.add(best_gt_idx)
|
| 36 |
+
results.append((det, best_gt_label, best_iou))
|
| 37 |
+
else:
|
| 38 |
+
results.append((det, None, best_iou))
|
| 39 |
+
return results, len(gt_list) - len(matched_gt), matched_gt
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def compute_pr_curve(dets, gt_list, iou_thr, steps=50):
|
| 43 |
+
if not dets:
|
| 44 |
+
return [], [], [], []
|
| 45 |
+
thresholds = np.linspace(0.0, 1.0, steps)
|
| 46 |
+
precisions, recalls, f1s = [], [], []
|
| 47 |
+
for thr in thresholds:
|
| 48 |
+
filtered = [d for d in dets if d[5] >= thr]
|
| 49 |
+
if not filtered:
|
| 50 |
+
precisions.append(1.0); recalls.append(0.0); f1s.append(0.0)
|
| 51 |
+
continue
|
| 52 |
+
matched, n_missed, _ = match_detections(filtered, gt_list, iou_thr)
|
| 53 |
+
tp = sum(1 for _, gt_lbl, _ in matched if gt_lbl is not None)
|
| 54 |
+
fp = sum(1 for _, gt_lbl, _ in matched if gt_lbl is None)
|
| 55 |
+
fn = n_missed
|
| 56 |
+
prec = tp / (tp + fp) if (tp + fp) > 0 else 1.0
|
| 57 |
+
rec = tp / (tp + fn) if (tp + fn) > 0 else 0.0
|
| 58 |
+
f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0.0
|
| 59 |
+
precisions.append(prec); recalls.append(rec); f1s.append(f1)
|
| 60 |
+
return thresholds.tolist(), precisions, recalls, f1s
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def build_confusion_matrix(dets, gt_list, iou_thr):
|
| 64 |
+
gt_labels = sorted(set(lbl for _, lbl in gt_list))
|
| 65 |
+
all_labels = gt_labels + ["background"]
|
| 66 |
+
n = len(all_labels)
|
| 67 |
+
matrix = np.zeros((n, n), dtype=int)
|
| 68 |
+
label_to_idx = {lbl: i for i, lbl in enumerate(all_labels)}
|
| 69 |
+
matched, n_missed, matched_gt_indices = match_detections(dets, gt_list, iou_thr)
|
| 70 |
+
for det, gt_lbl, _ in matched:
|
| 71 |
+
pred_lbl = det[4]
|
| 72 |
+
if gt_lbl is not None:
|
| 73 |
+
pi = label_to_idx.get(pred_lbl, label_to_idx["background"])
|
| 74 |
+
gi = label_to_idx[gt_lbl]
|
| 75 |
+
matrix[pi][gi] += 1
|
| 76 |
+
else:
|
| 77 |
+
pi = label_to_idx.get(pred_lbl, label_to_idx["background"])
|
| 78 |
+
matrix[pi][label_to_idx["background"]] += 1
|
| 79 |
+
for gi, (_, gt_lbl) in enumerate(gt_list):
|
| 80 |
+
if gi not in matched_gt_indices:
|
| 81 |
+
matrix[label_to_idx["background"]][label_to_idx[gt_lbl]] += 1
|
| 82 |
+
return matrix, all_labels
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def render():
|
| 86 |
+
st.title("📈 Evaluation: Confusion Matrix & PR Curves")
|
| 87 |
+
|
| 88 |
+
pipe = st.session_state.get("gen_pipeline")
|
| 89 |
+
if not pipe:
|
| 90 |
+
st.error("Complete the **Data Lab** first.")
|
| 91 |
+
st.stop()
|
| 92 |
+
|
| 93 |
+
crop = pipe.get("crop")
|
| 94 |
+
crop_aug = pipe.get("crop_aug", crop)
|
| 95 |
+
bbox = pipe.get("crop_bbox", (0, 0, crop.shape[1], crop.shape[0])) if crop is not None else None
|
| 96 |
+
rois = pipe.get("rois", [{"label": "object", "bbox": bbox,
|
| 97 |
+
"crop": crop, "crop_aug": crop_aug}])
|
| 98 |
+
|
| 99 |
+
rce_dets = pipe.get("rce_dets")
|
| 100 |
+
cnn_dets = pipe.get("cnn_dets")
|
| 101 |
+
orb_dets = pipe.get("orb_dets")
|
| 102 |
+
|
| 103 |
+
if rce_dets is None and cnn_dets is None and orb_dets is None:
|
| 104 |
+
st.warning("Run detection first in **Real-Time Detection**.")
|
| 105 |
+
st.stop()
|
| 106 |
+
|
| 107 |
+
gt_boxes = [(roi["bbox"], roi["label"]) for roi in rois]
|
| 108 |
+
|
| 109 |
+
st.sidebar.subheader("Evaluation Settings")
|
| 110 |
+
iou_thresh = st.sidebar.slider("IoU Threshold", 0.1, 0.9, 0.5, 0.05,
|
| 111 |
+
help="Minimum IoU to count as TP",
|
| 112 |
+
key="gen_eval_iou")
|
| 113 |
+
|
| 114 |
+
st.subheader("Ground Truth (from Data Lab ROIs)")
|
| 115 |
+
st.caption(f"{len(gt_boxes)} ground-truth ROIs defined")
|
| 116 |
+
gt_vis = pipe["test_image"].copy()
|
| 117 |
+
for (bx0, by0, bx1, by1), lbl in gt_boxes:
|
| 118 |
+
cv2.rectangle(gt_vis, (bx0, by0), (bx1, by1), (0, 255, 255), 2)
|
| 119 |
+
cv2.putText(gt_vis, lbl, (bx0, by0 - 6),
|
| 120 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 255), 1)
|
| 121 |
+
st.image(cv2.cvtColor(gt_vis, cv2.COLOR_BGR2RGB),
|
| 122 |
+
caption="Ground Truth Annotations", use_container_width=True)
|
| 123 |
+
st.divider()
|
| 124 |
+
|
| 125 |
+
methods = {}
|
| 126 |
+
if rce_dets is not None:
|
| 127 |
+
methods["RCE"] = rce_dets
|
| 128 |
+
if cnn_dets is not None:
|
| 129 |
+
methods["CNN"] = cnn_dets
|
| 130 |
+
if orb_dets is not None:
|
| 131 |
+
methods["ORB"] = orb_dets
|
| 132 |
+
|
| 133 |
+
# Confusion Matrices
|
| 134 |
+
st.subheader("🔲 Confusion Matrices")
|
| 135 |
+
cm_cols = st.columns(len(methods))
|
| 136 |
+
for col, (name, dets) in zip(cm_cols, methods.items()):
|
| 137 |
+
with col:
|
| 138 |
+
st.markdown(f"**{name}**")
|
| 139 |
+
matrix, labels = build_confusion_matrix(dets, gt_boxes, iou_thresh)
|
| 140 |
+
fig_cm = ff.create_annotated_heatmap(
|
| 141 |
+
z=matrix.tolist(), x=labels, y=labels,
|
| 142 |
+
colorscale="Blues", showscale=True)
|
| 143 |
+
fig_cm.update_layout(title=f"{name} Confusion Matrix",
|
| 144 |
+
xaxis_title="Actual", yaxis_title="Predicted",
|
| 145 |
+
template="plotly_dark", height=350)
|
| 146 |
+
fig_cm.update_yaxes(autorange="reversed")
|
| 147 |
+
st.plotly_chart(fig_cm, use_container_width=True)
|
| 148 |
+
|
| 149 |
+
matched, n_missed, _ = match_detections(dets, gt_boxes, iou_thresh)
|
| 150 |
+
tp = sum(1 for _, g, _ in matched if g is not None)
|
| 151 |
+
fp = sum(1 for _, g, _ in matched if g is None)
|
| 152 |
+
fn = n_missed
|
| 153 |
+
prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0
|
| 154 |
+
rec = tp / (tp + fn) if (tp + fn) > 0 else 0.0
|
| 155 |
+
f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0.0
|
| 156 |
+
m1, m2, m3 = st.columns(3)
|
| 157 |
+
m1.metric("Precision", f"{prec:.1%}")
|
| 158 |
+
m2.metric("Recall", f"{rec:.1%}")
|
| 159 |
+
m3.metric("F1 Score", f"{f1:.1%}")
|
| 160 |
+
|
| 161 |
+
# PR Curves
|
| 162 |
+
st.divider()
|
| 163 |
+
st.subheader("📉 Precision-Recall Curves")
|
| 164 |
+
method_colors = {"RCE": "#00ff88", "CNN": "#4488ff", "ORB": "#ff8800"}
|
| 165 |
+
fig_pr = go.Figure()
|
| 166 |
+
fig_f1 = go.Figure()
|
| 167 |
+
summary_rows = []
|
| 168 |
+
|
| 169 |
+
for name, dets in methods.items():
|
| 170 |
+
thrs, precs, recs, f1s = compute_pr_curve(dets, gt_boxes, iou_thresh)
|
| 171 |
+
clr = method_colors.get(name, "#ffffff")
|
| 172 |
+
fig_pr.add_trace(go.Scatter(
|
| 173 |
+
x=recs, y=precs, mode="lines+markers",
|
| 174 |
+
name=name, line=dict(color=clr, width=2), marker=dict(size=4)))
|
| 175 |
+
fig_f1.add_trace(go.Scatter(
|
| 176 |
+
x=thrs, y=f1s, mode="lines",
|
| 177 |
+
name=name, line=dict(color=clr, width=2)))
|
| 178 |
+
ap = float(np.trapz(precs, recs)) if recs and precs else 0.0
|
| 179 |
+
best_f1_idx = int(np.argmax(f1s)) if f1s else 0
|
| 180 |
+
summary_rows.append({
|
| 181 |
+
"Method": name,
|
| 182 |
+
"AP": f"{abs(ap):.3f}",
|
| 183 |
+
"Best F1": f"{f1s[best_f1_idx]:.3f}" if f1s else "N/A",
|
| 184 |
+
"@ Threshold": f"{thrs[best_f1_idx]:.2f}" if thrs else "N/A",
|
| 185 |
+
"Detections": len(dets),
|
| 186 |
+
})
|
| 187 |
+
|
| 188 |
+
fig_pr.update_layout(title="Precision vs Recall",
|
| 189 |
+
xaxis_title="Recall", yaxis_title="Precision",
|
| 190 |
+
template="plotly_dark", height=400,
|
| 191 |
+
xaxis=dict(range=[0, 1.05]), yaxis=dict(range=[0, 1.05]))
|
| 192 |
+
fig_f1.update_layout(title="F1 Score vs Confidence Threshold",
|
| 193 |
+
xaxis_title="Confidence Threshold", yaxis_title="F1 Score",
|
| 194 |
+
template="plotly_dark", height=400,
|
| 195 |
+
xaxis=dict(range=[0, 1.05]), yaxis=dict(range=[0, 1.05]))
|
| 196 |
+
pc1, pc2 = st.columns(2)
|
| 197 |
+
pc1.plotly_chart(fig_pr, use_container_width=True)
|
| 198 |
+
pc2.plotly_chart(fig_f1, use_container_width=True)
|
| 199 |
+
|
| 200 |
+
# Summary Table
|
| 201 |
+
st.divider()
|
| 202 |
+
st.subheader("📊 Summary")
|
| 203 |
+
import pandas as pd
|
| 204 |
+
st.dataframe(pd.DataFrame(summary_rows), use_container_width=True, hide_index=True)
|
| 205 |
+
st.caption(f"All metrics computed at IoU threshold = **{iou_thresh:.2f}**.")
|
tabs/generalisation/feature_lab.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Generalisation Feature Lab — Stage 2 of the Generalisation pipeline."""
|
| 2 |
+
|
| 3 |
+
import streamlit as st
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
import plotly.graph_objects as go
|
| 7 |
+
|
| 8 |
+
from src.detectors.rce.features import REGISTRY
|
| 9 |
+
from src.models import BACKBONES
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def render():
|
| 13 |
+
pipe = st.session_state.get("gen_pipeline")
|
| 14 |
+
if not pipe or "crop" not in pipe:
|
| 15 |
+
st.error("Please complete the **Data Lab** first!")
|
| 16 |
+
st.stop()
|
| 17 |
+
|
| 18 |
+
obj = pipe.get("crop_aug", pipe.get("crop"))
|
| 19 |
+
if obj is None:
|
| 20 |
+
st.error("No crop found. Go back to Data Lab and define a ROI.")
|
| 21 |
+
st.stop()
|
| 22 |
+
gray = cv2.cvtColor(obj, cv2.COLOR_BGR2GRAY)
|
| 23 |
+
|
| 24 |
+
st.title("🔬 Feature Lab: Physical Module Selection")
|
| 25 |
+
|
| 26 |
+
col_rce, col_cnn = st.columns([3, 2])
|
| 27 |
+
|
| 28 |
+
with col_rce:
|
| 29 |
+
st.header("🧬 RCE: Modular Physics Engine")
|
| 30 |
+
st.subheader("Select Active Modules")
|
| 31 |
+
|
| 32 |
+
active = {}
|
| 33 |
+
items = list(REGISTRY.items())
|
| 34 |
+
for row_start in range(0, len(items), 4):
|
| 35 |
+
row_items = items[row_start:row_start + 4]
|
| 36 |
+
m_cols = st.columns(4)
|
| 37 |
+
for col, (key, meta) in zip(m_cols, row_items):
|
| 38 |
+
active[key] = col.checkbox(meta["label"],
|
| 39 |
+
value=(key in ("intensity", "sobel", "spectral")),
|
| 40 |
+
key=f"gen_fl_{key}")
|
| 41 |
+
|
| 42 |
+
final_vector = []
|
| 43 |
+
viz_images = []
|
| 44 |
+
for key, meta in REGISTRY.items():
|
| 45 |
+
if active[key]:
|
| 46 |
+
vec, viz = meta["fn"](gray)
|
| 47 |
+
final_vector.extend(vec)
|
| 48 |
+
viz_images.append((meta["viz_title"], viz))
|
| 49 |
+
|
| 50 |
+
st.divider()
|
| 51 |
+
if viz_images:
|
| 52 |
+
for row_start in range(0, len(viz_images), 3):
|
| 53 |
+
row = viz_images[row_start:row_start + 3]
|
| 54 |
+
v_cols = st.columns(3)
|
| 55 |
+
for col, (title, img) in zip(v_cols, row):
|
| 56 |
+
col.image(img, caption=title, use_container_width=True)
|
| 57 |
+
else:
|
| 58 |
+
st.warning("No modules selected — vector is empty.")
|
| 59 |
+
|
| 60 |
+
st.write(f"### Current DNA Vector Size: **{len(final_vector)}**")
|
| 61 |
+
fig_vec = go.Figure(data=[go.Bar(y=final_vector, marker_color="#00d4ff")])
|
| 62 |
+
fig_vec.update_layout(title="Active Feature Vector (RCE Input)",
|
| 63 |
+
template="plotly_dark", height=300)
|
| 64 |
+
st.plotly_chart(fig_vec, use_container_width=True)
|
| 65 |
+
|
| 66 |
+
with col_cnn:
|
| 67 |
+
st.header("🧠 CNN: Static Architecture")
|
| 68 |
+
selected_cnn = st.selectbox("Compare against Model", list(BACKBONES.keys()),
|
| 69 |
+
key="gen_fl_cnn")
|
| 70 |
+
st.info("CNN features are fixed by pre-trained weights.")
|
| 71 |
+
|
| 72 |
+
with st.spinner(f"Loading {selected_cnn} and extracting activations..."):
|
| 73 |
+
try:
|
| 74 |
+
bmeta = BACKBONES[selected_cnn]
|
| 75 |
+
backbone = bmeta["loader"]()
|
| 76 |
+
layer_name = bmeta["hook_layer"]
|
| 77 |
+
act_maps = backbone.get_activation_maps(obj, n_maps=6)
|
| 78 |
+
st.caption(f"Hooked layer: `{layer_name}` — showing 6 of "
|
| 79 |
+
f"{len(act_maps)} channels")
|
| 80 |
+
act_cols = st.columns(3)
|
| 81 |
+
for i, amap in enumerate(act_maps):
|
| 82 |
+
act_cols[i % 3].image(amap, caption=f"Channel {i}",
|
| 83 |
+
use_container_width=True)
|
| 84 |
+
except Exception as e:
|
| 85 |
+
st.error(f"Could not load model: {e}")
|
| 86 |
+
|
| 87 |
+
st.divider()
|
| 88 |
+
st.markdown(f"""
|
| 89 |
+
**Analysis:**
|
| 90 |
+
- **Modularity:** RCE is **High** | CNN is **Zero**
|
| 91 |
+
- **Explainability:** RCE is **High** | CNN is **Low**
|
| 92 |
+
- **Compute Cost:** {len(final_vector)} floats | 512+ floats
|
| 93 |
+
""")
|
| 94 |
+
|
| 95 |
+
if st.button("🚀 Lock Modular Configuration", key="gen_fl_lock"):
|
| 96 |
+
if not final_vector:
|
| 97 |
+
st.error("Please select at least one module!")
|
| 98 |
+
else:
|
| 99 |
+
pipe["final_vector"] = np.array(final_vector)
|
| 100 |
+
pipe["active_modules"] = {k: v for k, v in active.items()}
|
| 101 |
+
st.session_state["gen_pipeline"] = pipe
|
| 102 |
+
st.success("Modular DNA Locked! Ready for Model Tuning.")
|
tabs/generalisation/localization.py
ADDED
|
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Generalisation Localization Lab — Stage 4 of the Generalisation pipeline."""
|
| 2 |
+
|
| 3 |
+
import streamlit as st
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import plotly.graph_objects as go
|
| 8 |
+
|
| 9 |
+
from src.detectors.rce.features import REGISTRY
|
| 10 |
+
from src.models import BACKBONES
|
| 11 |
+
from src.utils import build_rce_vector
|
| 12 |
+
from src.localization import (
|
| 13 |
+
exhaustive_sliding_window,
|
| 14 |
+
image_pyramid,
|
| 15 |
+
coarse_to_fine,
|
| 16 |
+
contour_proposals,
|
| 17 |
+
template_matching,
|
| 18 |
+
STRATEGIES,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def render():
|
| 23 |
+
st.title("🔍 Localization Lab")
|
| 24 |
+
st.markdown(
|
| 25 |
+
"Compare **localization strategies** — algorithms that decide *where* "
|
| 26 |
+
"to look in the image. The recognition head stays the same; only the "
|
| 27 |
+
"search method changes."
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
pipe = st.session_state.get("gen_pipeline")
|
| 31 |
+
if not pipe or "crop" not in pipe:
|
| 32 |
+
st.error("Complete **Data Lab** first (upload assets & define a crop).")
|
| 33 |
+
st.stop()
|
| 34 |
+
|
| 35 |
+
test_img = pipe["test_image"]
|
| 36 |
+
crop = pipe["crop"]
|
| 37 |
+
crop_aug = pipe.get("crop_aug", crop)
|
| 38 |
+
bbox = pipe.get("crop_bbox", (0, 0, crop.shape[1], crop.shape[0]))
|
| 39 |
+
active_mods = pipe.get("active_modules", {k: True for k in REGISTRY})
|
| 40 |
+
|
| 41 |
+
x0, y0, x1, y1 = bbox
|
| 42 |
+
win_h, win_w = y1 - y0, x1 - x0
|
| 43 |
+
|
| 44 |
+
if win_h <= 0 or win_w <= 0:
|
| 45 |
+
st.error("Invalid window size from crop bbox. "
|
| 46 |
+
"Go back to **Data Lab** and redefine the ROI.")
|
| 47 |
+
st.stop()
|
| 48 |
+
|
| 49 |
+
rce_head = pipe.get("rce_head")
|
| 50 |
+
has_any_cnn = any(f"cnn_head_{n}" in pipe for n in BACKBONES)
|
| 51 |
+
|
| 52 |
+
if rce_head is None and not has_any_cnn:
|
| 53 |
+
st.warning("No trained heads found. Go to **Model Tuning** first.")
|
| 54 |
+
st.stop()
|
| 55 |
+
|
| 56 |
+
def rce_feature_fn(patch_bgr):
|
| 57 |
+
return build_rce_vector(patch_bgr, active_mods)
|
| 58 |
+
|
| 59 |
+
# Algorithm Reference
|
| 60 |
+
st.divider()
|
| 61 |
+
with st.expander("📚 **Algorithm Reference** — click to expand", expanded=False):
|
| 62 |
+
tabs = st.tabs([f"{v['icon']} {k}" for k, v in STRATEGIES.items()])
|
| 63 |
+
for tab, (name, meta) in zip(tabs, STRATEGIES.items()):
|
| 64 |
+
with tab:
|
| 65 |
+
st.markdown(f"### {meta['icon']} {name}")
|
| 66 |
+
st.caption(meta["short"])
|
| 67 |
+
st.markdown(meta["detail"])
|
| 68 |
+
|
| 69 |
+
# Configuration
|
| 70 |
+
st.divider()
|
| 71 |
+
st.header("⚙️ Configuration")
|
| 72 |
+
|
| 73 |
+
col_head, col_info = st.columns([2, 3])
|
| 74 |
+
with col_head:
|
| 75 |
+
head_options = []
|
| 76 |
+
if rce_head is not None:
|
| 77 |
+
head_options.append("RCE")
|
| 78 |
+
trained_cnns = [n for n in BACKBONES if f"cnn_head_{n}" in pipe]
|
| 79 |
+
head_options.extend(trained_cnns)
|
| 80 |
+
selected_head = st.selectbox("Recognition Head", head_options,
|
| 81 |
+
key="gen_loc_head")
|
| 82 |
+
|
| 83 |
+
if selected_head == "RCE":
|
| 84 |
+
feature_fn = rce_feature_fn
|
| 85 |
+
head = rce_head
|
| 86 |
+
else:
|
| 87 |
+
bmeta = BACKBONES[selected_head]
|
| 88 |
+
backbone = bmeta["loader"]()
|
| 89 |
+
feature_fn = backbone.get_features
|
| 90 |
+
head = pipe[f"cnn_head_{selected_head}"]
|
| 91 |
+
|
| 92 |
+
with col_info:
|
| 93 |
+
if selected_head == "RCE":
|
| 94 |
+
mods = [REGISTRY[k]["label"] for k in active_mods if active_mods[k]]
|
| 95 |
+
st.info(f"**RCE** — Modules: {', '.join(mods)}")
|
| 96 |
+
else:
|
| 97 |
+
st.info(f"**{selected_head}** — "
|
| 98 |
+
f"{BACKBONES[selected_head]['dim']}D feature vector")
|
| 99 |
+
|
| 100 |
+
# Algorithm checkboxes
|
| 101 |
+
st.subheader("Select Algorithms to Compare")
|
| 102 |
+
algo_cols = st.columns(5)
|
| 103 |
+
algo_names = list(STRATEGIES.keys())
|
| 104 |
+
algo_checks = {}
|
| 105 |
+
for col, name in zip(algo_cols, algo_names):
|
| 106 |
+
algo_checks[name] = col.checkbox(
|
| 107 |
+
f"{STRATEGIES[name]['icon']} {name}",
|
| 108 |
+
value=(name != "Template Matching"),
|
| 109 |
+
key=f"gen_chk_{name}")
|
| 110 |
+
|
| 111 |
+
any_selected = any(algo_checks.values())
|
| 112 |
+
|
| 113 |
+
# Parameters
|
| 114 |
+
st.subheader("Parameters")
|
| 115 |
+
sp1, sp2, sp3 = st.columns(3)
|
| 116 |
+
stride = sp1.slider("Base Stride (px)", 4, max(win_w, win_h),
|
| 117 |
+
max(win_w // 4, 4), step=2, key="gen_loc_stride")
|
| 118 |
+
conf_thresh = sp2.slider("Confidence Threshold", 0.5, 1.0, 0.7, 0.05,
|
| 119 |
+
key="gen_loc_conf")
|
| 120 |
+
nms_iou = sp3.slider("NMS IoU Threshold", 0.1, 0.9, 0.3, 0.05,
|
| 121 |
+
key="gen_loc_nms")
|
| 122 |
+
|
| 123 |
+
with st.expander("🔧 Per-Algorithm Settings"):
|
| 124 |
+
pa1, pa2, pa3 = st.columns(3)
|
| 125 |
+
with pa1:
|
| 126 |
+
st.markdown("**Image Pyramid**")
|
| 127 |
+
pyr_min = st.slider("Min Scale", 0.3, 1.0, 0.5, 0.05, key="gen_pyr_min")
|
| 128 |
+
pyr_max = st.slider("Max Scale", 1.0, 2.0, 1.5, 0.1, key="gen_pyr_max")
|
| 129 |
+
pyr_n = st.slider("Number of Scales", 3, 7, 5, key="gen_pyr_n")
|
| 130 |
+
with pa2:
|
| 131 |
+
st.markdown("**Coarse-to-Fine**")
|
| 132 |
+
c2f_factor = st.slider("Coarse Factor", 2, 8, 4, key="gen_c2f_factor")
|
| 133 |
+
c2f_radius = st.slider("Refine Radius (strides)", 1, 5, 2, key="gen_c2f_radius")
|
| 134 |
+
with pa3:
|
| 135 |
+
st.markdown("**Contour Proposals**")
|
| 136 |
+
cnt_low = st.slider("Canny Low", 10, 100, 50, key="gen_cnt_low")
|
| 137 |
+
cnt_high = st.slider("Canny High", 50, 300, 150, key="gen_cnt_high")
|
| 138 |
+
cnt_tol = st.slider("Area Tolerance", 1.5, 10.0, 3.0, 0.5, key="gen_cnt_tol")
|
| 139 |
+
|
| 140 |
+
st.caption(
|
| 141 |
+
f"Window: **{win_w}×{win_h} px** · "
|
| 142 |
+
f"Image: **{test_img.shape[1]}×{test_img.shape[0]} px** · "
|
| 143 |
+
f"Stride: **{stride} px**"
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
# Run
|
| 147 |
+
st.divider()
|
| 148 |
+
run_btn = st.button("▶ Run Comparison", type="primary",
|
| 149 |
+
disabled=not any_selected, use_container_width=True,
|
| 150 |
+
key="gen_loc_run")
|
| 151 |
+
|
| 152 |
+
if run_btn:
|
| 153 |
+
selected_algos = [n for n in algo_names if algo_checks[n]]
|
| 154 |
+
progress = st.progress(0, text="Starting…")
|
| 155 |
+
results = {}
|
| 156 |
+
edge_maps = {}
|
| 157 |
+
|
| 158 |
+
for i, name in enumerate(selected_algos):
|
| 159 |
+
progress.progress(i / len(selected_algos), text=f"Running **{name}**…")
|
| 160 |
+
|
| 161 |
+
if name == "Exhaustive Sliding Window":
|
| 162 |
+
dets, n, ms, hmap = exhaustive_sliding_window(
|
| 163 |
+
test_img, win_h, win_w, feature_fn, head,
|
| 164 |
+
stride, conf_thresh, nms_iou)
|
| 165 |
+
elif name == "Image Pyramid":
|
| 166 |
+
scales = np.linspace(pyr_min, pyr_max, pyr_n).tolist()
|
| 167 |
+
dets, n, ms, hmap = image_pyramid(
|
| 168 |
+
test_img, win_h, win_w, feature_fn, head,
|
| 169 |
+
stride, conf_thresh, nms_iou, scales=scales)
|
| 170 |
+
elif name == "Coarse-to-Fine":
|
| 171 |
+
dets, n, ms, hmap = coarse_to_fine(
|
| 172 |
+
test_img, win_h, win_w, feature_fn, head,
|
| 173 |
+
stride, conf_thresh, nms_iou,
|
| 174 |
+
coarse_factor=c2f_factor, refine_radius=c2f_radius)
|
| 175 |
+
elif name == "Contour Proposals":
|
| 176 |
+
dets, n, ms, hmap, edges = contour_proposals(
|
| 177 |
+
test_img, win_h, win_w, feature_fn, head,
|
| 178 |
+
conf_thresh, nms_iou,
|
| 179 |
+
canny_low=cnt_low, canny_high=cnt_high,
|
| 180 |
+
area_tolerance=cnt_tol)
|
| 181 |
+
edge_maps[name] = edges
|
| 182 |
+
elif name == "Template Matching":
|
| 183 |
+
dets, n, ms, hmap = template_matching(
|
| 184 |
+
test_img, crop_aug, conf_thresh, nms_iou)
|
| 185 |
+
|
| 186 |
+
results[name] = {"dets": dets, "n_proposals": n,
|
| 187 |
+
"time_ms": ms, "heatmap": hmap}
|
| 188 |
+
|
| 189 |
+
progress.progress(1.0, text="Done!")
|
| 190 |
+
|
| 191 |
+
# Summary Table
|
| 192 |
+
st.header("📊 Results")
|
| 193 |
+
baseline_ms = results.get("Exhaustive Sliding Window", {}).get("time_ms")
|
| 194 |
+
rows = []
|
| 195 |
+
for name, r in results.items():
|
| 196 |
+
speedup = (baseline_ms / r["time_ms"]
|
| 197 |
+
if baseline_ms and r["time_ms"] > 0 else None)
|
| 198 |
+
rows.append({
|
| 199 |
+
"Algorithm": name,
|
| 200 |
+
"Proposals": r["n_proposals"],
|
| 201 |
+
"Time (ms)": round(r["time_ms"], 1),
|
| 202 |
+
"Detections": len(r["dets"]),
|
| 203 |
+
"ms / Proposal": round(r["time_ms"] / max(r["n_proposals"], 1), 4),
|
| 204 |
+
"Speedup": f"{speedup:.1f}×" if speedup else "—",
|
| 205 |
+
})
|
| 206 |
+
st.dataframe(pd.DataFrame(rows), use_container_width=True, hide_index=True)
|
| 207 |
+
|
| 208 |
+
# Detection Images & Heatmaps
|
| 209 |
+
st.subheader("Detection Results")
|
| 210 |
+
COLORS = {
|
| 211 |
+
"Exhaustive Sliding Window": (0, 255, 0),
|
| 212 |
+
"Image Pyramid": (255, 128, 0),
|
| 213 |
+
"Coarse-to-Fine": (0, 128, 255),
|
| 214 |
+
"Contour Proposals": (255, 0, 255),
|
| 215 |
+
"Template Matching": (0, 255, 255),
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
result_tabs = st.tabs(
|
| 219 |
+
[f"{STRATEGIES[n]['icon']} {n}" for n in results])
|
| 220 |
+
|
| 221 |
+
for tab, (name, r) in zip(result_tabs, results.items()):
|
| 222 |
+
with tab:
|
| 223 |
+
c1, c2 = st.columns(2)
|
| 224 |
+
color = COLORS.get(name, (0, 255, 0))
|
| 225 |
+
|
| 226 |
+
vis = test_img.copy()
|
| 227 |
+
for x1d, y1d, x2d, y2d, _, cf in r["dets"]:
|
| 228 |
+
cv2.rectangle(vis, (x1d, y1d), (x2d, y2d), color, 2)
|
| 229 |
+
cv2.putText(vis, f"{cf:.0%}", (x1d, y1d - 6),
|
| 230 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
|
| 231 |
+
c1.image(cv2.cvtColor(vis, cv2.COLOR_BGR2RGB),
|
| 232 |
+
caption=f"{name} — {len(r['dets'])} detections",
|
| 233 |
+
use_container_width=True)
|
| 234 |
+
|
| 235 |
+
hmap = r["heatmap"]
|
| 236 |
+
if hmap.max() > 0:
|
| 237 |
+
hmap_color = cv2.applyColorMap(
|
| 238 |
+
(hmap / hmap.max() * 255).astype(np.uint8),
|
| 239 |
+
cv2.COLORMAP_JET)
|
| 240 |
+
blend = cv2.addWeighted(test_img, 0.5, hmap_color, 0.5, 0)
|
| 241 |
+
c2.image(cv2.cvtColor(blend, cv2.COLOR_BGR2RGB),
|
| 242 |
+
caption=f"{name} — Confidence Heatmap",
|
| 243 |
+
use_container_width=True)
|
| 244 |
+
else:
|
| 245 |
+
c2.info("No positive responses above threshold.")
|
| 246 |
+
|
| 247 |
+
if name in edge_maps:
|
| 248 |
+
st.image(edge_maps[name],
|
| 249 |
+
caption="Canny Edge Map",
|
| 250 |
+
use_container_width=True, clamp=True)
|
| 251 |
+
|
| 252 |
+
m1, m2, m3, m4 = st.columns(4)
|
| 253 |
+
m1.metric("Proposals", r["n_proposals"])
|
| 254 |
+
m2.metric("Time", f"{r['time_ms']:.0f} ms")
|
| 255 |
+
m3.metric("Detections", len(r["dets"]))
|
| 256 |
+
m4.metric("ms / Proposal",
|
| 257 |
+
f"{r['time_ms'] / max(r['n_proposals'], 1):.3f}")
|
| 258 |
+
|
| 259 |
+
if r["dets"]:
|
| 260 |
+
df = pd.DataFrame(r["dets"],
|
| 261 |
+
columns=["x1","y1","x2","y2","label","conf"])
|
| 262 |
+
st.dataframe(df, use_container_width=True, hide_index=True)
|
| 263 |
+
|
| 264 |
+
# Performance Charts
|
| 265 |
+
st.subheader("📈 Performance Comparison")
|
| 266 |
+
ch1, ch2 = st.columns(2)
|
| 267 |
+
names = list(results.keys())
|
| 268 |
+
times = [results[n]["time_ms"] for n in names]
|
| 269 |
+
props = [results[n]["n_proposals"] for n in names]
|
| 270 |
+
n_dets = [len(results[n]["dets"]) for n in names]
|
| 271 |
+
colors_hex = ["#00cc66", "#ff8800", "#0088ff", "#ff00ff", "#00cccc"]
|
| 272 |
+
|
| 273 |
+
with ch1:
|
| 274 |
+
fig = go.Figure(go.Bar(
|
| 275 |
+
x=names, y=times,
|
| 276 |
+
text=[f"{t:.0f}" for t in times], textposition="auto",
|
| 277 |
+
marker_color=colors_hex[:len(names)]))
|
| 278 |
+
fig.update_layout(title="Total Time (ms)", yaxis_title="ms", height=400)
|
| 279 |
+
st.plotly_chart(fig, use_container_width=True)
|
| 280 |
+
|
| 281 |
+
with ch2:
|
| 282 |
+
fig = go.Figure(go.Bar(
|
| 283 |
+
x=names, y=props,
|
| 284 |
+
text=[str(p) for p in props], textposition="auto",
|
| 285 |
+
marker_color=colors_hex[:len(names)]))
|
| 286 |
+
fig.update_layout(title="Proposals Evaluated",
|
| 287 |
+
yaxis_title="Count", height=400)
|
| 288 |
+
st.plotly_chart(fig, use_container_width=True)
|
| 289 |
+
|
| 290 |
+
fig = go.Figure()
|
| 291 |
+
for i, name in enumerate(names):
|
| 292 |
+
fig.add_trace(go.Scatter(
|
| 293 |
+
x=[props[i]], y=[times[i]],
|
| 294 |
+
mode="markers+text",
|
| 295 |
+
marker=dict(size=max(n_dets[i] * 12, 18),
|
| 296 |
+
color=colors_hex[i % len(colors_hex)]),
|
| 297 |
+
text=[name], textposition="top center", name=name))
|
| 298 |
+
fig.update_layout(
|
| 299 |
+
title="Proposals vs Time (marker size ∝ detections)",
|
| 300 |
+
xaxis_title="Proposals Evaluated",
|
| 301 |
+
yaxis_title="Time (ms)", height=500)
|
| 302 |
+
st.plotly_chart(fig, use_container_width=True)
|