DariusGiannoli commited on
Commit
8ac50b6
·
1 Parent(s): 565e61c

feat: central model registry, real-time detection, stereo geometry & home page

Browse files
app.py CHANGED
@@ -1,17 +1,192 @@
1
  import streamlit as st
2
- import cv2
3
- import numpy as np
4
- from src.detectors.yolo import YOLODetector
5
-
6
- st.set_page_config(page_title="Perception Benchmark", layout="wide")
7
-
8
- st.title("🦅 Bird Perception Stack")
9
- st.write("Current Status: Recognition Engine Online. Stereo Depth Engine Pending.")
10
-
11
- # Simple test of your existing YOLO class
12
- if st.button("Initialize YOLOv8n"):
13
- try:
14
- detector = YOLODetector()
15
- st.success("YOLOv8n Loaded Successfully from weights!")
16
- except Exception as e:
17
- st.error(f"Error loading weights: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+
3
+ st.set_page_config(page_title="Perception Benchmark", layout="wide", page_icon="🦅")
4
+
5
+ # ===================================================================
6
+ # Header
7
+ # ===================================================================
8
+ st.title("🦅 Recognition BenchMark")
9
+ st.subheader("A stereo-vision pipeline for object recognition & depth estimation")
10
+ st.caption("Compare classical feature engineering (RCE) against modern deep learning backbones — end-to-end, in your browser.")
11
+
12
+ st.divider()
13
+
14
+ # ===================================================================
15
+ # Pipeline Overview
16
+ # ===================================================================
17
+ st.header("🗺️ Pipeline Overview")
18
+ st.markdown("""
19
+ The app is structured as a **5-stage sequential pipeline**.
20
+ Complete each page in order — every stage feeds the next.
21
+ """)
22
+
23
+ stages = [
24
+ ("🧪", "1 · Data Lab", "Upload a stereo image pair, camera calibration file, and two PFM ground-truth depth maps. "
25
+ "Define an object ROI (bounding box), then apply live data augmentation "
26
+ "(brightness, contrast, rotation, noise, blur, shift, flip). "
27
+ "All assets are locked into session state — nothing is written to disk."),
28
+ ("🔬", "2 · Feature Lab", "Toggle RCE physics modules (Intensity · Sobel · Spectral) to build a modular "
29
+ "feature vector. Compare it live against CNN activation maps extracted from a "
30
+ "frozen backbone via forward hooks. Lock your active module configuration."),
31
+ ("⚙️", "3 · Model Tuning", "Train lightweight **heads** on your session data (augmented crop = positives, "
32
+ "random non-overlapping patches from the scene = negatives). "
33
+ "Both RCE and CNN heads are trained identically with LogisticRegression "
34
+ "and stored in session state only — no disk writes."),
35
+ ("🎯", "4 · Real-Time Detection","Run a **sliding window** across the right image using both the RCE head and "
36
+ "your chosen CNN head simultaneously. Watch the scan live, then compare "
37
+ "bounding boxes, confidence heatmaps, and latency."),
38
+ ("📐", "5 · Stereo Geometry", "Compute a disparity map with **StereoSGBM**, convert it to metric depth "
39
+ "using the stereo formula $Z = fB/(d+d_{\\text{offs}})$, then read depth "
40
+ "directly at every detected bounding box. Compare against PFM ground truth."),
41
+ ]
42
+
43
+ for icon, title, desc in stages:
44
+ with st.container(border=True):
45
+ c1, c2 = st.columns([1, 12])
46
+ c1.markdown(f"## {icon}")
47
+ c2.markdown(f"**{title}** \n{desc}")
48
+
49
+ st.divider()
50
+
51
+ # ===================================================================
52
+ # Models
53
+ # ===================================================================
54
+ st.header("🧠 Models Used")
55
+
56
+ tab_rce, tab_resnet, tab_mobilenet, tab_mobilevit = st.tabs(
57
+ ["RCE Engine", "ResNet-18", "MobileNetV3-Small", "MobileViT-XXS"])
58
+
59
+ with tab_rce:
60
+ st.markdown("### 🧬 RCE — Relative Contextual Encoding")
61
+ st.markdown("""
62
+ **Type:** Modular hand-crafted feature extractor
63
+ **Architecture:** Three physics-inspired modules, each producing a 10-bin histogram:
64
+
65
+ | Module | Input | Operation |
66
+ |--------|-------|-----------|
67
+ | **Intensity** | Grayscale | Pixel-value histogram (global appearance) |
68
+ | **Sobel** | Gradient magnitude | Edge strength distribution (texture) |
69
+ | **Spectral** | FFT log-magnitude | Frequency content (pattern / structure) |
70
+
71
+ **Strengths:**
72
+ - Fully explainable — every dimension has a physical meaning
73
+ - Extremely fast (µs per patch, no GPU needed)
74
+ - Modular: disable any module and immediately see the effect on the vector
75
+ - Zero pre-training needed
76
+
77
+ **Weakness:** Less discriminative than deep features for complex visual scenes.
78
+ """)
79
+
80
+ with tab_resnet:
81
+ st.markdown("### 🏗️ ResNet-18")
82
+ st.markdown("""
83
+ **Source:** PyTorch Hub (`torchvision.models.ResNet18_Weights.DEFAULT`)
84
+ **Pre-training:** ImageNet-1k (1.28 M images, 1 000 classes)
85
+ **Backbone output:** 512-dimensional embedding (after `avgpool`)
86
+ **Head:** LogisticRegression trained on your session data
87
+
88
+ **Architecture highlights:**
89
+ - 18 layers with residual (skip) connections
90
+ - Residual blocks prevent vanishing gradients in deeper networks
91
+ - `layer4` is hooked for activation map visualisation
92
+
93
+ **In this app:** The entire backbone is **frozen** (`requires_grad=False`).
94
+ Only the lightweight head adapts to your specific object.
95
+ """)
96
+
97
+ with tab_mobilenet:
98
+ st.markdown("### 📱 MobileNetV3-Small")
99
+ st.markdown("""
100
+ **Source:** PyTorch Hub (`torchvision.models.MobileNet_V3_Small_Weights.DEFAULT`)
101
+ **Pre-training:** ImageNet-1k
102
+ **Backbone output:** 576-dimensional embedding (classifier replaced with `Identity`)
103
+ **Head:** LogisticRegression trained on your session data
104
+
105
+ **Architecture highlights:**
106
+ - Inverted residuals + linear bottlenecks (MobileNetV2 heritage)
107
+ - Hard-Swish / Hard-Sigmoid activations (hardware-friendly)
108
+ - Squeeze-and-Excitation (SE) blocks for channel attention
109
+ - Designed for **edge / mobile inference** — ~2.5 M parameters
110
+
111
+ **In this app:** Typically 3–5× faster than ResNet-18.
112
+ `features[-1]` is hooked for activation maps.
113
+ """)
114
+
115
+ with tab_mobilevit:
116
+ st.markdown("### 🤖 MobileViT-XXS")
117
+ st.markdown("""
118
+ **Source:** timm — `mobilevit_xxs.cvnets_in1k` (Apple Research, 2022)
119
+ **Pre-training:** ImageNet-1k
120
+ **Backbone output:** 320-dimensional embedding (`num_classes=0`)
121
+ **Head:** LogisticRegression trained on your session data
122
+
123
+ **Architecture highlights:**
124
+ - **Hybrid CNN + Vision Transformer** — local convolutions for spatial features,
125
+ global self-attention for long-range context
126
+ - MobileNetV2 stem + MobileViT blocks (attention on non-overlapping patches)
127
+ - Only ~1.3 M parameters — smallest of the three
128
+
129
+ **In this app:** The final transformer stage `stages[-1]` is hooked.
130
+ Slower than MobileNetV3 but captures global structure.
131
+ """)
132
+
133
+ st.divider()
134
+
135
+ # ===================================================================
136
+ # Depth Estimation
137
+ # ===================================================================
138
+ st.header("📐 Stereo Depth Estimation")
139
+
140
+ col_d1, col_d2 = st.columns(2)
141
+ with col_d1:
142
+ st.markdown("""
143
+ **Algorithm:** `cv2.StereoSGBM` (Semi-Global Block Matching)
144
+
145
+ SGBM minimises a global energy function combining:
146
+ - Data cost (pixel intensity difference)
147
+ - Smoothness penalty (P1, P2 regularisation)
148
+
149
+ It processes multiple horizontal and diagonal scan-line passes,
150
+ making it significantly more accurate than basic block matching.
151
+ """)
152
+ with col_d2:
153
+ st.markdown("""
154
+ **Depth formula (Middlebury convention):**
155
+ """)
156
+ st.latex(r"Z = \frac{f \times B}{d + d_{\text{offs}}}")
157
+ st.markdown("""
158
+ - $f$ — focal length (pixels)
159
+ - $B$ — baseline (mm, from calibration file)
160
+ - $d$ — disparity (pixels)
161
+ - $d_\\text{offs}$ — optical-center offset between cameras
162
+ """)
163
+
164
+ st.divider()
165
+
166
+ # ===================================================================
167
+ # Session Status
168
+ # ===================================================================
169
+ st.header("📋 Session Status")
170
+
171
+ pipe = st.session_state.get("pipeline_data", {})
172
+
173
+ checks = {
174
+ "Data Lab locked": "left" in pipe,
175
+ "Crop defined": "crop" in pipe,
176
+ "Augmentation applied": "crop_aug" in pipe,
177
+ "Active modules locked": "active_modules" in st.session_state,
178
+ "RCE head trained": "rce_head" in st.session_state,
179
+ "CNN head trained": any(f"cnn_head_{n}" in st.session_state
180
+ for n in ["ResNet-18", "MobileNetV3", "MobileViT-XXS"]),
181
+ "RCE detections ready": "rce_dets" in st.session_state,
182
+ "CNN detections ready": "cnn_dets" in st.session_state,
183
+ }
184
+
185
+ cols = st.columns(4)
186
+ for i, (label, done) in enumerate(checks.items()):
187
+ cols[i % 4].markdown(
188
+ f"{'✅' if done else '⬜'} {'~~' if not done else ''}{label}{'~~' if not done else ''}"
189
+ )
190
+
191
+ st.divider()
192
+ st.caption("Navigate using the sidebar → Start with **🧪 Data Lab**")
pages/3_Feature_Lab.py CHANGED
@@ -5,25 +5,7 @@ import plotly.graph_objects as go
5
  import sys, os
6
  sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
7
  from src.detectors.rce.features import REGISTRY
8
-
9
-
10
- # ---------------------------------------------------------------------------
11
- # Cached model loaders — instantiated once, reused across reruns
12
- # ---------------------------------------------------------------------------
13
- @st.cache_resource
14
- def load_resnet():
15
- from src.detectors.resnet import ResNetDetector
16
- return ResNetDetector()
17
-
18
- @st.cache_resource
19
- def load_mobilenet():
20
- from src.detectors.mobilenet import MobileNetDetector
21
- return MobileNetDetector()
22
-
23
- @st.cache_resource
24
- def load_mobilevit():
25
- from src.detectors.mobilevit import MobileViTDetector
26
- return MobileViTDetector()
27
 
28
  st.set_page_config(page_title="Feature Lab", layout="wide")
29
 
@@ -86,22 +68,16 @@ with col_rce:
86
  # ---------------------------------------------------------------------------
87
  with col_cnn:
88
  st.header("🧠 CNN: Static Architecture")
89
- selected_cnn = st.selectbox("Compare against Model", ["ResNet-18", "MobileViT-XXS", "MobileNetV3"])
90
  st.info("CNN features are fixed by pre-trained weights. You cannot toggle them like the RCE.")
91
 
92
  with st.spinner(f"Loading {selected_cnn} and extracting activations..."):
93
  try:
94
- if selected_cnn == "ResNet-18":
95
- detector = load_resnet()
96
- layer_name = "layer4 (last conv block)"
97
- elif selected_cnn == "MobileViT-XXS":
98
- detector = load_mobilevit()
99
- layer_name = "stages[-1] (last transformer stage)"
100
- else:
101
- detector = load_mobilenet()
102
- layer_name = "features[-1] (last features block)"
103
-
104
- act_maps = detector.get_activation_maps(obj, n_maps=6)
105
  st.caption(f"Hooked layer: `{layer_name}` — showing 6 of {len(act_maps)} channels")
106
  act_cols = st.columns(3)
107
  for i, amap in enumerate(act_maps):
 
5
  import sys, os
6
  sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
7
  from src.detectors.rce.features import REGISTRY
8
+ from src.models import BACKBONES
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  st.set_page_config(page_title="Feature Lab", layout="wide")
11
 
 
68
  # ---------------------------------------------------------------------------
69
  with col_cnn:
70
  st.header("🧠 CNN: Static Architecture")
71
+ selected_cnn = st.selectbox("Compare against Model", list(BACKBONES.keys()))
72
  st.info("CNN features are fixed by pre-trained weights. You cannot toggle them like the RCE.")
73
 
74
  with st.spinner(f"Loading {selected_cnn} and extracting activations..."):
75
  try:
76
+ bmeta = BACKBONES[selected_cnn]
77
+ backbone = bmeta["loader"]() # cached frozen backbone
78
+ layer_name = bmeta["hook_layer"]
79
+
80
+ act_maps = backbone.get_activation_maps(obj, n_maps=6)
 
 
 
 
 
 
81
  st.caption(f"Hooked layer: `{layer_name}` — showing 6 of {len(act_maps)} channels")
82
  act_cols = st.columns(3)
83
  for i, amap in enumerate(act_maps):
pages/4_Model_Tuning.py CHANGED
@@ -7,6 +7,7 @@ import sys, os
7
  sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
8
 
9
  from src.detectors.rce.features import REGISTRY
 
10
 
11
  st.set_page_config(page_title="Model Tuning", layout="wide")
12
  st.title("⚙️ Model Tuning: Train & Compare")
@@ -26,31 +27,6 @@ bbox = assets.get("crop_bbox", (0, 0, crop.shape[1], crop.shape[0]))
26
  active_modules = st.session_state.get("active_modules", {k: True for k in REGISTRY})
27
 
28
 
29
- # ---------------------------------------------------------------------------
30
- # Cached model loaders
31
- # ---------------------------------------------------------------------------
32
- @st.cache_resource
33
- def load_resnet():
34
- from src.detectors.resnet import ResNetDetector
35
- return ResNetDetector()
36
-
37
- @st.cache_resource
38
- def load_mobilenet():
39
- from src.detectors.mobilenet import MobileNetDetector
40
- return MobileNetDetector()
41
-
42
- @st.cache_resource
43
- def load_mobilevit():
44
- from src.detectors.mobilevit import MobileViTDetector
45
- return MobileViTDetector()
46
-
47
- CNN_MODELS = {
48
- "ResNet-18": {"loader": load_resnet, "dim": 512},
49
- "MobileNetV3": {"loader": load_mobilenet, "dim": 576},
50
- "MobileViT-XXS": {"loader": load_mobilevit, "dim": 320},
51
- }
52
-
53
-
54
  # ---------------------------------------------------------------------------
55
  # Build training set from session data (no disk reads)
56
  # ---------------------------------------------------------------------------
@@ -137,7 +113,6 @@ with col_rce:
137
 
138
  if st.button("🚀 Train RCE Head"):
139
  images, labels = build_training_set()
140
- from sklearn.linear_model import LogisticRegression
141
  from sklearn.metrics import accuracy_score
142
 
143
  progress = st.progress(0, text="Extracting RCE features...")
@@ -151,12 +126,11 @@ with col_rce:
151
  progress.progress(1.0, text="Fitting Logistic Regression...")
152
 
153
  t0 = time.perf_counter()
154
- head = LogisticRegression(max_iter=rce_max_iter, C=rce_C)
155
- head.fit(X, labels)
156
  train_time = time.perf_counter() - t0
157
  progress.progress(1.0, text="✅ Training complete!")
158
 
159
- preds = head.predict(X)
160
  train_acc = accuracy_score(labels, preds)
161
 
162
  st.success(f"Trained in **{train_time:.2f}s**")
@@ -184,10 +158,9 @@ with col_rce:
184
  head = st.session_state["rce_head"]
185
  t0 = time.perf_counter()
186
  vec = build_rce_vector(crop_aug)
187
- probs = head.predict_proba([vec])[0]
188
  dt = (time.perf_counter() - t0) * 1000
189
- idx = np.argmax(probs)
190
- st.write(f"**{head.classes_[idx]}** — {probs[idx]:.1%} confidence — {dt:.1f} ms")
191
 
192
 
193
  # ---------------------------------------------------------------------------
@@ -196,8 +169,8 @@ with col_rce:
196
  with col_cnn:
197
  st.header("🧠 CNN Fine-Tuning")
198
 
199
- selected = st.selectbox("Select Model", list(CNN_MODELS.keys()))
200
- meta = CNN_MODELS[selected]
201
  st.caption(f"Backbone embedding: **{meta['dim']}D** → Logistic Regression head")
202
 
203
  st.subheader("Training Parameters")
@@ -208,28 +181,26 @@ with col_cnn:
208
 
209
  if st.button(f"🚀 Train {selected} Head"):
210
  images, labels = build_training_set()
211
- detector = meta["loader"]()
212
 
213
- from sklearn.linear_model import LogisticRegression
214
  from sklearn.metrics import accuracy_score
215
 
216
  progress = st.progress(0, text=f"Extracting {selected} features...")
217
  n = len(images)
218
  X = []
219
  for i, img in enumerate(images):
220
- X.append(detector._get_features(img))
221
  progress.progress((i + 1) / n, text=f"Feature extraction: {i+1}/{n}")
222
 
223
  X = np.array(X)
224
  progress.progress(1.0, text="Fitting Logistic Regression...")
225
 
226
  t0 = time.perf_counter()
227
- head = LogisticRegression(max_iter=cnn_max_iter, C=cnn_C)
228
- head.fit(X, labels)
229
  train_time = time.perf_counter() - t0
230
  progress.progress(1.0, text="✅ Training complete!")
231
 
232
- preds = head.predict(X)
233
  train_acc = accuracy_score(labels, preds)
234
 
235
  st.success(f"Trained in **{train_time:.2f}s**")
@@ -254,14 +225,13 @@ with col_cnn:
254
  if f"cnn_head_{selected}" in st.session_state:
255
  st.divider()
256
  st.subheader("Quick Predict (Crop)")
257
- detector = meta["loader"]()
258
  head = st.session_state[f"cnn_head_{selected}"]
259
  t0 = time.perf_counter()
260
- feats = detector._get_features(crop_aug)
261
- probs = head.predict_proba([feats])[0]
262
  dt = (time.perf_counter() - t0) * 1000
263
- idx = np.argmax(probs)
264
- st.write(f"**{head.classes_[idx]}** — {probs[idx]:.1%} confidence — {dt:.1f} ms")
265
 
266
 
267
  # ===========================================================================
@@ -275,11 +245,11 @@ rows = []
275
  if rce_acc is not None:
276
  rows.append({"Model": "RCE", "Train Accuracy": f"{rce_acc:.1%}",
277
  "Vector Size": str(sum(10 for k in active_modules if active_modules[k]))})
278
- for name in CNN_MODELS:
279
  acc = st.session_state.get(f"cnn_acc_{name}")
280
  if acc is not None:
281
  rows.append({"Model": name, "Train Accuracy": f"{acc:.1%}",
282
- "Vector Size": f"{CNN_MODELS[name]['dim']}D"})
283
 
284
  if rows:
285
  import pandas as pd
 
7
  sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
8
 
9
  from src.detectors.rce.features import REGISTRY
10
+ from src.models import BACKBONES, RecognitionHead
11
 
12
  st.set_page_config(page_title="Model Tuning", layout="wide")
13
  st.title("⚙️ Model Tuning: Train & Compare")
 
27
  active_modules = st.session_state.get("active_modules", {k: True for k in REGISTRY})
28
 
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  # ---------------------------------------------------------------------------
31
  # Build training set from session data (no disk reads)
32
  # ---------------------------------------------------------------------------
 
113
 
114
  if st.button("🚀 Train RCE Head"):
115
  images, labels = build_training_set()
 
116
  from sklearn.metrics import accuracy_score
117
 
118
  progress = st.progress(0, text="Extracting RCE features...")
 
126
  progress.progress(1.0, text="Fitting Logistic Regression...")
127
 
128
  t0 = time.perf_counter()
129
+ head = RecognitionHead(C=rce_C, max_iter=rce_max_iter).fit(X, labels)
 
130
  train_time = time.perf_counter() - t0
131
  progress.progress(1.0, text="✅ Training complete!")
132
 
133
+ preds = head.model.predict(X)
134
  train_acc = accuracy_score(labels, preds)
135
 
136
  st.success(f"Trained in **{train_time:.2f}s**")
 
158
  head = st.session_state["rce_head"]
159
  t0 = time.perf_counter()
160
  vec = build_rce_vector(crop_aug)
161
+ label, conf = head.predict(vec)
162
  dt = (time.perf_counter() - t0) * 1000
163
+ st.write(f"**{label}** {conf:.1%} confidence — {dt:.1f} ms")
 
164
 
165
 
166
  # ---------------------------------------------------------------------------
 
169
  with col_cnn:
170
  st.header("🧠 CNN Fine-Tuning")
171
 
172
+ selected = st.selectbox("Select Model", list(BACKBONES.keys()))
173
+ meta = BACKBONES[selected]
174
  st.caption(f"Backbone embedding: **{meta['dim']}D** → Logistic Regression head")
175
 
176
  st.subheader("Training Parameters")
 
181
 
182
  if st.button(f"🚀 Train {selected} Head"):
183
  images, labels = build_training_set()
184
+ backbone = meta["loader"]() # cached frozen backbone
185
 
 
186
  from sklearn.metrics import accuracy_score
187
 
188
  progress = st.progress(0, text=f"Extracting {selected} features...")
189
  n = len(images)
190
  X = []
191
  for i, img in enumerate(images):
192
+ X.append(backbone.get_features(img))
193
  progress.progress((i + 1) / n, text=f"Feature extraction: {i+1}/{n}")
194
 
195
  X = np.array(X)
196
  progress.progress(1.0, text="Fitting Logistic Regression...")
197
 
198
  t0 = time.perf_counter()
199
+ head = RecognitionHead(C=cnn_C, max_iter=cnn_max_iter).fit(X, labels)
 
200
  train_time = time.perf_counter() - t0
201
  progress.progress(1.0, text="✅ Training complete!")
202
 
203
+ preds = head.model.predict(X)
204
  train_acc = accuracy_score(labels, preds)
205
 
206
  st.success(f"Trained in **{train_time:.2f}s**")
 
225
  if f"cnn_head_{selected}" in st.session_state:
226
  st.divider()
227
  st.subheader("Quick Predict (Crop)")
228
+ backbone = meta["loader"]() # cached frozen backbone
229
  head = st.session_state[f"cnn_head_{selected}"]
230
  t0 = time.perf_counter()
231
+ feats = backbone.get_features(crop_aug)
232
+ label, conf = head.predict(feats)
233
  dt = (time.perf_counter() - t0) * 1000
234
+ st.write(f"**{label}** {conf:.1%} confidence — {dt:.1f} ms")
 
235
 
236
 
237
  # ===========================================================================
 
245
  if rce_acc is not None:
246
  rows.append({"Model": "RCE", "Train Accuracy": f"{rce_acc:.1%}",
247
  "Vector Size": str(sum(10 for k in active_modules if active_modules[k]))})
248
+ for name in BACKBONES:
249
  acc = st.session_state.get(f"cnn_acc_{name}")
250
  if acc is not None:
251
  rows.append({"Model": name, "Train Accuracy": f"{acc:.1%}",
252
+ "Vector Size": f"{BACKBONES[name]['dim']}D"})
253
 
254
  if rows:
255
  import pandas as pd
pages/5_RealTime_Detection.py CHANGED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import cv2
3
+ import numpy as np
4
+ import time
5
+ import plotly.graph_objects as go
6
+ import sys, os
7
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
8
+
9
+ from src.detectors.rce.features import REGISTRY
10
+ from src.models import BACKBONES, RecognitionHead
11
+
12
+ st.set_page_config(page_title="Real-Time Detection", layout="wide")
13
+ st.title("🎯 Real-Time Detection")
14
+
15
+ # ---------------------------------------------------------------------------
16
+ # Guard
17
+ # ---------------------------------------------------------------------------
18
+ if "pipeline_data" not in st.session_state or "crop" not in st.session_state.get("pipeline_data", {}):
19
+ st.error("Complete **Data Lab** first (upload assets & define a crop).")
20
+ st.stop()
21
+
22
+ assets = st.session_state["pipeline_data"]
23
+ right_img = assets["right"]
24
+ crop = assets["crop"]
25
+ crop_aug = assets.get("crop_aug", crop)
26
+ bbox = assets.get("crop_bbox", (0, 0, crop.shape[1], crop.shape[0]))
27
+ active_mods = st.session_state.get("active_modules", {k: True for k in REGISTRY})
28
+
29
+ x0, y0, x1, y1 = bbox
30
+ win_h, win_w = y1 - y0, x1 - x0 # window = same size as crop
31
+
32
+ rce_head = st.session_state.get("rce_head")
33
+ has_any_cnn = any(f"cnn_head_{n}" in st.session_state for n in BACKBONES)
34
+
35
+ if rce_head is None and not has_any_cnn:
36
+ st.warning("No trained heads found. Go to **Model Tuning** and train at least one head.")
37
+ st.stop()
38
+
39
+
40
+ # ===================================================================
41
+ # Sliding Window Engine (shared by both sides)
42
+ # ===================================================================
43
+ def sliding_window_detect(
44
+ image: np.ndarray,
45
+ feature_fn, # callable(patch_bgr) -> 1-D np.ndarray
46
+ head: RecognitionHead,
47
+ stride: int,
48
+ conf_thresh: float,
49
+ nms_iou: float,
50
+ progress_placeholder=None,
51
+ live_image_placeholder=None,
52
+ ):
53
+ """
54
+ Slide a window of size (win_h, win_w) across *image* with *stride*.
55
+ At each position call *feature_fn* → *head.predict*.
56
+ Returns (detections, heatmap, total_time_ms, n_windows).
57
+
58
+ Each detection is (x, y, x+win_w, y+win_h, label, confidence).
59
+ heatmap is a float32 array same size as image (object confidence).
60
+ """
61
+ H, W = image.shape[:2]
62
+ heatmap = np.zeros((H, W), dtype=np.float32)
63
+ detections = []
64
+ t0 = time.perf_counter()
65
+
66
+ positions = []
67
+ for y in range(0, H - win_h + 1, stride):
68
+ for x in range(0, W - win_w + 1, stride):
69
+ positions.append((x, y))
70
+
71
+ n_total = len(positions)
72
+ if n_total == 0:
73
+ return [], heatmap, 0.0, 0
74
+
75
+ for idx, (x, y) in enumerate(positions):
76
+ patch = image[y:y+win_h, x:x+win_w]
77
+ feats = feature_fn(patch)
78
+ label, conf = head.predict(feats)
79
+
80
+ # Fill heatmap with object confidence
81
+ if label == "object":
82
+ heatmap[y:y+win_h, x:x+win_w] = np.maximum(
83
+ heatmap[y:y+win_h, x:x+win_w], conf)
84
+ if conf >= conf_thresh:
85
+ detections.append((x, y, x+win_w, y+win_h, label, conf))
86
+
87
+ # Live updates (every 5th window or last)
88
+ if live_image_placeholder is not None and (idx % 5 == 0 or idx == n_total - 1):
89
+ vis = image.copy()
90
+ # Draw current scan position
91
+ cv2.rectangle(vis, (x, y), (x+win_w, y+win_h), (255, 255, 0), 1)
92
+ # Draw current detections
93
+ for dx, dy, dx2, dy2, dl, dc in detections:
94
+ cv2.rectangle(vis, (dx, dy), (dx2, dy2), (0, 255, 0), 2)
95
+ cv2.putText(vis, f"{dc:.0%}", (dx, dy - 4),
96
+ cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 255, 0), 1)
97
+ live_image_placeholder.image(
98
+ cv2.cvtColor(vis, cv2.COLOR_BGR2RGB),
99
+ caption=f"Scanning… {idx+1}/{n_total}",
100
+ use_container_width=True)
101
+
102
+ if progress_placeholder is not None:
103
+ progress_placeholder.progress(
104
+ (idx + 1) / n_total,
105
+ text=f"Window {idx+1}/{n_total}")
106
+
107
+ total_ms = (time.perf_counter() - t0) * 1000
108
+
109
+ # --- Non-Maximum Suppression ---
110
+ if detections:
111
+ detections = _nms(detections, nms_iou)
112
+
113
+ return detections, heatmap, total_ms, n_total
114
+
115
+
116
+ def _nms(dets, iou_thresh):
117
+ """Greedy NMS on list of (x1,y1,x2,y2,label,conf)."""
118
+ dets = sorted(dets, key=lambda d: d[5], reverse=True)
119
+ keep = []
120
+ while dets:
121
+ best = dets.pop(0)
122
+ keep.append(best)
123
+ dets = [d for d in dets if _iou(best, d) < iou_thresh]
124
+ return keep
125
+
126
+
127
+ def _iou(a, b):
128
+ """IoU between two (x1,y1,x2,y2,…) tuples."""
129
+ xi1 = max(a[0], b[0]); yi1 = max(a[1], b[1])
130
+ xi2 = min(a[2], b[2]); yi2 = min(a[3], b[3])
131
+ inter = max(0, xi2-xi1) * max(0, yi2-yi1)
132
+ aa = (a[2]-a[0])*(a[3]-a[1])
133
+ ab = (b[2]-b[0])*(b[3]-b[1])
134
+ return inter / (aa + ab - inter + 1e-6)
135
+
136
+
137
+ # ===================================================================
138
+ # RCE feature function
139
+ # ===================================================================
140
+ def rce_feature_fn(patch_bgr):
141
+ gray = cv2.cvtColor(patch_bgr, cv2.COLOR_BGR2GRAY)
142
+ vec = []
143
+ for key, meta in REGISTRY.items():
144
+ if active_mods.get(key, False):
145
+ v, _ = meta["fn"](gray)
146
+ vec.extend(v)
147
+ return np.array(vec, dtype=np.float32)
148
+
149
+
150
+ # ===================================================================
151
+ # Controls
152
+ # ===================================================================
153
+ st.subheader("Sliding Window Parameters")
154
+ p1, p2, p3 = st.columns(3)
155
+ stride = p1.slider("Stride (px)", 4, max(win_w, win_h),
156
+ max(win_w // 4, 4), step=2,
157
+ help="Lower = more windows = slower but finer")
158
+ conf_thresh = p2.slider("Confidence Threshold", 0.5, 1.0, 0.7, 0.05)
159
+ nms_iou = p3.slider("NMS IoU Threshold", 0.1, 0.9, 0.3, 0.05)
160
+
161
+ st.caption(f"Window size: **{win_w}×{win_h} px** | "
162
+ f"Right image: **{right_img.shape[1]}×{right_img.shape[0]} px** | "
163
+ f"≈ {((right_img.shape[0]-win_h)//stride + 1) * ((right_img.shape[1]-win_w)//stride + 1)} windows")
164
+
165
+ st.divider()
166
+
167
+ # ===================================================================
168
+ # Side-by-side layout
169
+ # ===================================================================
170
+ col_rce, col_cnn = st.columns(2)
171
+
172
+ # -------------------------------------------------------------------
173
+ # LEFT — RCE Detection
174
+ # -------------------------------------------------------------------
175
+ with col_rce:
176
+ st.header("🧬 RCE Detection")
177
+ if rce_head is None:
178
+ st.info("No RCE head trained. Train one in **Model Tuning**.")
179
+ else:
180
+ st.caption(f"Modules: {', '.join(REGISTRY[k]['label'] for k in active_mods if active_mods[k])}")
181
+ rce_run = st.button("▶ Run RCE Scan", key="rce_run")
182
+
183
+ rce_progress = st.empty()
184
+ rce_live = st.empty()
185
+ rce_results = st.container()
186
+
187
+ if rce_run:
188
+ dets, hmap, ms, nw = sliding_window_detect(
189
+ right_img, rce_feature_fn, rce_head,
190
+ stride, conf_thresh, nms_iou,
191
+ progress_placeholder=rce_progress,
192
+ live_image_placeholder=rce_live,
193
+ )
194
+
195
+ # Final image with boxes
196
+ final = right_img.copy()
197
+ for x1d, y1d, x2d, y2d, lbl, cf in dets:
198
+ cv2.rectangle(final, (x1d, y1d), (x2d, y2d), (0, 255, 0), 2)
199
+ cv2.putText(final, f"{cf:.0%}", (x1d, y1d - 6),
200
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
201
+ rce_live.image(cv2.cvtColor(final, cv2.COLOR_BGR2RGB),
202
+ caption="RCE — Final Detections",
203
+ use_container_width=True)
204
+ rce_progress.empty()
205
+
206
+ with rce_results:
207
+ # Metrics
208
+ rm1, rm2, rm3, rm4 = st.columns(4)
209
+ rm1.metric("Detections", len(dets))
210
+ rm2.metric("Windows", nw)
211
+ rm3.metric("Total Time", f"{ms:.0f} ms")
212
+ rm4.metric("Per Window", f"{ms/max(nw,1):.2f} ms")
213
+
214
+ # Confidence heatmap
215
+ if hmap.max() > 0:
216
+ hmap_color = cv2.applyColorMap(
217
+ (hmap / hmap.max() * 255).astype(np.uint8),
218
+ cv2.COLORMAP_JET)
219
+ blend = cv2.addWeighted(right_img, 0.5, hmap_color, 0.5, 0)
220
+ st.image(cv2.cvtColor(blend, cv2.COLOR_BGR2RGB),
221
+ caption="RCE — Confidence Heatmap",
222
+ use_container_width=True)
223
+
224
+ # Detection table
225
+ if dets:
226
+ import pandas as pd
227
+ df = pd.DataFrame(dets, columns=["x1","y1","x2","y2","label","conf"])
228
+ st.dataframe(df, use_container_width=True, hide_index=True)
229
+
230
+ st.session_state["rce_dets"] = dets
231
+ st.session_state["rce_det_ms"] = ms
232
+
233
+ # -------------------------------------------------------------------
234
+ # RIGHT — CNN Detection
235
+ # -------------------------------------------------------------------
236
+ with col_cnn:
237
+ st.header("🧠 CNN Detection")
238
+
239
+ # Find which CNN heads are trained
240
+ trained_cnns = [n for n in BACKBONES if f"cnn_head_{n}" in st.session_state]
241
+ if not trained_cnns:
242
+ st.info("No CNN head trained. Train one in **Model Tuning**.")
243
+ else:
244
+ selected = st.selectbox("Select Model", trained_cnns, key="det_cnn_sel")
245
+ bmeta = BACKBONES[selected]
246
+ backbone = bmeta["loader"]()
247
+ head = st.session_state[f"cnn_head_{selected}"]
248
+
249
+ st.caption(f"Backbone: **{selected}** ({bmeta['dim']}D) — Head in session state")
250
+ cnn_run = st.button(f"▶ Run {selected} Scan", key="cnn_run")
251
+
252
+ cnn_progress = st.empty()
253
+ cnn_live = st.empty()
254
+ cnn_results = st.container()
255
+
256
+ if cnn_run:
257
+ dets, hmap, ms, nw = sliding_window_detect(
258
+ right_img, backbone.get_features, head,
259
+ stride, conf_thresh, nms_iou,
260
+ progress_placeholder=cnn_progress,
261
+ live_image_placeholder=cnn_live,
262
+ )
263
+
264
+ # Final image
265
+ final = right_img.copy()
266
+ for x1d, y1d, x2d, y2d, lbl, cf in dets:
267
+ cv2.rectangle(final, (x1d, y1d), (x2d, y2d), (0, 0, 255), 2)
268
+ cv2.putText(final, f"{cf:.0%}", (x1d, y1d - 6),
269
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
270
+ cnn_live.image(cv2.cvtColor(final, cv2.COLOR_BGR2RGB),
271
+ caption=f"{selected} — Final Detections",
272
+ use_container_width=True)
273
+ cnn_progress.empty()
274
+
275
+ with cnn_results:
276
+ cm1, cm2, cm3, cm4 = st.columns(4)
277
+ cm1.metric("Detections", len(dets))
278
+ cm2.metric("Windows", nw)
279
+ cm3.metric("Total Time", f"{ms:.0f} ms")
280
+ cm4.metric("Per Window", f"{ms/max(nw,1):.2f} ms")
281
+
282
+ if hmap.max() > 0:
283
+ hmap_color = cv2.applyColorMap(
284
+ (hmap / hmap.max() * 255).astype(np.uint8),
285
+ cv2.COLORMAP_JET)
286
+ blend = cv2.addWeighted(right_img, 0.5, hmap_color, 0.5, 0)
287
+ st.image(cv2.cvtColor(blend, cv2.COLOR_BGR2RGB),
288
+ caption=f"{selected} — Confidence Heatmap",
289
+ use_container_width=True)
290
+
291
+ if dets:
292
+ import pandas as pd
293
+ df = pd.DataFrame(dets, columns=["x1","y1","x2","y2","label","conf"])
294
+ st.dataframe(df, use_container_width=True, hide_index=True)
295
+
296
+ st.session_state["cnn_dets"] = dets
297
+ st.session_state["cnn_det_ms"] = ms
298
+
299
+
300
+ # ===================================================================
301
+ # Bottom — Comparison (if both have run)
302
+ # ===================================================================
303
+ rce_dets = st.session_state.get("rce_dets")
304
+ cnn_dets = st.session_state.get("cnn_dets")
305
+
306
+ if rce_dets is not None and cnn_dets is not None:
307
+ st.divider()
308
+ st.subheader("📊 Side-by-Side Comparison")
309
+
310
+ import pandas as pd
311
+ comp = pd.DataFrame({
312
+ "Metric": ["Detections", "Best Confidence", "Total Time (ms)"],
313
+ "RCE": [
314
+ len(rce_dets),
315
+ f"{max((d[5] for d in rce_dets), default=0):.1%}",
316
+ f"{st.session_state.get('rce_det_ms', 0):.0f}",
317
+ ],
318
+ "CNN": [
319
+ len(cnn_dets),
320
+ f"{max((d[5] for d in cnn_dets), default=0):.1%}",
321
+ f"{st.session_state.get('cnn_det_ms', 0):.0f}",
322
+ ],
323
+ })
324
+ st.dataframe(comp, use_container_width=True, hide_index=True)
325
+
326
+ # Overlay both on one image
327
+ overlay = right_img.copy()
328
+ for x1d, y1d, x2d, y2d, _, cf in rce_dets:
329
+ cv2.rectangle(overlay, (x1d, y1d), (x2d, y2d), (0, 255, 0), 2)
330
+ cv2.putText(overlay, f"RCE {cf:.0%}", (x1d, y1d - 6),
331
+ cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 255, 0), 1)
332
+ for x1d, y1d, x2d, y2d, _, cf in cnn_dets:
333
+ cv2.rectangle(overlay, (x1d, y1d), (x2d, y2d), (0, 0, 255), 2)
334
+ cv2.putText(overlay, f"CNN {cf:.0%}", (x1d, y2d + 12),
335
+ cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 255), 1)
336
+ st.image(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB),
337
+ caption="Green = RCE | Blue = CNN",
338
+ use_container_width=True)
pages/6_Stereo_Geometry.py CHANGED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import cv2
3
+ import numpy as np
4
+ import re
5
+ import plotly.graph_objects as go
6
+ import sys, os
7
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
8
+
9
+ st.set_page_config(page_title="Stereo Geometry", layout="wide")
10
+ st.title("📐 Stereo Geometry: Distance Estimation")
11
+
12
+ # ---------------------------------------------------------------------------
13
+ # Guard
14
+ # ---------------------------------------------------------------------------
15
+ if "pipeline_data" not in st.session_state or "left" not in st.session_state.get("pipeline_data", {}):
16
+ st.error("Complete **Data Lab** first.")
17
+ st.stop()
18
+
19
+ assets = st.session_state["pipeline_data"]
20
+ img_l = assets["left"]
21
+ img_r = assets["right"]
22
+ gt_left = assets.get("gt_left") # float32 depth map from PFM
23
+ gt_right = assets.get("gt_right")
24
+ conf_raw = assets.get("conf_raw", "")
25
+ crop_bbox = assets.get("crop_bbox") # (x0, y0, x1, y1) on LEFT image
26
+
27
+ rce_dets = st.session_state.get("rce_dets", []) # list of (x1,y1,x2,y2,label,conf)
28
+ cnn_dets = st.session_state.get("cnn_dets", [])
29
+
30
+
31
+ # ===================================================================
32
+ # Parse Middlebury-style camera config
33
+ # ===================================================================
34
+ def parse_config(text: str) -> dict:
35
+ """
36
+ Parse a Middlebury .txt / .conf calibration file.
37
+ Expected keys: cam0, cam1, doffs, baseline, width, height, ndisp, vmin, vmax
38
+ cam0 / cam1 are 3×3 matrices in bracket notation: [f 0 cx; 0 f cy; 0 0 1]
39
+ """
40
+ params = {}
41
+ for line in text.strip().splitlines():
42
+ line = line.strip()
43
+ if "=" not in line:
44
+ continue
45
+ key, val = line.split("=", 1)
46
+ key = key.strip()
47
+ val = val.strip()
48
+ # Matrix?
49
+ if "[" in val:
50
+ nums = list(map(float, re.findall(r"[-+]?\d*\.?\d+(?:[eE][-+]?\d+)?", val)))
51
+ if len(nums) == 9:
52
+ params[key] = np.array(nums).reshape(3, 3)
53
+ else:
54
+ params[key] = nums
55
+ else:
56
+ try:
57
+ params[key] = float(val)
58
+ except ValueError:
59
+ params[key] = val
60
+ return params
61
+
62
+
63
+ calib = parse_config(conf_raw)
64
+
65
+ # Extract intrinsics
66
+ focal = calib.get("cam0", np.eye(3))[0, 0] if isinstance(calib.get("cam0"), np.ndarray) else 0.0
67
+ doffs = float(calib.get("doffs", 0.0))
68
+ baseline = float(calib.get("baseline", 1.0))
69
+ ndisp = int(calib.get("ndisp", 128))
70
+
71
+ st.subheader("Camera Calibration")
72
+ cc1, cc2, cc3, cc4 = st.columns(4)
73
+ cc1.metric("Focal Length (px)", f"{focal:.1f}")
74
+ cc2.metric("Baseline (mm)", f"{baseline:.1f}")
75
+ cc3.metric("Doffs (px)", f"{doffs:.2f}")
76
+ cc4.metric("ndisp", str(ndisp))
77
+
78
+ with st.expander("Full Calibration"):
79
+ st.json({k: v.tolist() if isinstance(v, np.ndarray) else v for k, v in calib.items()})
80
+
81
+ st.divider()
82
+
83
+
84
+ # ===================================================================
85
+ # Step 1 — Compute Disparity Map
86
+ # ===================================================================
87
+ st.subheader("Step 1: Disparity Map (StereoSGBM)")
88
+
89
+ sc1, sc2, sc3 = st.columns(3)
90
+ block_size = sc1.slider("Block Size", 3, 21, 5, step=2)
91
+ p1_mult = sc2.slider("P1 multiplier", 1, 32, 8)
92
+ p2_mult = sc3.slider("P2 multiplier", 1, 128, 32)
93
+
94
+ @st.cache_data
95
+ def compute_disparity(_left, _right, _ndisp, _block_size, _p1m, _p2m):
96
+ gray_l = cv2.cvtColor(_left, cv2.COLOR_BGR2GRAY)
97
+ gray_r = cv2.cvtColor(_right, cv2.COLOR_BGR2GRAY)
98
+
99
+ # Align ndisp to 16
100
+ nd = max(16, (_ndisp // 16) * 16)
101
+ channels = 1
102
+ sgbm = cv2.StereoSGBM_create(
103
+ minDisparity=0,
104
+ numDisparities=nd,
105
+ blockSize=_block_size,
106
+ P1=_p1m * channels * _block_size ** 2,
107
+ P2=_p2m * channels * _block_size ** 2,
108
+ disp12MaxDiff=1,
109
+ uniquenessRatio=10,
110
+ speckleWindowSize=100,
111
+ speckleRange=32,
112
+ mode=cv2.STEREO_SGBM_MODE_SGBM_3WAY,
113
+ )
114
+ disp = sgbm.compute(gray_l, gray_r).astype(np.float32) / 16.0
115
+ return disp
116
+
117
+
118
+ with st.spinner("Computing disparity..."):
119
+ disp = compute_disparity(img_l, img_r, ndisp, block_size, p1_mult, p2_mult)
120
+
121
+ # Visualize disparity
122
+ disp_vis = disp.copy()
123
+ disp_vis[disp_vis <= 0] = 0
124
+ disp_max = disp_vis.max() if disp_vis.max() > 0 else 1.0
125
+ disp_norm = (disp_vis / disp_max * 255).astype(np.uint8)
126
+ disp_color = cv2.applyColorMap(disp_norm, cv2.COLORMAP_INFERNO)
127
+
128
+ dc1, dc2 = st.columns(2)
129
+ dc1.image(cv2.cvtColor(img_l, cv2.COLOR_BGR2RGB), caption="Left Image", use_container_width=True)
130
+ dc2.image(cv2.cvtColor(disp_color, cv2.COLOR_BGR2RGB), caption="Disparity Map (SGBM)", use_container_width=True)
131
+
132
+
133
+ # ===================================================================
134
+ # Step 2 — Depth Map from Disparity
135
+ # ===================================================================
136
+ st.divider()
137
+ st.subheader("Step 2: Depth Map from Disparity")
138
+
139
+ st.latex(r"Z = \frac{f \times B}{d + d_{\text{offs}}}")
140
+ st.caption("Z = depth (mm), f = focal length (px), B = baseline (mm), d = disparity (px), d_offs = optical center offset (px)")
141
+
142
+ # Compute depth from disparity
143
+ valid = (disp + doffs) > 0
144
+ depth_map = np.zeros_like(disp)
145
+ depth_map[valid] = (focal * baseline) / (disp[valid] + doffs)
146
+ depth_map[~valid] = 0
147
+
148
+ # Visualize
149
+ depth_vis = depth_map.copy()
150
+ finite = depth_vis[depth_vis > 0]
151
+ if len(finite) > 0:
152
+ clip_max = np.percentile(finite, 98)
153
+ depth_vis = np.clip(depth_vis, 0, clip_max)
154
+ depth_norm = (depth_vis / clip_max * 255).astype(np.uint8)
155
+ else:
156
+ depth_norm = np.zeros_like(depth_map, dtype=np.uint8)
157
+
158
+ depth_color = cv2.applyColorMap(depth_norm, cv2.COLORMAP_TURBO)
159
+
160
+ zc1, zc2 = st.columns(2)
161
+ zc1.image(cv2.cvtColor(depth_color, cv2.COLOR_BGR2RGB),
162
+ caption="Estimated Depth (SGBM)", use_container_width=True)
163
+
164
+ # Ground truth comparison
165
+ if gt_left is not None:
166
+ gt_vis = gt_left.copy()
167
+ gt_finite = gt_vis[np.isfinite(gt_vis) & (gt_vis > 0)]
168
+ if len(gt_finite) > 0:
169
+ gt_clip = np.percentile(gt_finite, 98)
170
+ gt_vis = np.clip(np.nan_to_num(gt_vis, nan=0), 0, gt_clip)
171
+ gt_norm = (gt_vis / gt_clip * 255).astype(np.uint8)
172
+ else:
173
+ gt_norm = np.zeros_like(gt_vis, dtype=np.uint8)
174
+ gt_color = cv2.applyColorMap(gt_norm, cv2.COLORMAP_TURBO)
175
+ zc2.image(cv2.cvtColor(gt_color, cv2.COLOR_BGR2RGB),
176
+ caption="Ground Truth Depth", use_container_width=True)
177
+
178
+
179
+ # ===================================================================
180
+ # Step 3 — Error Map (SGBM vs Ground Truth)
181
+ # ===================================================================
182
+ if gt_left is not None:
183
+ st.divider()
184
+ st.subheader("Step 3: Error Analysis (SGBM vs Ground Truth)")
185
+
186
+ # The GT is disparity in PFM for Middlebury; convert to depth for comparison
187
+ # Middlebury PFM stores DISPARITY, not depth. Let's handle both:
188
+ gt_disp = gt_left # Middlebury standard: PFM = disparity map
189
+ gt_depth_from_disp = np.zeros_like(gt_disp)
190
+ gt_valid = np.isfinite(gt_disp) & (gt_disp + doffs > 0) & (gt_disp != np.inf)
191
+ gt_depth_from_disp[gt_valid] = (focal * baseline) / (gt_disp[gt_valid] + doffs)
192
+
193
+ # Crop to common valid region
194
+ both_valid = valid & gt_valid
195
+ if both_valid.any():
196
+ # Disparity error
197
+ disp_err = np.abs(disp - gt_disp)
198
+ disp_err[~both_valid] = 0
199
+
200
+ # Stats
201
+ err_vals = disp_err[both_valid]
202
+ mae = float(np.mean(err_vals))
203
+ rmse = float(np.sqrt(np.mean(err_vals ** 2)))
204
+ bad_2 = float(np.mean(err_vals > 2.0)) * 100 # % of pixels with error > 2px
205
+
206
+ em1, em2, em3 = st.columns(3)
207
+ em1.metric("MAE (px)", f"{mae:.2f}")
208
+ em2.metric("RMSE (px)", f"{rmse:.2f}")
209
+ em3.metric("Bad-2.0 (%)", f"{bad_2:.1f}%")
210
+
211
+ # Error heatmap
212
+ err_clip = np.clip(disp_err, 0, 10)
213
+ err_norm = (err_clip / 10 * 255).astype(np.uint8)
214
+ err_color = cv2.applyColorMap(err_norm, cv2.COLORMAP_HOT)
215
+ st.image(cv2.cvtColor(err_color, cv2.COLOR_BGR2RGB),
216
+ caption="Disparity Error Map (red = high error, clipped at 10 px)",
217
+ use_container_width=True)
218
+
219
+ # Histogram
220
+ fig = go.Figure(data=[go.Histogram(x=err_vals, nbinsx=50,
221
+ marker_color="#ff6361")])
222
+ fig.update_layout(title="Disparity Error Distribution",
223
+ xaxis_title="Absolute Error (px)",
224
+ yaxis_title="Pixel Count",
225
+ template="plotly_dark", height=300)
226
+ st.plotly_chart(fig, use_container_width=True)
227
+ else:
228
+ st.warning("No overlapping valid pixels between SGBM disparity and ground truth.")
229
+
230
+
231
+ # ===================================================================
232
+ # Step 4 — Object Distance from Detections
233
+ # ===================================================================
234
+ st.divider()
235
+ st.subheader("Step 4: Object Distance Estimation")
236
+
237
+ all_dets = []
238
+ if rce_dets:
239
+ for d in rce_dets:
240
+ all_dets.append(("RCE", *d))
241
+ if cnn_dets:
242
+ for d in cnn_dets:
243
+ all_dets.append(("CNN", *d))
244
+
245
+ if not all_dets and crop_bbox is not None:
246
+ st.info("No detections from the Real-Time Detection page. Using the **crop bounding box on the left image** as a fallback.")
247
+ x0, y0, x1, y1 = crop_bbox
248
+ all_dets.append(("Crop (left)", x0, y0, x1, y1, "object", 1.0))
249
+ elif not all_dets:
250
+ st.warning("No detections found. Run **Real-Time Detection** first, or define a crop in **Data Lab**.")
251
+ st.stop()
252
+
253
+
254
+ # For each detection, compute median depth inside the bounding box
255
+ import pandas as pd
256
+
257
+ rows = []
258
+ det_overlay = img_l.copy() if all_dets and all_dets[0][0] == "Crop (left)" else img_r.copy()
259
+
260
+ for source, dx1, dy1, dx2, dy2, lbl, conf in all_dets:
261
+ dx1, dy1, dx2, dy2 = int(dx1), int(dy1), int(dx2), int(dy2)
262
+
263
+ # Clamp to image bounds
264
+ H, W = depth_map.shape[:2]
265
+ dx1c = max(0, min(dx1, W-1))
266
+ dy1c = max(0, min(dy1, H-1))
267
+ dx2c = max(0, min(dx2, W))
268
+ dy2c = max(0, min(dy2, H))
269
+
270
+ roi_depth = depth_map[dy1c:dy2c, dx1c:dx2c]
271
+ roi_disp = disp[dy1c:dy2c, dx1c:dx2c]
272
+ roi_valid = roi_depth[roi_depth > 0]
273
+
274
+ if len(roi_valid) > 0:
275
+ med_depth = float(np.median(roi_valid))
276
+ mean_depth = float(np.mean(roi_valid))
277
+ med_disp = float(np.median(roi_disp[roi_disp > 0])) if (roi_disp > 0).any() else 0
278
+ else:
279
+ med_depth = mean_depth = med_disp = 0.0
280
+
281
+ # Ground truth depth at this region (for comparison)
282
+ gt_depth_val = 0.0
283
+ if gt_left is not None:
284
+ gt_roi = gt_left[dy1c:dy2c, dx1c:dx2c]
285
+ gt_roi_valid = gt_roi[np.isfinite(gt_roi) & (gt_roi > 0)]
286
+ if len(gt_roi_valid) > 0:
287
+ # Convert GT disparity → depth
288
+ gt_med_disp = float(np.median(gt_roi_valid))
289
+ gt_depth_val = (focal * baseline) / (gt_med_disp + doffs) if (gt_med_disp + doffs) > 0 else 0
290
+
291
+ error_mm = abs(med_depth - gt_depth_val) if gt_depth_val > 0 else float('nan')
292
+
293
+ rows.append({
294
+ "Source": source,
295
+ "Box": f"({dx1},{dy1})→({dx2},{dy2})",
296
+ "Confidence": f"{conf:.1%}" if isinstance(conf, float) else str(conf),
297
+ "Med Disparity": f"{med_disp:.1f} px",
298
+ "Med Depth": f"{med_depth:.0f} mm",
299
+ "Mean Depth": f"{mean_depth:.0f} mm",
300
+ "GT Depth": f"{gt_depth_val:.0f} mm" if gt_depth_val > 0 else "N/A",
301
+ "Error": f"{error_mm:.0f} mm" if not np.isnan(error_mm) else "N/A",
302
+ })
303
+
304
+ # Draw on overlay
305
+ color = (0, 255, 0) if "RCE" in source else (0, 0, 255) if "CNN" in source else (255, 255, 0)
306
+ cv2.rectangle(det_overlay, (dx1c, dy1c), (dx2c, dy2c), color, 2)
307
+ depth_str = f"{med_depth/1000:.2f}m" if med_depth > 0 else "?"
308
+ cv2.putText(det_overlay, f"{source} {depth_str}",
309
+ (dx1c, dy1c - 6), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
310
+
311
+ # Show overlay
312
+ st.image(cv2.cvtColor(det_overlay, cv2.COLOR_BGR2RGB),
313
+ caption="Detections with Estimated Distance",
314
+ use_container_width=True)
315
+
316
+ # Table
317
+ st.dataframe(pd.DataFrame(rows), use_container_width=True, hide_index=True)
318
+
319
+ # Big metric cards for the best detection
320
+ if rows:
321
+ best = rows[0]
322
+ st.divider()
323
+ st.subheader("🎯 Primary Detection — Distance")
324
+ bc1, bc2, bc3 = st.columns(3)
325
+ bc1.metric("Estimated Depth", best["Med Depth"])
326
+ bc2.metric("Ground Truth", best["GT Depth"])
327
+ bc3.metric("Absolute Error", best["Error"])
src/models.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ src/models.py — Central Model Registry
3
+ =========================================
4
+ Downloads backbone weights **once** from the internet (PyTorch Hub / timm),
5
+ freezes every feature-extraction layer, and caches the result in RAM with
6
+ Streamlit's ``@st.cache_resource``.
7
+
8
+ Strategy
9
+ --------
10
+ 1. **Freeze the Backbone** → ``requires_grad = False`` on every parameter.
11
+ The backbone is a pure feature extractor — no gradient updates, ever.
12
+ 2. **Cache the Resource** → ``@st.cache_resource`` keeps the heavy model
13
+ in RAM even when you switch pages.
14
+ 3. **Define the Head** → ``RecognitionHead``: a tiny sklearn
15
+ LogisticRegression that takes the backbone's feature vector and
16
+ produces a recognition score. Lives only in ``st.session_state``.
17
+ """
18
+
19
+ import streamlit as st
20
+ import torch
21
+ import torch.nn as nn
22
+ import torchvision.models as models
23
+ import torchvision.transforms as transforms
24
+ import timm
25
+ import cv2
26
+ import numpy as np
27
+
28
+ # ---------------------------------------------------------------------------
29
+ # Device selection (MPS > CUDA > CPU)
30
+ # ---------------------------------------------------------------------------
31
+ DEVICE = (
32
+ "mps" if torch.backends.mps.is_available() else
33
+ "cuda" if torch.cuda.is_available() else
34
+ "cpu"
35
+ )
36
+
37
+ # ---------------------------------------------------------------------------
38
+ # Shared ImageNet preprocessing
39
+ # ---------------------------------------------------------------------------
40
+ _IMAGENET_TRANSFORM = transforms.Compose([
41
+ transforms.ToPILImage(),
42
+ transforms.Resize((224, 224)),
43
+ transforms.ToTensor(),
44
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
45
+ std=[0.229, 0.224, 0.225]),
46
+ ])
47
+
48
+
49
+ # ===================================================================
50
+ # Base class
51
+ # ===================================================================
52
+ class _FrozenBackbone:
53
+ """Shared helpers: freeze, normalise activation maps."""
54
+
55
+ DIM: int = 0 # overridden by subclasses
56
+
57
+ # --- freeze every parameter ---
58
+ def _freeze(self, model: nn.Module) -> nn.Module:
59
+ model.eval()
60
+ for p in model.parameters():
61
+ p.requires_grad = False
62
+ return model.to(DEVICE)
63
+
64
+ # --- public interface ---
65
+ def get_features(self, img_bgr: np.ndarray) -> np.ndarray:
66
+ """Return a 1-D float32 feature vector for *img_bgr* (BGR uint8)."""
67
+ raise NotImplementedError
68
+
69
+ def get_activation_maps(self, img_bgr: np.ndarray,
70
+ n_maps: int = 6) -> list[np.ndarray]:
71
+ """Return *n_maps* normalised float32 spatial activation maps."""
72
+ raise NotImplementedError
73
+
74
+ @staticmethod
75
+ def _norm(m: np.ndarray) -> np.ndarray:
76
+ lo, hi = m.min(), m.max()
77
+ return ((m - lo) / (hi - lo + 1e-5)).astype(np.float32)
78
+
79
+
80
+ # ===================================================================
81
+ # ResNet-18
82
+ # ===================================================================
83
+ class ResNet18Backbone(_FrozenBackbone):
84
+ """ResNet-18 downloaded from PyTorch Hub, frozen, classifier removed."""
85
+
86
+ DIM = 512
87
+
88
+ def __init__(self):
89
+ full = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
90
+ self.backbone = self._freeze(full)
91
+ self.extractor = nn.Sequential(*list(full.children())[:-1]).to(DEVICE)
92
+ self.transform = _IMAGENET_TRANSFORM
93
+
94
+ def get_features(self, img_bgr):
95
+ t = self.transform(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
96
+ with torch.no_grad():
97
+ return self.extractor(t.unsqueeze(0).to(DEVICE)).cpu().numpy().flatten()
98
+
99
+ def get_activation_maps(self, img_bgr, n_maps=6):
100
+ cap = {}
101
+ hook = self.backbone.layer4.register_forward_hook(
102
+ lambda _m, _i, o: cap.update(feat=o))
103
+ t = self.transform(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
104
+ with torch.no_grad():
105
+ self.backbone(t.unsqueeze(0).to(DEVICE))
106
+ hook.remove()
107
+ acts = cap["feat"][0].cpu().numpy()
108
+ return [self._norm(acts[i]) for i in range(min(n_maps, acts.shape[0]))]
109
+
110
+
111
+ # ===================================================================
112
+ # MobileNetV3-Small
113
+ # ===================================================================
114
+ class MobileNetV3Backbone(_FrozenBackbone):
115
+ """MobileNetV3-Small from PyTorch Hub, frozen, classifier = Identity."""
116
+
117
+ DIM = 576
118
+
119
+ def __init__(self):
120
+ self.backbone = models.mobilenet_v3_small(
121
+ weights=models.MobileNet_V3_Small_Weights.DEFAULT)
122
+ self.backbone.classifier = nn.Identity()
123
+ self._freeze(self.backbone)
124
+ self.transform = _IMAGENET_TRANSFORM
125
+
126
+ def get_features(self, img_bgr):
127
+ t = self.transform(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
128
+ with torch.no_grad():
129
+ return self.backbone(t.unsqueeze(0).to(DEVICE)).cpu().numpy().flatten()
130
+
131
+ def get_activation_maps(self, img_bgr, n_maps=6):
132
+ cap = {}
133
+ hook = self.backbone.features[-1].register_forward_hook(
134
+ lambda _m, _i, o: cap.update(feat=o))
135
+ t = self.transform(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
136
+ with torch.no_grad():
137
+ self.backbone(t.unsqueeze(0).to(DEVICE))
138
+ hook.remove()
139
+ acts = cap["feat"][0].cpu().numpy()
140
+ return [self._norm(acts[i]) for i in range(min(n_maps, acts.shape[0]))]
141
+
142
+
143
+ # ===================================================================
144
+ # MobileViT-XXS
145
+ # ===================================================================
146
+ class MobileViTBackbone(_FrozenBackbone):
147
+ """MobileViT-XXS from timm (Apple Research), frozen."""
148
+
149
+ DIM = 320
150
+
151
+ def __init__(self):
152
+ self.backbone = timm.create_model(
153
+ "mobilevit_xxs.cvnets_in1k", pretrained=True, num_classes=0)
154
+ self._freeze(self.backbone)
155
+ cfg = timm.data.resolve_model_data_config(self.backbone)
156
+ self.transform = timm.data.create_transform(**cfg, is_training=False)
157
+
158
+ def _to_tensor(self, img_bgr):
159
+ from PIL import Image
160
+ pil = Image.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
161
+ return self.transform(pil).unsqueeze(0).to(DEVICE)
162
+
163
+ def get_features(self, img_bgr):
164
+ with torch.no_grad():
165
+ return self.backbone(self._to_tensor(img_bgr)).cpu().numpy().flatten()
166
+
167
+ def get_activation_maps(self, img_bgr, n_maps=6):
168
+ cap = {}
169
+ hook = self.backbone.stages[-1].register_forward_hook(
170
+ lambda _m, _i, o: cap.update(feat=o))
171
+ with torch.no_grad():
172
+ self.backbone(self._to_tensor(img_bgr))
173
+ hook.remove()
174
+ acts = cap["feat"][0].cpu().numpy()
175
+ return [self._norm(acts[i]) for i in range(min(n_maps, acts.shape[0]))]
176
+
177
+
178
+ # ===================================================================
179
+ # Lightweight Head (lives in session state, never on disk)
180
+ # ===================================================================
181
+ class RecognitionHead:
182
+ """
183
+ A tiny trainable layer on top of a frozen backbone.
184
+ Wraps sklearn ``LogisticRegression`` for binary classification.
185
+ Stored in ``st.session_state`` — never saved to disk.
186
+ """
187
+
188
+ def __init__(self, C: float = 1.0, max_iter: int = 1000):
189
+ from sklearn.linear_model import LogisticRegression
190
+ self.model = LogisticRegression(C=C, max_iter=max_iter)
191
+ self.is_trained = False
192
+
193
+ def fit(self, X, y):
194
+ self.model.fit(X, y)
195
+ self.is_trained = True
196
+ return self
197
+
198
+ def predict(self, features: np.ndarray):
199
+ """Return *(label, confidence)* for a single feature vector."""
200
+ probs = self.model.predict_proba([features])[0]
201
+ idx = int(np.argmax(probs))
202
+ return self.model.classes_[idx], probs[idx]
203
+
204
+ def predict_proba(self, X):
205
+ return self.model.predict_proba(X)
206
+
207
+ @property
208
+ def classes_(self):
209
+ return self.model.classes_
210
+
211
+
212
+ # ===================================================================
213
+ # Cached loaders — @st.cache_resource keeps models in RAM
214
+ # ===================================================================
215
+ @st.cache_resource
216
+ def get_resnet() -> ResNet18Backbone:
217
+ """Download & freeze ResNet-18. Stays in RAM across page switches."""
218
+ return ResNet18Backbone()
219
+
220
+ @st.cache_resource
221
+ def get_mobilenet() -> MobileNetV3Backbone:
222
+ """Download & freeze MobileNetV3-Small. Stays in RAM."""
223
+ return MobileNetV3Backbone()
224
+
225
+ @st.cache_resource
226
+ def get_mobilevit() -> MobileViTBackbone:
227
+ """Download & freeze MobileViT-XXS. Stays in RAM."""
228
+ return MobileViTBackbone()
229
+
230
+
231
+ # ===================================================================
232
+ # BACKBONES — The Registry Dict
233
+ # ===================================================================
234
+ BACKBONES = {
235
+ "ResNet-18": {
236
+ "loader": get_resnet,
237
+ "dim": ResNet18Backbone.DIM,
238
+ "hook_layer": "layer4 (last conv block)",
239
+ },
240
+ "MobileNetV3": {
241
+ "loader": get_mobilenet,
242
+ "dim": MobileNetV3Backbone.DIM,
243
+ "hook_layer": "features[-1] (last features block)",
244
+ },
245
+ "MobileViT-XXS": {
246
+ "loader": get_mobilevit,
247
+ "dim": MobileViTBackbone.DIM,
248
+ "hook_layer": "stages[-1] (last transformer stage)",
249
+ },
250
+ }