Spaces:
Sleeping
Sleeping
DariusGiannoli commited on
Commit ·
8ac50b6
1
Parent(s): 565e61c
feat: central model registry, real-time detection, stereo geometry & home page
Browse files- app.py +191 -16
- pages/3_Feature_Lab.py +7 -31
- pages/4_Model_Tuning.py +17 -47
- pages/5_RealTime_Detection.py +338 -0
- pages/6_Stereo_Geometry.py +327 -0
- src/models.py +250 -0
app.py
CHANGED
|
@@ -1,17 +1,192 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
st.title("🦅
|
| 9 |
-
st.
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
+
|
| 3 |
+
st.set_page_config(page_title="Perception Benchmark", layout="wide", page_icon="🦅")
|
| 4 |
+
|
| 5 |
+
# ===================================================================
|
| 6 |
+
# Header
|
| 7 |
+
# ===================================================================
|
| 8 |
+
st.title("🦅 Recognition BenchMark")
|
| 9 |
+
st.subheader("A stereo-vision pipeline for object recognition & depth estimation")
|
| 10 |
+
st.caption("Compare classical feature engineering (RCE) against modern deep learning backbones — end-to-end, in your browser.")
|
| 11 |
+
|
| 12 |
+
st.divider()
|
| 13 |
+
|
| 14 |
+
# ===================================================================
|
| 15 |
+
# Pipeline Overview
|
| 16 |
+
# ===================================================================
|
| 17 |
+
st.header("🗺️ Pipeline Overview")
|
| 18 |
+
st.markdown("""
|
| 19 |
+
The app is structured as a **5-stage sequential pipeline**.
|
| 20 |
+
Complete each page in order — every stage feeds the next.
|
| 21 |
+
""")
|
| 22 |
+
|
| 23 |
+
stages = [
|
| 24 |
+
("🧪", "1 · Data Lab", "Upload a stereo image pair, camera calibration file, and two PFM ground-truth depth maps. "
|
| 25 |
+
"Define an object ROI (bounding box), then apply live data augmentation "
|
| 26 |
+
"(brightness, contrast, rotation, noise, blur, shift, flip). "
|
| 27 |
+
"All assets are locked into session state — nothing is written to disk."),
|
| 28 |
+
("🔬", "2 · Feature Lab", "Toggle RCE physics modules (Intensity · Sobel · Spectral) to build a modular "
|
| 29 |
+
"feature vector. Compare it live against CNN activation maps extracted from a "
|
| 30 |
+
"frozen backbone via forward hooks. Lock your active module configuration."),
|
| 31 |
+
("⚙️", "3 · Model Tuning", "Train lightweight **heads** on your session data (augmented crop = positives, "
|
| 32 |
+
"random non-overlapping patches from the scene = negatives). "
|
| 33 |
+
"Both RCE and CNN heads are trained identically with LogisticRegression "
|
| 34 |
+
"and stored in session state only — no disk writes."),
|
| 35 |
+
("🎯", "4 · Real-Time Detection","Run a **sliding window** across the right image using both the RCE head and "
|
| 36 |
+
"your chosen CNN head simultaneously. Watch the scan live, then compare "
|
| 37 |
+
"bounding boxes, confidence heatmaps, and latency."),
|
| 38 |
+
("📐", "5 · Stereo Geometry", "Compute a disparity map with **StereoSGBM**, convert it to metric depth "
|
| 39 |
+
"using the stereo formula $Z = fB/(d+d_{\\text{offs}})$, then read depth "
|
| 40 |
+
"directly at every detected bounding box. Compare against PFM ground truth."),
|
| 41 |
+
]
|
| 42 |
+
|
| 43 |
+
for icon, title, desc in stages:
|
| 44 |
+
with st.container(border=True):
|
| 45 |
+
c1, c2 = st.columns([1, 12])
|
| 46 |
+
c1.markdown(f"## {icon}")
|
| 47 |
+
c2.markdown(f"**{title}** \n{desc}")
|
| 48 |
+
|
| 49 |
+
st.divider()
|
| 50 |
+
|
| 51 |
+
# ===================================================================
|
| 52 |
+
# Models
|
| 53 |
+
# ===================================================================
|
| 54 |
+
st.header("🧠 Models Used")
|
| 55 |
+
|
| 56 |
+
tab_rce, tab_resnet, tab_mobilenet, tab_mobilevit = st.tabs(
|
| 57 |
+
["RCE Engine", "ResNet-18", "MobileNetV3-Small", "MobileViT-XXS"])
|
| 58 |
+
|
| 59 |
+
with tab_rce:
|
| 60 |
+
st.markdown("### 🧬 RCE — Relative Contextual Encoding")
|
| 61 |
+
st.markdown("""
|
| 62 |
+
**Type:** Modular hand-crafted feature extractor
|
| 63 |
+
**Architecture:** Three physics-inspired modules, each producing a 10-bin histogram:
|
| 64 |
+
|
| 65 |
+
| Module | Input | Operation |
|
| 66 |
+
|--------|-------|-----------|
|
| 67 |
+
| **Intensity** | Grayscale | Pixel-value histogram (global appearance) |
|
| 68 |
+
| **Sobel** | Gradient magnitude | Edge strength distribution (texture) |
|
| 69 |
+
| **Spectral** | FFT log-magnitude | Frequency content (pattern / structure) |
|
| 70 |
+
|
| 71 |
+
**Strengths:**
|
| 72 |
+
- Fully explainable — every dimension has a physical meaning
|
| 73 |
+
- Extremely fast (µs per patch, no GPU needed)
|
| 74 |
+
- Modular: disable any module and immediately see the effect on the vector
|
| 75 |
+
- Zero pre-training needed
|
| 76 |
+
|
| 77 |
+
**Weakness:** Less discriminative than deep features for complex visual scenes.
|
| 78 |
+
""")
|
| 79 |
+
|
| 80 |
+
with tab_resnet:
|
| 81 |
+
st.markdown("### 🏗️ ResNet-18")
|
| 82 |
+
st.markdown("""
|
| 83 |
+
**Source:** PyTorch Hub (`torchvision.models.ResNet18_Weights.DEFAULT`)
|
| 84 |
+
**Pre-training:** ImageNet-1k (1.28 M images, 1 000 classes)
|
| 85 |
+
**Backbone output:** 512-dimensional embedding (after `avgpool`)
|
| 86 |
+
**Head:** LogisticRegression trained on your session data
|
| 87 |
+
|
| 88 |
+
**Architecture highlights:**
|
| 89 |
+
- 18 layers with residual (skip) connections
|
| 90 |
+
- Residual blocks prevent vanishing gradients in deeper networks
|
| 91 |
+
- `layer4` is hooked for activation map visualisation
|
| 92 |
+
|
| 93 |
+
**In this app:** The entire backbone is **frozen** (`requires_grad=False`).
|
| 94 |
+
Only the lightweight head adapts to your specific object.
|
| 95 |
+
""")
|
| 96 |
+
|
| 97 |
+
with tab_mobilenet:
|
| 98 |
+
st.markdown("### 📱 MobileNetV3-Small")
|
| 99 |
+
st.markdown("""
|
| 100 |
+
**Source:** PyTorch Hub (`torchvision.models.MobileNet_V3_Small_Weights.DEFAULT`)
|
| 101 |
+
**Pre-training:** ImageNet-1k
|
| 102 |
+
**Backbone output:** 576-dimensional embedding (classifier replaced with `Identity`)
|
| 103 |
+
**Head:** LogisticRegression trained on your session data
|
| 104 |
+
|
| 105 |
+
**Architecture highlights:**
|
| 106 |
+
- Inverted residuals + linear bottlenecks (MobileNetV2 heritage)
|
| 107 |
+
- Hard-Swish / Hard-Sigmoid activations (hardware-friendly)
|
| 108 |
+
- Squeeze-and-Excitation (SE) blocks for channel attention
|
| 109 |
+
- Designed for **edge / mobile inference** — ~2.5 M parameters
|
| 110 |
+
|
| 111 |
+
**In this app:** Typically 3–5× faster than ResNet-18.
|
| 112 |
+
`features[-1]` is hooked for activation maps.
|
| 113 |
+
""")
|
| 114 |
+
|
| 115 |
+
with tab_mobilevit:
|
| 116 |
+
st.markdown("### 🤖 MobileViT-XXS")
|
| 117 |
+
st.markdown("""
|
| 118 |
+
**Source:** timm — `mobilevit_xxs.cvnets_in1k` (Apple Research, 2022)
|
| 119 |
+
**Pre-training:** ImageNet-1k
|
| 120 |
+
**Backbone output:** 320-dimensional embedding (`num_classes=0`)
|
| 121 |
+
**Head:** LogisticRegression trained on your session data
|
| 122 |
+
|
| 123 |
+
**Architecture highlights:**
|
| 124 |
+
- **Hybrid CNN + Vision Transformer** — local convolutions for spatial features,
|
| 125 |
+
global self-attention for long-range context
|
| 126 |
+
- MobileNetV2 stem + MobileViT blocks (attention on non-overlapping patches)
|
| 127 |
+
- Only ~1.3 M parameters — smallest of the three
|
| 128 |
+
|
| 129 |
+
**In this app:** The final transformer stage `stages[-1]` is hooked.
|
| 130 |
+
Slower than MobileNetV3 but captures global structure.
|
| 131 |
+
""")
|
| 132 |
+
|
| 133 |
+
st.divider()
|
| 134 |
+
|
| 135 |
+
# ===================================================================
|
| 136 |
+
# Depth Estimation
|
| 137 |
+
# ===================================================================
|
| 138 |
+
st.header("📐 Stereo Depth Estimation")
|
| 139 |
+
|
| 140 |
+
col_d1, col_d2 = st.columns(2)
|
| 141 |
+
with col_d1:
|
| 142 |
+
st.markdown("""
|
| 143 |
+
**Algorithm:** `cv2.StereoSGBM` (Semi-Global Block Matching)
|
| 144 |
+
|
| 145 |
+
SGBM minimises a global energy function combining:
|
| 146 |
+
- Data cost (pixel intensity difference)
|
| 147 |
+
- Smoothness penalty (P1, P2 regularisation)
|
| 148 |
+
|
| 149 |
+
It processes multiple horizontal and diagonal scan-line passes,
|
| 150 |
+
making it significantly more accurate than basic block matching.
|
| 151 |
+
""")
|
| 152 |
+
with col_d2:
|
| 153 |
+
st.markdown("""
|
| 154 |
+
**Depth formula (Middlebury convention):**
|
| 155 |
+
""")
|
| 156 |
+
st.latex(r"Z = \frac{f \times B}{d + d_{\text{offs}}}")
|
| 157 |
+
st.markdown("""
|
| 158 |
+
- $f$ — focal length (pixels)
|
| 159 |
+
- $B$ — baseline (mm, from calibration file)
|
| 160 |
+
- $d$ — disparity (pixels)
|
| 161 |
+
- $d_\\text{offs}$ — optical-center offset between cameras
|
| 162 |
+
""")
|
| 163 |
+
|
| 164 |
+
st.divider()
|
| 165 |
+
|
| 166 |
+
# ===================================================================
|
| 167 |
+
# Session Status
|
| 168 |
+
# ===================================================================
|
| 169 |
+
st.header("📋 Session Status")
|
| 170 |
+
|
| 171 |
+
pipe = st.session_state.get("pipeline_data", {})
|
| 172 |
+
|
| 173 |
+
checks = {
|
| 174 |
+
"Data Lab locked": "left" in pipe,
|
| 175 |
+
"Crop defined": "crop" in pipe,
|
| 176 |
+
"Augmentation applied": "crop_aug" in pipe,
|
| 177 |
+
"Active modules locked": "active_modules" in st.session_state,
|
| 178 |
+
"RCE head trained": "rce_head" in st.session_state,
|
| 179 |
+
"CNN head trained": any(f"cnn_head_{n}" in st.session_state
|
| 180 |
+
for n in ["ResNet-18", "MobileNetV3", "MobileViT-XXS"]),
|
| 181 |
+
"RCE detections ready": "rce_dets" in st.session_state,
|
| 182 |
+
"CNN detections ready": "cnn_dets" in st.session_state,
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
cols = st.columns(4)
|
| 186 |
+
for i, (label, done) in enumerate(checks.items()):
|
| 187 |
+
cols[i % 4].markdown(
|
| 188 |
+
f"{'✅' if done else '⬜'} {'~~' if not done else ''}{label}{'~~' if not done else ''}"
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
st.divider()
|
| 192 |
+
st.caption("Navigate using the sidebar → Start with **🧪 Data Lab**")
|
pages/3_Feature_Lab.py
CHANGED
|
@@ -5,25 +5,7 @@ import plotly.graph_objects as go
|
|
| 5 |
import sys, os
|
| 6 |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 7 |
from src.detectors.rce.features import REGISTRY
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
# ---------------------------------------------------------------------------
|
| 11 |
-
# Cached model loaders — instantiated once, reused across reruns
|
| 12 |
-
# ---------------------------------------------------------------------------
|
| 13 |
-
@st.cache_resource
|
| 14 |
-
def load_resnet():
|
| 15 |
-
from src.detectors.resnet import ResNetDetector
|
| 16 |
-
return ResNetDetector()
|
| 17 |
-
|
| 18 |
-
@st.cache_resource
|
| 19 |
-
def load_mobilenet():
|
| 20 |
-
from src.detectors.mobilenet import MobileNetDetector
|
| 21 |
-
return MobileNetDetector()
|
| 22 |
-
|
| 23 |
-
@st.cache_resource
|
| 24 |
-
def load_mobilevit():
|
| 25 |
-
from src.detectors.mobilevit import MobileViTDetector
|
| 26 |
-
return MobileViTDetector()
|
| 27 |
|
| 28 |
st.set_page_config(page_title="Feature Lab", layout="wide")
|
| 29 |
|
|
@@ -86,22 +68,16 @@ with col_rce:
|
|
| 86 |
# ---------------------------------------------------------------------------
|
| 87 |
with col_cnn:
|
| 88 |
st.header("🧠 CNN: Static Architecture")
|
| 89 |
-
selected_cnn = st.selectbox("Compare against Model",
|
| 90 |
st.info("CNN features are fixed by pre-trained weights. You cannot toggle them like the RCE.")
|
| 91 |
|
| 92 |
with st.spinner(f"Loading {selected_cnn} and extracting activations..."):
|
| 93 |
try:
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
layer_name = "stages[-1] (last transformer stage)"
|
| 100 |
-
else:
|
| 101 |
-
detector = load_mobilenet()
|
| 102 |
-
layer_name = "features[-1] (last features block)"
|
| 103 |
-
|
| 104 |
-
act_maps = detector.get_activation_maps(obj, n_maps=6)
|
| 105 |
st.caption(f"Hooked layer: `{layer_name}` — showing 6 of {len(act_maps)} channels")
|
| 106 |
act_cols = st.columns(3)
|
| 107 |
for i, amap in enumerate(act_maps):
|
|
|
|
| 5 |
import sys, os
|
| 6 |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 7 |
from src.detectors.rce.features import REGISTRY
|
| 8 |
+
from src.models import BACKBONES
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
st.set_page_config(page_title="Feature Lab", layout="wide")
|
| 11 |
|
|
|
|
| 68 |
# ---------------------------------------------------------------------------
|
| 69 |
with col_cnn:
|
| 70 |
st.header("🧠 CNN: Static Architecture")
|
| 71 |
+
selected_cnn = st.selectbox("Compare against Model", list(BACKBONES.keys()))
|
| 72 |
st.info("CNN features are fixed by pre-trained weights. You cannot toggle them like the RCE.")
|
| 73 |
|
| 74 |
with st.spinner(f"Loading {selected_cnn} and extracting activations..."):
|
| 75 |
try:
|
| 76 |
+
bmeta = BACKBONES[selected_cnn]
|
| 77 |
+
backbone = bmeta["loader"]() # cached frozen backbone
|
| 78 |
+
layer_name = bmeta["hook_layer"]
|
| 79 |
+
|
| 80 |
+
act_maps = backbone.get_activation_maps(obj, n_maps=6)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
st.caption(f"Hooked layer: `{layer_name}` — showing 6 of {len(act_maps)} channels")
|
| 82 |
act_cols = st.columns(3)
|
| 83 |
for i, amap in enumerate(act_maps):
|
pages/4_Model_Tuning.py
CHANGED
|
@@ -7,6 +7,7 @@ import sys, os
|
|
| 7 |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 8 |
|
| 9 |
from src.detectors.rce.features import REGISTRY
|
|
|
|
| 10 |
|
| 11 |
st.set_page_config(page_title="Model Tuning", layout="wide")
|
| 12 |
st.title("⚙️ Model Tuning: Train & Compare")
|
|
@@ -26,31 +27,6 @@ bbox = assets.get("crop_bbox", (0, 0, crop.shape[1], crop.shape[0]))
|
|
| 26 |
active_modules = st.session_state.get("active_modules", {k: True for k in REGISTRY})
|
| 27 |
|
| 28 |
|
| 29 |
-
# ---------------------------------------------------------------------------
|
| 30 |
-
# Cached model loaders
|
| 31 |
-
# ---------------------------------------------------------------------------
|
| 32 |
-
@st.cache_resource
|
| 33 |
-
def load_resnet():
|
| 34 |
-
from src.detectors.resnet import ResNetDetector
|
| 35 |
-
return ResNetDetector()
|
| 36 |
-
|
| 37 |
-
@st.cache_resource
|
| 38 |
-
def load_mobilenet():
|
| 39 |
-
from src.detectors.mobilenet import MobileNetDetector
|
| 40 |
-
return MobileNetDetector()
|
| 41 |
-
|
| 42 |
-
@st.cache_resource
|
| 43 |
-
def load_mobilevit():
|
| 44 |
-
from src.detectors.mobilevit import MobileViTDetector
|
| 45 |
-
return MobileViTDetector()
|
| 46 |
-
|
| 47 |
-
CNN_MODELS = {
|
| 48 |
-
"ResNet-18": {"loader": load_resnet, "dim": 512},
|
| 49 |
-
"MobileNetV3": {"loader": load_mobilenet, "dim": 576},
|
| 50 |
-
"MobileViT-XXS": {"loader": load_mobilevit, "dim": 320},
|
| 51 |
-
}
|
| 52 |
-
|
| 53 |
-
|
| 54 |
# ---------------------------------------------------------------------------
|
| 55 |
# Build training set from session data (no disk reads)
|
| 56 |
# ---------------------------------------------------------------------------
|
|
@@ -137,7 +113,6 @@ with col_rce:
|
|
| 137 |
|
| 138 |
if st.button("🚀 Train RCE Head"):
|
| 139 |
images, labels = build_training_set()
|
| 140 |
-
from sklearn.linear_model import LogisticRegression
|
| 141 |
from sklearn.metrics import accuracy_score
|
| 142 |
|
| 143 |
progress = st.progress(0, text="Extracting RCE features...")
|
|
@@ -151,12 +126,11 @@ with col_rce:
|
|
| 151 |
progress.progress(1.0, text="Fitting Logistic Regression...")
|
| 152 |
|
| 153 |
t0 = time.perf_counter()
|
| 154 |
-
head =
|
| 155 |
-
head.fit(X, labels)
|
| 156 |
train_time = time.perf_counter() - t0
|
| 157 |
progress.progress(1.0, text="✅ Training complete!")
|
| 158 |
|
| 159 |
-
preds = head.predict(X)
|
| 160 |
train_acc = accuracy_score(labels, preds)
|
| 161 |
|
| 162 |
st.success(f"Trained in **{train_time:.2f}s**")
|
|
@@ -184,10 +158,9 @@ with col_rce:
|
|
| 184 |
head = st.session_state["rce_head"]
|
| 185 |
t0 = time.perf_counter()
|
| 186 |
vec = build_rce_vector(crop_aug)
|
| 187 |
-
|
| 188 |
dt = (time.perf_counter() - t0) * 1000
|
| 189 |
-
|
| 190 |
-
st.write(f"**{head.classes_[idx]}** — {probs[idx]:.1%} confidence — {dt:.1f} ms")
|
| 191 |
|
| 192 |
|
| 193 |
# ---------------------------------------------------------------------------
|
|
@@ -196,8 +169,8 @@ with col_rce:
|
|
| 196 |
with col_cnn:
|
| 197 |
st.header("🧠 CNN Fine-Tuning")
|
| 198 |
|
| 199 |
-
selected = st.selectbox("Select Model", list(
|
| 200 |
-
meta =
|
| 201 |
st.caption(f"Backbone embedding: **{meta['dim']}D** → Logistic Regression head")
|
| 202 |
|
| 203 |
st.subheader("Training Parameters")
|
|
@@ -208,28 +181,26 @@ with col_cnn:
|
|
| 208 |
|
| 209 |
if st.button(f"🚀 Train {selected} Head"):
|
| 210 |
images, labels = build_training_set()
|
| 211 |
-
|
| 212 |
|
| 213 |
-
from sklearn.linear_model import LogisticRegression
|
| 214 |
from sklearn.metrics import accuracy_score
|
| 215 |
|
| 216 |
progress = st.progress(0, text=f"Extracting {selected} features...")
|
| 217 |
n = len(images)
|
| 218 |
X = []
|
| 219 |
for i, img in enumerate(images):
|
| 220 |
-
X.append(
|
| 221 |
progress.progress((i + 1) / n, text=f"Feature extraction: {i+1}/{n}")
|
| 222 |
|
| 223 |
X = np.array(X)
|
| 224 |
progress.progress(1.0, text="Fitting Logistic Regression...")
|
| 225 |
|
| 226 |
t0 = time.perf_counter()
|
| 227 |
-
head =
|
| 228 |
-
head.fit(X, labels)
|
| 229 |
train_time = time.perf_counter() - t0
|
| 230 |
progress.progress(1.0, text="✅ Training complete!")
|
| 231 |
|
| 232 |
-
preds = head.predict(X)
|
| 233 |
train_acc = accuracy_score(labels, preds)
|
| 234 |
|
| 235 |
st.success(f"Trained in **{train_time:.2f}s**")
|
|
@@ -254,14 +225,13 @@ with col_cnn:
|
|
| 254 |
if f"cnn_head_{selected}" in st.session_state:
|
| 255 |
st.divider()
|
| 256 |
st.subheader("Quick Predict (Crop)")
|
| 257 |
-
|
| 258 |
head = st.session_state[f"cnn_head_{selected}"]
|
| 259 |
t0 = time.perf_counter()
|
| 260 |
-
feats =
|
| 261 |
-
|
| 262 |
dt = (time.perf_counter() - t0) * 1000
|
| 263 |
-
|
| 264 |
-
st.write(f"**{head.classes_[idx]}** — {probs[idx]:.1%} confidence — {dt:.1f} ms")
|
| 265 |
|
| 266 |
|
| 267 |
# ===========================================================================
|
|
@@ -275,11 +245,11 @@ rows = []
|
|
| 275 |
if rce_acc is not None:
|
| 276 |
rows.append({"Model": "RCE", "Train Accuracy": f"{rce_acc:.1%}",
|
| 277 |
"Vector Size": str(sum(10 for k in active_modules if active_modules[k]))})
|
| 278 |
-
for name in
|
| 279 |
acc = st.session_state.get(f"cnn_acc_{name}")
|
| 280 |
if acc is not None:
|
| 281 |
rows.append({"Model": name, "Train Accuracy": f"{acc:.1%}",
|
| 282 |
-
"Vector Size": f"{
|
| 283 |
|
| 284 |
if rows:
|
| 285 |
import pandas as pd
|
|
|
|
| 7 |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 8 |
|
| 9 |
from src.detectors.rce.features import REGISTRY
|
| 10 |
+
from src.models import BACKBONES, RecognitionHead
|
| 11 |
|
| 12 |
st.set_page_config(page_title="Model Tuning", layout="wide")
|
| 13 |
st.title("⚙️ Model Tuning: Train & Compare")
|
|
|
|
| 27 |
active_modules = st.session_state.get("active_modules", {k: True for k in REGISTRY})
|
| 28 |
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
# ---------------------------------------------------------------------------
|
| 31 |
# Build training set from session data (no disk reads)
|
| 32 |
# ---------------------------------------------------------------------------
|
|
|
|
| 113 |
|
| 114 |
if st.button("🚀 Train RCE Head"):
|
| 115 |
images, labels = build_training_set()
|
|
|
|
| 116 |
from sklearn.metrics import accuracy_score
|
| 117 |
|
| 118 |
progress = st.progress(0, text="Extracting RCE features...")
|
|
|
|
| 126 |
progress.progress(1.0, text="Fitting Logistic Regression...")
|
| 127 |
|
| 128 |
t0 = time.perf_counter()
|
| 129 |
+
head = RecognitionHead(C=rce_C, max_iter=rce_max_iter).fit(X, labels)
|
|
|
|
| 130 |
train_time = time.perf_counter() - t0
|
| 131 |
progress.progress(1.0, text="✅ Training complete!")
|
| 132 |
|
| 133 |
+
preds = head.model.predict(X)
|
| 134 |
train_acc = accuracy_score(labels, preds)
|
| 135 |
|
| 136 |
st.success(f"Trained in **{train_time:.2f}s**")
|
|
|
|
| 158 |
head = st.session_state["rce_head"]
|
| 159 |
t0 = time.perf_counter()
|
| 160 |
vec = build_rce_vector(crop_aug)
|
| 161 |
+
label, conf = head.predict(vec)
|
| 162 |
dt = (time.perf_counter() - t0) * 1000
|
| 163 |
+
st.write(f"**{label}** — {conf:.1%} confidence — {dt:.1f} ms")
|
|
|
|
| 164 |
|
| 165 |
|
| 166 |
# ---------------------------------------------------------------------------
|
|
|
|
| 169 |
with col_cnn:
|
| 170 |
st.header("🧠 CNN Fine-Tuning")
|
| 171 |
|
| 172 |
+
selected = st.selectbox("Select Model", list(BACKBONES.keys()))
|
| 173 |
+
meta = BACKBONES[selected]
|
| 174 |
st.caption(f"Backbone embedding: **{meta['dim']}D** → Logistic Regression head")
|
| 175 |
|
| 176 |
st.subheader("Training Parameters")
|
|
|
|
| 181 |
|
| 182 |
if st.button(f"🚀 Train {selected} Head"):
|
| 183 |
images, labels = build_training_set()
|
| 184 |
+
backbone = meta["loader"]() # cached frozen backbone
|
| 185 |
|
|
|
|
| 186 |
from sklearn.metrics import accuracy_score
|
| 187 |
|
| 188 |
progress = st.progress(0, text=f"Extracting {selected} features...")
|
| 189 |
n = len(images)
|
| 190 |
X = []
|
| 191 |
for i, img in enumerate(images):
|
| 192 |
+
X.append(backbone.get_features(img))
|
| 193 |
progress.progress((i + 1) / n, text=f"Feature extraction: {i+1}/{n}")
|
| 194 |
|
| 195 |
X = np.array(X)
|
| 196 |
progress.progress(1.0, text="Fitting Logistic Regression...")
|
| 197 |
|
| 198 |
t0 = time.perf_counter()
|
| 199 |
+
head = RecognitionHead(C=cnn_C, max_iter=cnn_max_iter).fit(X, labels)
|
|
|
|
| 200 |
train_time = time.perf_counter() - t0
|
| 201 |
progress.progress(1.0, text="✅ Training complete!")
|
| 202 |
|
| 203 |
+
preds = head.model.predict(X)
|
| 204 |
train_acc = accuracy_score(labels, preds)
|
| 205 |
|
| 206 |
st.success(f"Trained in **{train_time:.2f}s**")
|
|
|
|
| 225 |
if f"cnn_head_{selected}" in st.session_state:
|
| 226 |
st.divider()
|
| 227 |
st.subheader("Quick Predict (Crop)")
|
| 228 |
+
backbone = meta["loader"]() # cached frozen backbone
|
| 229 |
head = st.session_state[f"cnn_head_{selected}"]
|
| 230 |
t0 = time.perf_counter()
|
| 231 |
+
feats = backbone.get_features(crop_aug)
|
| 232 |
+
label, conf = head.predict(feats)
|
| 233 |
dt = (time.perf_counter() - t0) * 1000
|
| 234 |
+
st.write(f"**{label}** — {conf:.1%} confidence — {dt:.1f} ms")
|
|
|
|
| 235 |
|
| 236 |
|
| 237 |
# ===========================================================================
|
|
|
|
| 245 |
if rce_acc is not None:
|
| 246 |
rows.append({"Model": "RCE", "Train Accuracy": f"{rce_acc:.1%}",
|
| 247 |
"Vector Size": str(sum(10 for k in active_modules if active_modules[k]))})
|
| 248 |
+
for name in BACKBONES:
|
| 249 |
acc = st.session_state.get(f"cnn_acc_{name}")
|
| 250 |
if acc is not None:
|
| 251 |
rows.append({"Model": name, "Train Accuracy": f"{acc:.1%}",
|
| 252 |
+
"Vector Size": f"{BACKBONES[name]['dim']}D"})
|
| 253 |
|
| 254 |
if rows:
|
| 255 |
import pandas as pd
|
pages/5_RealTime_Detection.py
CHANGED
|
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
import time
|
| 5 |
+
import plotly.graph_objects as go
|
| 6 |
+
import sys, os
|
| 7 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 8 |
+
|
| 9 |
+
from src.detectors.rce.features import REGISTRY
|
| 10 |
+
from src.models import BACKBONES, RecognitionHead
|
| 11 |
+
|
| 12 |
+
st.set_page_config(page_title="Real-Time Detection", layout="wide")
|
| 13 |
+
st.title("🎯 Real-Time Detection")
|
| 14 |
+
|
| 15 |
+
# ---------------------------------------------------------------------------
|
| 16 |
+
# Guard
|
| 17 |
+
# ---------------------------------------------------------------------------
|
| 18 |
+
if "pipeline_data" not in st.session_state or "crop" not in st.session_state.get("pipeline_data", {}):
|
| 19 |
+
st.error("Complete **Data Lab** first (upload assets & define a crop).")
|
| 20 |
+
st.stop()
|
| 21 |
+
|
| 22 |
+
assets = st.session_state["pipeline_data"]
|
| 23 |
+
right_img = assets["right"]
|
| 24 |
+
crop = assets["crop"]
|
| 25 |
+
crop_aug = assets.get("crop_aug", crop)
|
| 26 |
+
bbox = assets.get("crop_bbox", (0, 0, crop.shape[1], crop.shape[0]))
|
| 27 |
+
active_mods = st.session_state.get("active_modules", {k: True for k in REGISTRY})
|
| 28 |
+
|
| 29 |
+
x0, y0, x1, y1 = bbox
|
| 30 |
+
win_h, win_w = y1 - y0, x1 - x0 # window = same size as crop
|
| 31 |
+
|
| 32 |
+
rce_head = st.session_state.get("rce_head")
|
| 33 |
+
has_any_cnn = any(f"cnn_head_{n}" in st.session_state for n in BACKBONES)
|
| 34 |
+
|
| 35 |
+
if rce_head is None and not has_any_cnn:
|
| 36 |
+
st.warning("No trained heads found. Go to **Model Tuning** and train at least one head.")
|
| 37 |
+
st.stop()
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# ===================================================================
|
| 41 |
+
# Sliding Window Engine (shared by both sides)
|
| 42 |
+
# ===================================================================
|
| 43 |
+
def sliding_window_detect(
|
| 44 |
+
image: np.ndarray,
|
| 45 |
+
feature_fn, # callable(patch_bgr) -> 1-D np.ndarray
|
| 46 |
+
head: RecognitionHead,
|
| 47 |
+
stride: int,
|
| 48 |
+
conf_thresh: float,
|
| 49 |
+
nms_iou: float,
|
| 50 |
+
progress_placeholder=None,
|
| 51 |
+
live_image_placeholder=None,
|
| 52 |
+
):
|
| 53 |
+
"""
|
| 54 |
+
Slide a window of size (win_h, win_w) across *image* with *stride*.
|
| 55 |
+
At each position call *feature_fn* → *head.predict*.
|
| 56 |
+
Returns (detections, heatmap, total_time_ms, n_windows).
|
| 57 |
+
|
| 58 |
+
Each detection is (x, y, x+win_w, y+win_h, label, confidence).
|
| 59 |
+
heatmap is a float32 array same size as image (object confidence).
|
| 60 |
+
"""
|
| 61 |
+
H, W = image.shape[:2]
|
| 62 |
+
heatmap = np.zeros((H, W), dtype=np.float32)
|
| 63 |
+
detections = []
|
| 64 |
+
t0 = time.perf_counter()
|
| 65 |
+
|
| 66 |
+
positions = []
|
| 67 |
+
for y in range(0, H - win_h + 1, stride):
|
| 68 |
+
for x in range(0, W - win_w + 1, stride):
|
| 69 |
+
positions.append((x, y))
|
| 70 |
+
|
| 71 |
+
n_total = len(positions)
|
| 72 |
+
if n_total == 0:
|
| 73 |
+
return [], heatmap, 0.0, 0
|
| 74 |
+
|
| 75 |
+
for idx, (x, y) in enumerate(positions):
|
| 76 |
+
patch = image[y:y+win_h, x:x+win_w]
|
| 77 |
+
feats = feature_fn(patch)
|
| 78 |
+
label, conf = head.predict(feats)
|
| 79 |
+
|
| 80 |
+
# Fill heatmap with object confidence
|
| 81 |
+
if label == "object":
|
| 82 |
+
heatmap[y:y+win_h, x:x+win_w] = np.maximum(
|
| 83 |
+
heatmap[y:y+win_h, x:x+win_w], conf)
|
| 84 |
+
if conf >= conf_thresh:
|
| 85 |
+
detections.append((x, y, x+win_w, y+win_h, label, conf))
|
| 86 |
+
|
| 87 |
+
# Live updates (every 5th window or last)
|
| 88 |
+
if live_image_placeholder is not None and (idx % 5 == 0 or idx == n_total - 1):
|
| 89 |
+
vis = image.copy()
|
| 90 |
+
# Draw current scan position
|
| 91 |
+
cv2.rectangle(vis, (x, y), (x+win_w, y+win_h), (255, 255, 0), 1)
|
| 92 |
+
# Draw current detections
|
| 93 |
+
for dx, dy, dx2, dy2, dl, dc in detections:
|
| 94 |
+
cv2.rectangle(vis, (dx, dy), (dx2, dy2), (0, 255, 0), 2)
|
| 95 |
+
cv2.putText(vis, f"{dc:.0%}", (dx, dy - 4),
|
| 96 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 255, 0), 1)
|
| 97 |
+
live_image_placeholder.image(
|
| 98 |
+
cv2.cvtColor(vis, cv2.COLOR_BGR2RGB),
|
| 99 |
+
caption=f"Scanning… {idx+1}/{n_total}",
|
| 100 |
+
use_container_width=True)
|
| 101 |
+
|
| 102 |
+
if progress_placeholder is not None:
|
| 103 |
+
progress_placeholder.progress(
|
| 104 |
+
(idx + 1) / n_total,
|
| 105 |
+
text=f"Window {idx+1}/{n_total}")
|
| 106 |
+
|
| 107 |
+
total_ms = (time.perf_counter() - t0) * 1000
|
| 108 |
+
|
| 109 |
+
# --- Non-Maximum Suppression ---
|
| 110 |
+
if detections:
|
| 111 |
+
detections = _nms(detections, nms_iou)
|
| 112 |
+
|
| 113 |
+
return detections, heatmap, total_ms, n_total
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def _nms(dets, iou_thresh):
|
| 117 |
+
"""Greedy NMS on list of (x1,y1,x2,y2,label,conf)."""
|
| 118 |
+
dets = sorted(dets, key=lambda d: d[5], reverse=True)
|
| 119 |
+
keep = []
|
| 120 |
+
while dets:
|
| 121 |
+
best = dets.pop(0)
|
| 122 |
+
keep.append(best)
|
| 123 |
+
dets = [d for d in dets if _iou(best, d) < iou_thresh]
|
| 124 |
+
return keep
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def _iou(a, b):
|
| 128 |
+
"""IoU between two (x1,y1,x2,y2,…) tuples."""
|
| 129 |
+
xi1 = max(a[0], b[0]); yi1 = max(a[1], b[1])
|
| 130 |
+
xi2 = min(a[2], b[2]); yi2 = min(a[3], b[3])
|
| 131 |
+
inter = max(0, xi2-xi1) * max(0, yi2-yi1)
|
| 132 |
+
aa = (a[2]-a[0])*(a[3]-a[1])
|
| 133 |
+
ab = (b[2]-b[0])*(b[3]-b[1])
|
| 134 |
+
return inter / (aa + ab - inter + 1e-6)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
# ===================================================================
|
| 138 |
+
# RCE feature function
|
| 139 |
+
# ===================================================================
|
| 140 |
+
def rce_feature_fn(patch_bgr):
|
| 141 |
+
gray = cv2.cvtColor(patch_bgr, cv2.COLOR_BGR2GRAY)
|
| 142 |
+
vec = []
|
| 143 |
+
for key, meta in REGISTRY.items():
|
| 144 |
+
if active_mods.get(key, False):
|
| 145 |
+
v, _ = meta["fn"](gray)
|
| 146 |
+
vec.extend(v)
|
| 147 |
+
return np.array(vec, dtype=np.float32)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
# ===================================================================
|
| 151 |
+
# Controls
|
| 152 |
+
# ===================================================================
|
| 153 |
+
st.subheader("Sliding Window Parameters")
|
| 154 |
+
p1, p2, p3 = st.columns(3)
|
| 155 |
+
stride = p1.slider("Stride (px)", 4, max(win_w, win_h),
|
| 156 |
+
max(win_w // 4, 4), step=2,
|
| 157 |
+
help="Lower = more windows = slower but finer")
|
| 158 |
+
conf_thresh = p2.slider("Confidence Threshold", 0.5, 1.0, 0.7, 0.05)
|
| 159 |
+
nms_iou = p3.slider("NMS IoU Threshold", 0.1, 0.9, 0.3, 0.05)
|
| 160 |
+
|
| 161 |
+
st.caption(f"Window size: **{win_w}×{win_h} px** | "
|
| 162 |
+
f"Right image: **{right_img.shape[1]}×{right_img.shape[0]} px** | "
|
| 163 |
+
f"≈ {((right_img.shape[0]-win_h)//stride + 1) * ((right_img.shape[1]-win_w)//stride + 1)} windows")
|
| 164 |
+
|
| 165 |
+
st.divider()
|
| 166 |
+
|
| 167 |
+
# ===================================================================
|
| 168 |
+
# Side-by-side layout
|
| 169 |
+
# ===================================================================
|
| 170 |
+
col_rce, col_cnn = st.columns(2)
|
| 171 |
+
|
| 172 |
+
# -------------------------------------------------------------------
|
| 173 |
+
# LEFT — RCE Detection
|
| 174 |
+
# -------------------------------------------------------------------
|
| 175 |
+
with col_rce:
|
| 176 |
+
st.header("🧬 RCE Detection")
|
| 177 |
+
if rce_head is None:
|
| 178 |
+
st.info("No RCE head trained. Train one in **Model Tuning**.")
|
| 179 |
+
else:
|
| 180 |
+
st.caption(f"Modules: {', '.join(REGISTRY[k]['label'] for k in active_mods if active_mods[k])}")
|
| 181 |
+
rce_run = st.button("▶ Run RCE Scan", key="rce_run")
|
| 182 |
+
|
| 183 |
+
rce_progress = st.empty()
|
| 184 |
+
rce_live = st.empty()
|
| 185 |
+
rce_results = st.container()
|
| 186 |
+
|
| 187 |
+
if rce_run:
|
| 188 |
+
dets, hmap, ms, nw = sliding_window_detect(
|
| 189 |
+
right_img, rce_feature_fn, rce_head,
|
| 190 |
+
stride, conf_thresh, nms_iou,
|
| 191 |
+
progress_placeholder=rce_progress,
|
| 192 |
+
live_image_placeholder=rce_live,
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
# Final image with boxes
|
| 196 |
+
final = right_img.copy()
|
| 197 |
+
for x1d, y1d, x2d, y2d, lbl, cf in dets:
|
| 198 |
+
cv2.rectangle(final, (x1d, y1d), (x2d, y2d), (0, 255, 0), 2)
|
| 199 |
+
cv2.putText(final, f"{cf:.0%}", (x1d, y1d - 6),
|
| 200 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
|
| 201 |
+
rce_live.image(cv2.cvtColor(final, cv2.COLOR_BGR2RGB),
|
| 202 |
+
caption="RCE — Final Detections",
|
| 203 |
+
use_container_width=True)
|
| 204 |
+
rce_progress.empty()
|
| 205 |
+
|
| 206 |
+
with rce_results:
|
| 207 |
+
# Metrics
|
| 208 |
+
rm1, rm2, rm3, rm4 = st.columns(4)
|
| 209 |
+
rm1.metric("Detections", len(dets))
|
| 210 |
+
rm2.metric("Windows", nw)
|
| 211 |
+
rm3.metric("Total Time", f"{ms:.0f} ms")
|
| 212 |
+
rm4.metric("Per Window", f"{ms/max(nw,1):.2f} ms")
|
| 213 |
+
|
| 214 |
+
# Confidence heatmap
|
| 215 |
+
if hmap.max() > 0:
|
| 216 |
+
hmap_color = cv2.applyColorMap(
|
| 217 |
+
(hmap / hmap.max() * 255).astype(np.uint8),
|
| 218 |
+
cv2.COLORMAP_JET)
|
| 219 |
+
blend = cv2.addWeighted(right_img, 0.5, hmap_color, 0.5, 0)
|
| 220 |
+
st.image(cv2.cvtColor(blend, cv2.COLOR_BGR2RGB),
|
| 221 |
+
caption="RCE — Confidence Heatmap",
|
| 222 |
+
use_container_width=True)
|
| 223 |
+
|
| 224 |
+
# Detection table
|
| 225 |
+
if dets:
|
| 226 |
+
import pandas as pd
|
| 227 |
+
df = pd.DataFrame(dets, columns=["x1","y1","x2","y2","label","conf"])
|
| 228 |
+
st.dataframe(df, use_container_width=True, hide_index=True)
|
| 229 |
+
|
| 230 |
+
st.session_state["rce_dets"] = dets
|
| 231 |
+
st.session_state["rce_det_ms"] = ms
|
| 232 |
+
|
| 233 |
+
# -------------------------------------------------------------------
|
| 234 |
+
# RIGHT — CNN Detection
|
| 235 |
+
# -------------------------------------------------------------------
|
| 236 |
+
with col_cnn:
|
| 237 |
+
st.header("🧠 CNN Detection")
|
| 238 |
+
|
| 239 |
+
# Find which CNN heads are trained
|
| 240 |
+
trained_cnns = [n for n in BACKBONES if f"cnn_head_{n}" in st.session_state]
|
| 241 |
+
if not trained_cnns:
|
| 242 |
+
st.info("No CNN head trained. Train one in **Model Tuning**.")
|
| 243 |
+
else:
|
| 244 |
+
selected = st.selectbox("Select Model", trained_cnns, key="det_cnn_sel")
|
| 245 |
+
bmeta = BACKBONES[selected]
|
| 246 |
+
backbone = bmeta["loader"]()
|
| 247 |
+
head = st.session_state[f"cnn_head_{selected}"]
|
| 248 |
+
|
| 249 |
+
st.caption(f"Backbone: **{selected}** ({bmeta['dim']}D) — Head in session state")
|
| 250 |
+
cnn_run = st.button(f"▶ Run {selected} Scan", key="cnn_run")
|
| 251 |
+
|
| 252 |
+
cnn_progress = st.empty()
|
| 253 |
+
cnn_live = st.empty()
|
| 254 |
+
cnn_results = st.container()
|
| 255 |
+
|
| 256 |
+
if cnn_run:
|
| 257 |
+
dets, hmap, ms, nw = sliding_window_detect(
|
| 258 |
+
right_img, backbone.get_features, head,
|
| 259 |
+
stride, conf_thresh, nms_iou,
|
| 260 |
+
progress_placeholder=cnn_progress,
|
| 261 |
+
live_image_placeholder=cnn_live,
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
# Final image
|
| 265 |
+
final = right_img.copy()
|
| 266 |
+
for x1d, y1d, x2d, y2d, lbl, cf in dets:
|
| 267 |
+
cv2.rectangle(final, (x1d, y1d), (x2d, y2d), (0, 0, 255), 2)
|
| 268 |
+
cv2.putText(final, f"{cf:.0%}", (x1d, y1d - 6),
|
| 269 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
|
| 270 |
+
cnn_live.image(cv2.cvtColor(final, cv2.COLOR_BGR2RGB),
|
| 271 |
+
caption=f"{selected} — Final Detections",
|
| 272 |
+
use_container_width=True)
|
| 273 |
+
cnn_progress.empty()
|
| 274 |
+
|
| 275 |
+
with cnn_results:
|
| 276 |
+
cm1, cm2, cm3, cm4 = st.columns(4)
|
| 277 |
+
cm1.metric("Detections", len(dets))
|
| 278 |
+
cm2.metric("Windows", nw)
|
| 279 |
+
cm3.metric("Total Time", f"{ms:.0f} ms")
|
| 280 |
+
cm4.metric("Per Window", f"{ms/max(nw,1):.2f} ms")
|
| 281 |
+
|
| 282 |
+
if hmap.max() > 0:
|
| 283 |
+
hmap_color = cv2.applyColorMap(
|
| 284 |
+
(hmap / hmap.max() * 255).astype(np.uint8),
|
| 285 |
+
cv2.COLORMAP_JET)
|
| 286 |
+
blend = cv2.addWeighted(right_img, 0.5, hmap_color, 0.5, 0)
|
| 287 |
+
st.image(cv2.cvtColor(blend, cv2.COLOR_BGR2RGB),
|
| 288 |
+
caption=f"{selected} — Confidence Heatmap",
|
| 289 |
+
use_container_width=True)
|
| 290 |
+
|
| 291 |
+
if dets:
|
| 292 |
+
import pandas as pd
|
| 293 |
+
df = pd.DataFrame(dets, columns=["x1","y1","x2","y2","label","conf"])
|
| 294 |
+
st.dataframe(df, use_container_width=True, hide_index=True)
|
| 295 |
+
|
| 296 |
+
st.session_state["cnn_dets"] = dets
|
| 297 |
+
st.session_state["cnn_det_ms"] = ms
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
# ===================================================================
|
| 301 |
+
# Bottom — Comparison (if both have run)
|
| 302 |
+
# ===================================================================
|
| 303 |
+
rce_dets = st.session_state.get("rce_dets")
|
| 304 |
+
cnn_dets = st.session_state.get("cnn_dets")
|
| 305 |
+
|
| 306 |
+
if rce_dets is not None and cnn_dets is not None:
|
| 307 |
+
st.divider()
|
| 308 |
+
st.subheader("📊 Side-by-Side Comparison")
|
| 309 |
+
|
| 310 |
+
import pandas as pd
|
| 311 |
+
comp = pd.DataFrame({
|
| 312 |
+
"Metric": ["Detections", "Best Confidence", "Total Time (ms)"],
|
| 313 |
+
"RCE": [
|
| 314 |
+
len(rce_dets),
|
| 315 |
+
f"{max((d[5] for d in rce_dets), default=0):.1%}",
|
| 316 |
+
f"{st.session_state.get('rce_det_ms', 0):.0f}",
|
| 317 |
+
],
|
| 318 |
+
"CNN": [
|
| 319 |
+
len(cnn_dets),
|
| 320 |
+
f"{max((d[5] for d in cnn_dets), default=0):.1%}",
|
| 321 |
+
f"{st.session_state.get('cnn_det_ms', 0):.0f}",
|
| 322 |
+
],
|
| 323 |
+
})
|
| 324 |
+
st.dataframe(comp, use_container_width=True, hide_index=True)
|
| 325 |
+
|
| 326 |
+
# Overlay both on one image
|
| 327 |
+
overlay = right_img.copy()
|
| 328 |
+
for x1d, y1d, x2d, y2d, _, cf in rce_dets:
|
| 329 |
+
cv2.rectangle(overlay, (x1d, y1d), (x2d, y2d), (0, 255, 0), 2)
|
| 330 |
+
cv2.putText(overlay, f"RCE {cf:.0%}", (x1d, y1d - 6),
|
| 331 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 255, 0), 1)
|
| 332 |
+
for x1d, y1d, x2d, y2d, _, cf in cnn_dets:
|
| 333 |
+
cv2.rectangle(overlay, (x1d, y1d), (x2d, y2d), (0, 0, 255), 2)
|
| 334 |
+
cv2.putText(overlay, f"CNN {cf:.0%}", (x1d, y2d + 12),
|
| 335 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 255), 1)
|
| 336 |
+
st.image(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB),
|
| 337 |
+
caption="Green = RCE | Blue = CNN",
|
| 338 |
+
use_container_width=True)
|
pages/6_Stereo_Geometry.py
CHANGED
|
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
import re
|
| 5 |
+
import plotly.graph_objects as go
|
| 6 |
+
import sys, os
|
| 7 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 8 |
+
|
| 9 |
+
st.set_page_config(page_title="Stereo Geometry", layout="wide")
|
| 10 |
+
st.title("📐 Stereo Geometry: Distance Estimation")
|
| 11 |
+
|
| 12 |
+
# ---------------------------------------------------------------------------
|
| 13 |
+
# Guard
|
| 14 |
+
# ---------------------------------------------------------------------------
|
| 15 |
+
if "pipeline_data" not in st.session_state or "left" not in st.session_state.get("pipeline_data", {}):
|
| 16 |
+
st.error("Complete **Data Lab** first.")
|
| 17 |
+
st.stop()
|
| 18 |
+
|
| 19 |
+
assets = st.session_state["pipeline_data"]
|
| 20 |
+
img_l = assets["left"]
|
| 21 |
+
img_r = assets["right"]
|
| 22 |
+
gt_left = assets.get("gt_left") # float32 depth map from PFM
|
| 23 |
+
gt_right = assets.get("gt_right")
|
| 24 |
+
conf_raw = assets.get("conf_raw", "")
|
| 25 |
+
crop_bbox = assets.get("crop_bbox") # (x0, y0, x1, y1) on LEFT image
|
| 26 |
+
|
| 27 |
+
rce_dets = st.session_state.get("rce_dets", []) # list of (x1,y1,x2,y2,label,conf)
|
| 28 |
+
cnn_dets = st.session_state.get("cnn_dets", [])
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# ===================================================================
|
| 32 |
+
# Parse Middlebury-style camera config
|
| 33 |
+
# ===================================================================
|
| 34 |
+
def parse_config(text: str) -> dict:
|
| 35 |
+
"""
|
| 36 |
+
Parse a Middlebury .txt / .conf calibration file.
|
| 37 |
+
Expected keys: cam0, cam1, doffs, baseline, width, height, ndisp, vmin, vmax
|
| 38 |
+
cam0 / cam1 are 3×3 matrices in bracket notation: [f 0 cx; 0 f cy; 0 0 1]
|
| 39 |
+
"""
|
| 40 |
+
params = {}
|
| 41 |
+
for line in text.strip().splitlines():
|
| 42 |
+
line = line.strip()
|
| 43 |
+
if "=" not in line:
|
| 44 |
+
continue
|
| 45 |
+
key, val = line.split("=", 1)
|
| 46 |
+
key = key.strip()
|
| 47 |
+
val = val.strip()
|
| 48 |
+
# Matrix?
|
| 49 |
+
if "[" in val:
|
| 50 |
+
nums = list(map(float, re.findall(r"[-+]?\d*\.?\d+(?:[eE][-+]?\d+)?", val)))
|
| 51 |
+
if len(nums) == 9:
|
| 52 |
+
params[key] = np.array(nums).reshape(3, 3)
|
| 53 |
+
else:
|
| 54 |
+
params[key] = nums
|
| 55 |
+
else:
|
| 56 |
+
try:
|
| 57 |
+
params[key] = float(val)
|
| 58 |
+
except ValueError:
|
| 59 |
+
params[key] = val
|
| 60 |
+
return params
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
calib = parse_config(conf_raw)
|
| 64 |
+
|
| 65 |
+
# Extract intrinsics
|
| 66 |
+
focal = calib.get("cam0", np.eye(3))[0, 0] if isinstance(calib.get("cam0"), np.ndarray) else 0.0
|
| 67 |
+
doffs = float(calib.get("doffs", 0.0))
|
| 68 |
+
baseline = float(calib.get("baseline", 1.0))
|
| 69 |
+
ndisp = int(calib.get("ndisp", 128))
|
| 70 |
+
|
| 71 |
+
st.subheader("Camera Calibration")
|
| 72 |
+
cc1, cc2, cc3, cc4 = st.columns(4)
|
| 73 |
+
cc1.metric("Focal Length (px)", f"{focal:.1f}")
|
| 74 |
+
cc2.metric("Baseline (mm)", f"{baseline:.1f}")
|
| 75 |
+
cc3.metric("Doffs (px)", f"{doffs:.2f}")
|
| 76 |
+
cc4.metric("ndisp", str(ndisp))
|
| 77 |
+
|
| 78 |
+
with st.expander("Full Calibration"):
|
| 79 |
+
st.json({k: v.tolist() if isinstance(v, np.ndarray) else v for k, v in calib.items()})
|
| 80 |
+
|
| 81 |
+
st.divider()
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# ===================================================================
|
| 85 |
+
# Step 1 — Compute Disparity Map
|
| 86 |
+
# ===================================================================
|
| 87 |
+
st.subheader("Step 1: Disparity Map (StereoSGBM)")
|
| 88 |
+
|
| 89 |
+
sc1, sc2, sc3 = st.columns(3)
|
| 90 |
+
block_size = sc1.slider("Block Size", 3, 21, 5, step=2)
|
| 91 |
+
p1_mult = sc2.slider("P1 multiplier", 1, 32, 8)
|
| 92 |
+
p2_mult = sc3.slider("P2 multiplier", 1, 128, 32)
|
| 93 |
+
|
| 94 |
+
@st.cache_data
|
| 95 |
+
def compute_disparity(_left, _right, _ndisp, _block_size, _p1m, _p2m):
|
| 96 |
+
gray_l = cv2.cvtColor(_left, cv2.COLOR_BGR2GRAY)
|
| 97 |
+
gray_r = cv2.cvtColor(_right, cv2.COLOR_BGR2GRAY)
|
| 98 |
+
|
| 99 |
+
# Align ndisp to 16
|
| 100 |
+
nd = max(16, (_ndisp // 16) * 16)
|
| 101 |
+
channels = 1
|
| 102 |
+
sgbm = cv2.StereoSGBM_create(
|
| 103 |
+
minDisparity=0,
|
| 104 |
+
numDisparities=nd,
|
| 105 |
+
blockSize=_block_size,
|
| 106 |
+
P1=_p1m * channels * _block_size ** 2,
|
| 107 |
+
P2=_p2m * channels * _block_size ** 2,
|
| 108 |
+
disp12MaxDiff=1,
|
| 109 |
+
uniquenessRatio=10,
|
| 110 |
+
speckleWindowSize=100,
|
| 111 |
+
speckleRange=32,
|
| 112 |
+
mode=cv2.STEREO_SGBM_MODE_SGBM_3WAY,
|
| 113 |
+
)
|
| 114 |
+
disp = sgbm.compute(gray_l, gray_r).astype(np.float32) / 16.0
|
| 115 |
+
return disp
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
with st.spinner("Computing disparity..."):
|
| 119 |
+
disp = compute_disparity(img_l, img_r, ndisp, block_size, p1_mult, p2_mult)
|
| 120 |
+
|
| 121 |
+
# Visualize disparity
|
| 122 |
+
disp_vis = disp.copy()
|
| 123 |
+
disp_vis[disp_vis <= 0] = 0
|
| 124 |
+
disp_max = disp_vis.max() if disp_vis.max() > 0 else 1.0
|
| 125 |
+
disp_norm = (disp_vis / disp_max * 255).astype(np.uint8)
|
| 126 |
+
disp_color = cv2.applyColorMap(disp_norm, cv2.COLORMAP_INFERNO)
|
| 127 |
+
|
| 128 |
+
dc1, dc2 = st.columns(2)
|
| 129 |
+
dc1.image(cv2.cvtColor(img_l, cv2.COLOR_BGR2RGB), caption="Left Image", use_container_width=True)
|
| 130 |
+
dc2.image(cv2.cvtColor(disp_color, cv2.COLOR_BGR2RGB), caption="Disparity Map (SGBM)", use_container_width=True)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
# ===================================================================
|
| 134 |
+
# Step 2 — Depth Map from Disparity
|
| 135 |
+
# ===================================================================
|
| 136 |
+
st.divider()
|
| 137 |
+
st.subheader("Step 2: Depth Map from Disparity")
|
| 138 |
+
|
| 139 |
+
st.latex(r"Z = \frac{f \times B}{d + d_{\text{offs}}}")
|
| 140 |
+
st.caption("Z = depth (mm), f = focal length (px), B = baseline (mm), d = disparity (px), d_offs = optical center offset (px)")
|
| 141 |
+
|
| 142 |
+
# Compute depth from disparity
|
| 143 |
+
valid = (disp + doffs) > 0
|
| 144 |
+
depth_map = np.zeros_like(disp)
|
| 145 |
+
depth_map[valid] = (focal * baseline) / (disp[valid] + doffs)
|
| 146 |
+
depth_map[~valid] = 0
|
| 147 |
+
|
| 148 |
+
# Visualize
|
| 149 |
+
depth_vis = depth_map.copy()
|
| 150 |
+
finite = depth_vis[depth_vis > 0]
|
| 151 |
+
if len(finite) > 0:
|
| 152 |
+
clip_max = np.percentile(finite, 98)
|
| 153 |
+
depth_vis = np.clip(depth_vis, 0, clip_max)
|
| 154 |
+
depth_norm = (depth_vis / clip_max * 255).astype(np.uint8)
|
| 155 |
+
else:
|
| 156 |
+
depth_norm = np.zeros_like(depth_map, dtype=np.uint8)
|
| 157 |
+
|
| 158 |
+
depth_color = cv2.applyColorMap(depth_norm, cv2.COLORMAP_TURBO)
|
| 159 |
+
|
| 160 |
+
zc1, zc2 = st.columns(2)
|
| 161 |
+
zc1.image(cv2.cvtColor(depth_color, cv2.COLOR_BGR2RGB),
|
| 162 |
+
caption="Estimated Depth (SGBM)", use_container_width=True)
|
| 163 |
+
|
| 164 |
+
# Ground truth comparison
|
| 165 |
+
if gt_left is not None:
|
| 166 |
+
gt_vis = gt_left.copy()
|
| 167 |
+
gt_finite = gt_vis[np.isfinite(gt_vis) & (gt_vis > 0)]
|
| 168 |
+
if len(gt_finite) > 0:
|
| 169 |
+
gt_clip = np.percentile(gt_finite, 98)
|
| 170 |
+
gt_vis = np.clip(np.nan_to_num(gt_vis, nan=0), 0, gt_clip)
|
| 171 |
+
gt_norm = (gt_vis / gt_clip * 255).astype(np.uint8)
|
| 172 |
+
else:
|
| 173 |
+
gt_norm = np.zeros_like(gt_vis, dtype=np.uint8)
|
| 174 |
+
gt_color = cv2.applyColorMap(gt_norm, cv2.COLORMAP_TURBO)
|
| 175 |
+
zc2.image(cv2.cvtColor(gt_color, cv2.COLOR_BGR2RGB),
|
| 176 |
+
caption="Ground Truth Depth", use_container_width=True)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
# ===================================================================
|
| 180 |
+
# Step 3 — Error Map (SGBM vs Ground Truth)
|
| 181 |
+
# ===================================================================
|
| 182 |
+
if gt_left is not None:
|
| 183 |
+
st.divider()
|
| 184 |
+
st.subheader("Step 3: Error Analysis (SGBM vs Ground Truth)")
|
| 185 |
+
|
| 186 |
+
# The GT is disparity in PFM for Middlebury; convert to depth for comparison
|
| 187 |
+
# Middlebury PFM stores DISPARITY, not depth. Let's handle both:
|
| 188 |
+
gt_disp = gt_left # Middlebury standard: PFM = disparity map
|
| 189 |
+
gt_depth_from_disp = np.zeros_like(gt_disp)
|
| 190 |
+
gt_valid = np.isfinite(gt_disp) & (gt_disp + doffs > 0) & (gt_disp != np.inf)
|
| 191 |
+
gt_depth_from_disp[gt_valid] = (focal * baseline) / (gt_disp[gt_valid] + doffs)
|
| 192 |
+
|
| 193 |
+
# Crop to common valid region
|
| 194 |
+
both_valid = valid & gt_valid
|
| 195 |
+
if both_valid.any():
|
| 196 |
+
# Disparity error
|
| 197 |
+
disp_err = np.abs(disp - gt_disp)
|
| 198 |
+
disp_err[~both_valid] = 0
|
| 199 |
+
|
| 200 |
+
# Stats
|
| 201 |
+
err_vals = disp_err[both_valid]
|
| 202 |
+
mae = float(np.mean(err_vals))
|
| 203 |
+
rmse = float(np.sqrt(np.mean(err_vals ** 2)))
|
| 204 |
+
bad_2 = float(np.mean(err_vals > 2.0)) * 100 # % of pixels with error > 2px
|
| 205 |
+
|
| 206 |
+
em1, em2, em3 = st.columns(3)
|
| 207 |
+
em1.metric("MAE (px)", f"{mae:.2f}")
|
| 208 |
+
em2.metric("RMSE (px)", f"{rmse:.2f}")
|
| 209 |
+
em3.metric("Bad-2.0 (%)", f"{bad_2:.1f}%")
|
| 210 |
+
|
| 211 |
+
# Error heatmap
|
| 212 |
+
err_clip = np.clip(disp_err, 0, 10)
|
| 213 |
+
err_norm = (err_clip / 10 * 255).astype(np.uint8)
|
| 214 |
+
err_color = cv2.applyColorMap(err_norm, cv2.COLORMAP_HOT)
|
| 215 |
+
st.image(cv2.cvtColor(err_color, cv2.COLOR_BGR2RGB),
|
| 216 |
+
caption="Disparity Error Map (red = high error, clipped at 10 px)",
|
| 217 |
+
use_container_width=True)
|
| 218 |
+
|
| 219 |
+
# Histogram
|
| 220 |
+
fig = go.Figure(data=[go.Histogram(x=err_vals, nbinsx=50,
|
| 221 |
+
marker_color="#ff6361")])
|
| 222 |
+
fig.update_layout(title="Disparity Error Distribution",
|
| 223 |
+
xaxis_title="Absolute Error (px)",
|
| 224 |
+
yaxis_title="Pixel Count",
|
| 225 |
+
template="plotly_dark", height=300)
|
| 226 |
+
st.plotly_chart(fig, use_container_width=True)
|
| 227 |
+
else:
|
| 228 |
+
st.warning("No overlapping valid pixels between SGBM disparity and ground truth.")
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
# ===================================================================
|
| 232 |
+
# Step 4 — Object Distance from Detections
|
| 233 |
+
# ===================================================================
|
| 234 |
+
st.divider()
|
| 235 |
+
st.subheader("Step 4: Object Distance Estimation")
|
| 236 |
+
|
| 237 |
+
all_dets = []
|
| 238 |
+
if rce_dets:
|
| 239 |
+
for d in rce_dets:
|
| 240 |
+
all_dets.append(("RCE", *d))
|
| 241 |
+
if cnn_dets:
|
| 242 |
+
for d in cnn_dets:
|
| 243 |
+
all_dets.append(("CNN", *d))
|
| 244 |
+
|
| 245 |
+
if not all_dets and crop_bbox is not None:
|
| 246 |
+
st.info("No detections from the Real-Time Detection page. Using the **crop bounding box on the left image** as a fallback.")
|
| 247 |
+
x0, y0, x1, y1 = crop_bbox
|
| 248 |
+
all_dets.append(("Crop (left)", x0, y0, x1, y1, "object", 1.0))
|
| 249 |
+
elif not all_dets:
|
| 250 |
+
st.warning("No detections found. Run **Real-Time Detection** first, or define a crop in **Data Lab**.")
|
| 251 |
+
st.stop()
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
# For each detection, compute median depth inside the bounding box
|
| 255 |
+
import pandas as pd
|
| 256 |
+
|
| 257 |
+
rows = []
|
| 258 |
+
det_overlay = img_l.copy() if all_dets and all_dets[0][0] == "Crop (left)" else img_r.copy()
|
| 259 |
+
|
| 260 |
+
for source, dx1, dy1, dx2, dy2, lbl, conf in all_dets:
|
| 261 |
+
dx1, dy1, dx2, dy2 = int(dx1), int(dy1), int(dx2), int(dy2)
|
| 262 |
+
|
| 263 |
+
# Clamp to image bounds
|
| 264 |
+
H, W = depth_map.shape[:2]
|
| 265 |
+
dx1c = max(0, min(dx1, W-1))
|
| 266 |
+
dy1c = max(0, min(dy1, H-1))
|
| 267 |
+
dx2c = max(0, min(dx2, W))
|
| 268 |
+
dy2c = max(0, min(dy2, H))
|
| 269 |
+
|
| 270 |
+
roi_depth = depth_map[dy1c:dy2c, dx1c:dx2c]
|
| 271 |
+
roi_disp = disp[dy1c:dy2c, dx1c:dx2c]
|
| 272 |
+
roi_valid = roi_depth[roi_depth > 0]
|
| 273 |
+
|
| 274 |
+
if len(roi_valid) > 0:
|
| 275 |
+
med_depth = float(np.median(roi_valid))
|
| 276 |
+
mean_depth = float(np.mean(roi_valid))
|
| 277 |
+
med_disp = float(np.median(roi_disp[roi_disp > 0])) if (roi_disp > 0).any() else 0
|
| 278 |
+
else:
|
| 279 |
+
med_depth = mean_depth = med_disp = 0.0
|
| 280 |
+
|
| 281 |
+
# Ground truth depth at this region (for comparison)
|
| 282 |
+
gt_depth_val = 0.0
|
| 283 |
+
if gt_left is not None:
|
| 284 |
+
gt_roi = gt_left[dy1c:dy2c, dx1c:dx2c]
|
| 285 |
+
gt_roi_valid = gt_roi[np.isfinite(gt_roi) & (gt_roi > 0)]
|
| 286 |
+
if len(gt_roi_valid) > 0:
|
| 287 |
+
# Convert GT disparity → depth
|
| 288 |
+
gt_med_disp = float(np.median(gt_roi_valid))
|
| 289 |
+
gt_depth_val = (focal * baseline) / (gt_med_disp + doffs) if (gt_med_disp + doffs) > 0 else 0
|
| 290 |
+
|
| 291 |
+
error_mm = abs(med_depth - gt_depth_val) if gt_depth_val > 0 else float('nan')
|
| 292 |
+
|
| 293 |
+
rows.append({
|
| 294 |
+
"Source": source,
|
| 295 |
+
"Box": f"({dx1},{dy1})→({dx2},{dy2})",
|
| 296 |
+
"Confidence": f"{conf:.1%}" if isinstance(conf, float) else str(conf),
|
| 297 |
+
"Med Disparity": f"{med_disp:.1f} px",
|
| 298 |
+
"Med Depth": f"{med_depth:.0f} mm",
|
| 299 |
+
"Mean Depth": f"{mean_depth:.0f} mm",
|
| 300 |
+
"GT Depth": f"{gt_depth_val:.0f} mm" if gt_depth_val > 0 else "N/A",
|
| 301 |
+
"Error": f"{error_mm:.0f} mm" if not np.isnan(error_mm) else "N/A",
|
| 302 |
+
})
|
| 303 |
+
|
| 304 |
+
# Draw on overlay
|
| 305 |
+
color = (0, 255, 0) if "RCE" in source else (0, 0, 255) if "CNN" in source else (255, 255, 0)
|
| 306 |
+
cv2.rectangle(det_overlay, (dx1c, dy1c), (dx2c, dy2c), color, 2)
|
| 307 |
+
depth_str = f"{med_depth/1000:.2f}m" if med_depth > 0 else "?"
|
| 308 |
+
cv2.putText(det_overlay, f"{source} {depth_str}",
|
| 309 |
+
(dx1c, dy1c - 6), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
|
| 310 |
+
|
| 311 |
+
# Show overlay
|
| 312 |
+
st.image(cv2.cvtColor(det_overlay, cv2.COLOR_BGR2RGB),
|
| 313 |
+
caption="Detections with Estimated Distance",
|
| 314 |
+
use_container_width=True)
|
| 315 |
+
|
| 316 |
+
# Table
|
| 317 |
+
st.dataframe(pd.DataFrame(rows), use_container_width=True, hide_index=True)
|
| 318 |
+
|
| 319 |
+
# Big metric cards for the best detection
|
| 320 |
+
if rows:
|
| 321 |
+
best = rows[0]
|
| 322 |
+
st.divider()
|
| 323 |
+
st.subheader("🎯 Primary Detection — Distance")
|
| 324 |
+
bc1, bc2, bc3 = st.columns(3)
|
| 325 |
+
bc1.metric("Estimated Depth", best["Med Depth"])
|
| 326 |
+
bc2.metric("Ground Truth", best["GT Depth"])
|
| 327 |
+
bc3.metric("Absolute Error", best["Error"])
|
src/models.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
src/models.py — Central Model Registry
|
| 3 |
+
=========================================
|
| 4 |
+
Downloads backbone weights **once** from the internet (PyTorch Hub / timm),
|
| 5 |
+
freezes every feature-extraction layer, and caches the result in RAM with
|
| 6 |
+
Streamlit's ``@st.cache_resource``.
|
| 7 |
+
|
| 8 |
+
Strategy
|
| 9 |
+
--------
|
| 10 |
+
1. **Freeze the Backbone** → ``requires_grad = False`` on every parameter.
|
| 11 |
+
The backbone is a pure feature extractor — no gradient updates, ever.
|
| 12 |
+
2. **Cache the Resource** → ``@st.cache_resource`` keeps the heavy model
|
| 13 |
+
in RAM even when you switch pages.
|
| 14 |
+
3. **Define the Head** → ``RecognitionHead``: a tiny sklearn
|
| 15 |
+
LogisticRegression that takes the backbone's feature vector and
|
| 16 |
+
produces a recognition score. Lives only in ``st.session_state``.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import streamlit as st
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
import torchvision.models as models
|
| 23 |
+
import torchvision.transforms as transforms
|
| 24 |
+
import timm
|
| 25 |
+
import cv2
|
| 26 |
+
import numpy as np
|
| 27 |
+
|
| 28 |
+
# ---------------------------------------------------------------------------
|
| 29 |
+
# Device selection (MPS > CUDA > CPU)
|
| 30 |
+
# ---------------------------------------------------------------------------
|
| 31 |
+
DEVICE = (
|
| 32 |
+
"mps" if torch.backends.mps.is_available() else
|
| 33 |
+
"cuda" if torch.cuda.is_available() else
|
| 34 |
+
"cpu"
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
# ---------------------------------------------------------------------------
|
| 38 |
+
# Shared ImageNet preprocessing
|
| 39 |
+
# ---------------------------------------------------------------------------
|
| 40 |
+
_IMAGENET_TRANSFORM = transforms.Compose([
|
| 41 |
+
transforms.ToPILImage(),
|
| 42 |
+
transforms.Resize((224, 224)),
|
| 43 |
+
transforms.ToTensor(),
|
| 44 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 45 |
+
std=[0.229, 0.224, 0.225]),
|
| 46 |
+
])
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# ===================================================================
|
| 50 |
+
# Base class
|
| 51 |
+
# ===================================================================
|
| 52 |
+
class _FrozenBackbone:
|
| 53 |
+
"""Shared helpers: freeze, normalise activation maps."""
|
| 54 |
+
|
| 55 |
+
DIM: int = 0 # overridden by subclasses
|
| 56 |
+
|
| 57 |
+
# --- freeze every parameter ---
|
| 58 |
+
def _freeze(self, model: nn.Module) -> nn.Module:
|
| 59 |
+
model.eval()
|
| 60 |
+
for p in model.parameters():
|
| 61 |
+
p.requires_grad = False
|
| 62 |
+
return model.to(DEVICE)
|
| 63 |
+
|
| 64 |
+
# --- public interface ---
|
| 65 |
+
def get_features(self, img_bgr: np.ndarray) -> np.ndarray:
|
| 66 |
+
"""Return a 1-D float32 feature vector for *img_bgr* (BGR uint8)."""
|
| 67 |
+
raise NotImplementedError
|
| 68 |
+
|
| 69 |
+
def get_activation_maps(self, img_bgr: np.ndarray,
|
| 70 |
+
n_maps: int = 6) -> list[np.ndarray]:
|
| 71 |
+
"""Return *n_maps* normalised float32 spatial activation maps."""
|
| 72 |
+
raise NotImplementedError
|
| 73 |
+
|
| 74 |
+
@staticmethod
|
| 75 |
+
def _norm(m: np.ndarray) -> np.ndarray:
|
| 76 |
+
lo, hi = m.min(), m.max()
|
| 77 |
+
return ((m - lo) / (hi - lo + 1e-5)).astype(np.float32)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
# ===================================================================
|
| 81 |
+
# ResNet-18
|
| 82 |
+
# ===================================================================
|
| 83 |
+
class ResNet18Backbone(_FrozenBackbone):
|
| 84 |
+
"""ResNet-18 downloaded from PyTorch Hub, frozen, classifier removed."""
|
| 85 |
+
|
| 86 |
+
DIM = 512
|
| 87 |
+
|
| 88 |
+
def __init__(self):
|
| 89 |
+
full = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
|
| 90 |
+
self.backbone = self._freeze(full)
|
| 91 |
+
self.extractor = nn.Sequential(*list(full.children())[:-1]).to(DEVICE)
|
| 92 |
+
self.transform = _IMAGENET_TRANSFORM
|
| 93 |
+
|
| 94 |
+
def get_features(self, img_bgr):
|
| 95 |
+
t = self.transform(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
|
| 96 |
+
with torch.no_grad():
|
| 97 |
+
return self.extractor(t.unsqueeze(0).to(DEVICE)).cpu().numpy().flatten()
|
| 98 |
+
|
| 99 |
+
def get_activation_maps(self, img_bgr, n_maps=6):
|
| 100 |
+
cap = {}
|
| 101 |
+
hook = self.backbone.layer4.register_forward_hook(
|
| 102 |
+
lambda _m, _i, o: cap.update(feat=o))
|
| 103 |
+
t = self.transform(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
|
| 104 |
+
with torch.no_grad():
|
| 105 |
+
self.backbone(t.unsqueeze(0).to(DEVICE))
|
| 106 |
+
hook.remove()
|
| 107 |
+
acts = cap["feat"][0].cpu().numpy()
|
| 108 |
+
return [self._norm(acts[i]) for i in range(min(n_maps, acts.shape[0]))]
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
# ===================================================================
|
| 112 |
+
# MobileNetV3-Small
|
| 113 |
+
# ===================================================================
|
| 114 |
+
class MobileNetV3Backbone(_FrozenBackbone):
|
| 115 |
+
"""MobileNetV3-Small from PyTorch Hub, frozen, classifier = Identity."""
|
| 116 |
+
|
| 117 |
+
DIM = 576
|
| 118 |
+
|
| 119 |
+
def __init__(self):
|
| 120 |
+
self.backbone = models.mobilenet_v3_small(
|
| 121 |
+
weights=models.MobileNet_V3_Small_Weights.DEFAULT)
|
| 122 |
+
self.backbone.classifier = nn.Identity()
|
| 123 |
+
self._freeze(self.backbone)
|
| 124 |
+
self.transform = _IMAGENET_TRANSFORM
|
| 125 |
+
|
| 126 |
+
def get_features(self, img_bgr):
|
| 127 |
+
t = self.transform(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
|
| 128 |
+
with torch.no_grad():
|
| 129 |
+
return self.backbone(t.unsqueeze(0).to(DEVICE)).cpu().numpy().flatten()
|
| 130 |
+
|
| 131 |
+
def get_activation_maps(self, img_bgr, n_maps=6):
|
| 132 |
+
cap = {}
|
| 133 |
+
hook = self.backbone.features[-1].register_forward_hook(
|
| 134 |
+
lambda _m, _i, o: cap.update(feat=o))
|
| 135 |
+
t = self.transform(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
|
| 136 |
+
with torch.no_grad():
|
| 137 |
+
self.backbone(t.unsqueeze(0).to(DEVICE))
|
| 138 |
+
hook.remove()
|
| 139 |
+
acts = cap["feat"][0].cpu().numpy()
|
| 140 |
+
return [self._norm(acts[i]) for i in range(min(n_maps, acts.shape[0]))]
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
# ===================================================================
|
| 144 |
+
# MobileViT-XXS
|
| 145 |
+
# ===================================================================
|
| 146 |
+
class MobileViTBackbone(_FrozenBackbone):
|
| 147 |
+
"""MobileViT-XXS from timm (Apple Research), frozen."""
|
| 148 |
+
|
| 149 |
+
DIM = 320
|
| 150 |
+
|
| 151 |
+
def __init__(self):
|
| 152 |
+
self.backbone = timm.create_model(
|
| 153 |
+
"mobilevit_xxs.cvnets_in1k", pretrained=True, num_classes=0)
|
| 154 |
+
self._freeze(self.backbone)
|
| 155 |
+
cfg = timm.data.resolve_model_data_config(self.backbone)
|
| 156 |
+
self.transform = timm.data.create_transform(**cfg, is_training=False)
|
| 157 |
+
|
| 158 |
+
def _to_tensor(self, img_bgr):
|
| 159 |
+
from PIL import Image
|
| 160 |
+
pil = Image.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
|
| 161 |
+
return self.transform(pil).unsqueeze(0).to(DEVICE)
|
| 162 |
+
|
| 163 |
+
def get_features(self, img_bgr):
|
| 164 |
+
with torch.no_grad():
|
| 165 |
+
return self.backbone(self._to_tensor(img_bgr)).cpu().numpy().flatten()
|
| 166 |
+
|
| 167 |
+
def get_activation_maps(self, img_bgr, n_maps=6):
|
| 168 |
+
cap = {}
|
| 169 |
+
hook = self.backbone.stages[-1].register_forward_hook(
|
| 170 |
+
lambda _m, _i, o: cap.update(feat=o))
|
| 171 |
+
with torch.no_grad():
|
| 172 |
+
self.backbone(self._to_tensor(img_bgr))
|
| 173 |
+
hook.remove()
|
| 174 |
+
acts = cap["feat"][0].cpu().numpy()
|
| 175 |
+
return [self._norm(acts[i]) for i in range(min(n_maps, acts.shape[0]))]
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
# ===================================================================
|
| 179 |
+
# Lightweight Head (lives in session state, never on disk)
|
| 180 |
+
# ===================================================================
|
| 181 |
+
class RecognitionHead:
|
| 182 |
+
"""
|
| 183 |
+
A tiny trainable layer on top of a frozen backbone.
|
| 184 |
+
Wraps sklearn ``LogisticRegression`` for binary classification.
|
| 185 |
+
Stored in ``st.session_state`` — never saved to disk.
|
| 186 |
+
"""
|
| 187 |
+
|
| 188 |
+
def __init__(self, C: float = 1.0, max_iter: int = 1000):
|
| 189 |
+
from sklearn.linear_model import LogisticRegression
|
| 190 |
+
self.model = LogisticRegression(C=C, max_iter=max_iter)
|
| 191 |
+
self.is_trained = False
|
| 192 |
+
|
| 193 |
+
def fit(self, X, y):
|
| 194 |
+
self.model.fit(X, y)
|
| 195 |
+
self.is_trained = True
|
| 196 |
+
return self
|
| 197 |
+
|
| 198 |
+
def predict(self, features: np.ndarray):
|
| 199 |
+
"""Return *(label, confidence)* for a single feature vector."""
|
| 200 |
+
probs = self.model.predict_proba([features])[0]
|
| 201 |
+
idx = int(np.argmax(probs))
|
| 202 |
+
return self.model.classes_[idx], probs[idx]
|
| 203 |
+
|
| 204 |
+
def predict_proba(self, X):
|
| 205 |
+
return self.model.predict_proba(X)
|
| 206 |
+
|
| 207 |
+
@property
|
| 208 |
+
def classes_(self):
|
| 209 |
+
return self.model.classes_
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
# ===================================================================
|
| 213 |
+
# Cached loaders — @st.cache_resource keeps models in RAM
|
| 214 |
+
# ===================================================================
|
| 215 |
+
@st.cache_resource
|
| 216 |
+
def get_resnet() -> ResNet18Backbone:
|
| 217 |
+
"""Download & freeze ResNet-18. Stays in RAM across page switches."""
|
| 218 |
+
return ResNet18Backbone()
|
| 219 |
+
|
| 220 |
+
@st.cache_resource
|
| 221 |
+
def get_mobilenet() -> MobileNetV3Backbone:
|
| 222 |
+
"""Download & freeze MobileNetV3-Small. Stays in RAM."""
|
| 223 |
+
return MobileNetV3Backbone()
|
| 224 |
+
|
| 225 |
+
@st.cache_resource
|
| 226 |
+
def get_mobilevit() -> MobileViTBackbone:
|
| 227 |
+
"""Download & freeze MobileViT-XXS. Stays in RAM."""
|
| 228 |
+
return MobileViTBackbone()
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
# ===================================================================
|
| 232 |
+
# BACKBONES — The Registry Dict
|
| 233 |
+
# ===================================================================
|
| 234 |
+
BACKBONES = {
|
| 235 |
+
"ResNet-18": {
|
| 236 |
+
"loader": get_resnet,
|
| 237 |
+
"dim": ResNet18Backbone.DIM,
|
| 238 |
+
"hook_layer": "layer4 (last conv block)",
|
| 239 |
+
},
|
| 240 |
+
"MobileNetV3": {
|
| 241 |
+
"loader": get_mobilenet,
|
| 242 |
+
"dim": MobileNetV3Backbone.DIM,
|
| 243 |
+
"hook_layer": "features[-1] (last features block)",
|
| 244 |
+
},
|
| 245 |
+
"MobileViT-XXS": {
|
| 246 |
+
"loader": get_mobilevit,
|
| 247 |
+
"dim": MobileViTBackbone.DIM,
|
| 248 |
+
"hook_layer": "stages[-1] (last transformer stage)",
|
| 249 |
+
},
|
| 250 |
+
}
|