Tan Zi Xu commited on
Commit ·
dd63ee6
1
Parent(s): a84e477
reflect detection model conf in crop tracking tab
Browse files- core/drawing.py +71 -0
- ui/workflow.py +217 -143
core/drawing.py
CHANGED
|
@@ -40,3 +40,74 @@ def boxes_to_canvas_json(boxes, scale: float = 1.0, stroke_color: str = "#FF9900
|
|
| 40 |
})
|
| 41 |
return {"objects": objects}
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
})
|
| 41 |
return {"objects": objects}
|
| 42 |
|
| 43 |
+
def rects_only_json(j: dict) -> dict:
|
| 44 |
+
"""Return a JSON with only rectangle objects (drop labels/others)."""
|
| 45 |
+
if not j:
|
| 46 |
+
return {"objects": []}
|
| 47 |
+
return {"objects": [o for o in (j.get("objects") or []) if o.get("type") == "rect"]}
|
| 48 |
+
|
| 49 |
+
def get_rect_conf(rect_obj: dict) -> float | None:
|
| 50 |
+
"""Read detector confidence from canvas rect's custom metadata, if any."""
|
| 51 |
+
meta = rect_obj.get("meta") or {}
|
| 52 |
+
c = meta.get("det_conf", None)
|
| 53 |
+
try:
|
| 54 |
+
return float(c) if c is not None else None
|
| 55 |
+
except Exception:
|
| 56 |
+
return None
|
| 57 |
+
|
| 58 |
+
def inject_index_labels(live_json: dict, offset_px: int = 4, show_conf: bool = True):
|
| 59 |
+
"""
|
| 60 |
+
Return a copy of live_json with small text labels injected at top-left of each rect.
|
| 61 |
+
If a rect has meta.det_conf, label becomes 'idx · 0.82' when show_conf=True.
|
| 62 |
+
"""
|
| 63 |
+
if not live_json:
|
| 64 |
+
return {"objects": []}
|
| 65 |
+
objs = list(live_json.get("objects") or [])
|
| 66 |
+
labeled, idx = [], 1
|
| 67 |
+
for o in objs:
|
| 68 |
+
labeled.append(o)
|
| 69 |
+
if o.get("type") == "rect":
|
| 70 |
+
text = f"{idx}"
|
| 71 |
+
if show_conf:
|
| 72 |
+
conf = get_rect_conf(o)
|
| 73 |
+
if conf is not None:
|
| 74 |
+
text = f"{idx} · {conf:.2f}"
|
| 75 |
+
left = float(o.get("left", 0.0)) + offset_px
|
| 76 |
+
top = float(o.get("top", 0.0)) + offset_px
|
| 77 |
+
labeled.append({
|
| 78 |
+
"type": "textbox", "text": text,
|
| 79 |
+
"left": left, "top": top,
|
| 80 |
+
"width": max(24, 10*len(text)), "height": 20,
|
| 81 |
+
"fontSize": 14, "fill": "#ffffff", "backgroundColor": "#111827",
|
| 82 |
+
"opacity": 0.9, "selectable": False, "evented": False,
|
| 83 |
+
"hasControls": False, "editable": False, "textAlign": "center",
|
| 84 |
+
"fontFamily": "sans-serif",
|
| 85 |
+
})
|
| 86 |
+
idx += 1
|
| 87 |
+
return {"objects": labeled}
|
| 88 |
+
|
| 89 |
+
def seed_canvas_from_boxes(boxes_xyxy, scale: float, det_scores=None):
|
| 90 |
+
"""
|
| 91 |
+
Build a canvas JSON (rect objects) from xyxy boxes at given display scale.
|
| 92 |
+
If det_scores provided, store them as rect.meta['det_conf'].
|
| 93 |
+
"""
|
| 94 |
+
objects = []
|
| 95 |
+
det_scores = det_scores or [None] * len(boxes_xyxy)
|
| 96 |
+
for (x0, y0, x1, y1), sc in zip(boxes_xyxy, det_scores):
|
| 97 |
+
left = float(x0) * scale
|
| 98 |
+
top = float(y0) * scale
|
| 99 |
+
width = float(x1 - x0) * scale
|
| 100 |
+
height = float(y1 - y0) * scale
|
| 101 |
+
meta = {"origin": "det"}
|
| 102 |
+
if sc is not None:
|
| 103 |
+
try:
|
| 104 |
+
meta["det_conf"] = float(sc)
|
| 105 |
+
except Exception:
|
| 106 |
+
pass
|
| 107 |
+
objects.append({
|
| 108 |
+
"type": "rect",
|
| 109 |
+
"left": left, "top": top, "width": width, "height": height,
|
| 110 |
+
"strokeWidth": 3, "stroke": "#FF9900", "fill": "rgba(0,0,0,0)",
|
| 111 |
+
"angle": 0, "meta": meta
|
| 112 |
+
})
|
| 113 |
+
return {"objects": objects}
|
ui/workflow.py
CHANGED
|
@@ -1,17 +1,21 @@
|
|
| 1 |
# ui/workflow.py
|
| 2 |
-
import io
|
| 3 |
import streamlit as st
|
| 4 |
-
import numpy as np
|
| 5 |
from PIL import Image, ImageOps
|
| 6 |
-
import uuid
|
| 7 |
from streamlit_drawable_canvas import st_canvas
|
| 8 |
-
from core.drawing import
|
| 9 |
-
from
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
from core.detect_infer import DetConfig
|
| 12 |
import torch
|
| 13 |
|
| 14 |
-
# ---- helpers ----
|
| 15 |
def _auto_canvas_width():
|
| 16 |
try:
|
| 17 |
from streamlit_js_eval import get_page_info
|
|
@@ -31,17 +35,50 @@ def _det_cfg_from_state():
|
|
| 31 |
half=bool(c.get("half", True)),
|
| 32 |
)
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
def _run_detection(item):
|
| 35 |
det = st.session_state.get("detector")
|
| 36 |
-
if det is None:
|
|
|
|
| 37 |
cfg = _det_cfg_from_state()
|
| 38 |
-
return det.predict_one(item["pil"], cfg) or [] # list of dicts {'bbox':[x1,y1,x2,y2], 'score',
|
| 39 |
-
|
| 40 |
-
def _seed_canvas_from_boxes(item, disp_w, base_w):
|
| 41 |
-
"""Write boxes to canvas JSON at display scale (non-destructive)."""
|
| 42 |
-
if not item.get("boxes"): return
|
| 43 |
-
scale = float(disp_w) / float(base_w)
|
| 44 |
-
item["canvas_json"] = boxes_to_canvas_json(item["boxes"], scale=scale)
|
| 45 |
|
| 46 |
def _parse_canvas_into_boxes(item, scale, W, H):
|
| 47 |
"""Read current canvas JSON back into xyxy boxes on the original image grid."""
|
|
@@ -50,79 +87,36 @@ def _parse_canvas_into_boxes(item, scale, W, H):
|
|
| 50 |
|
| 51 |
def _classify_now(item, predict_fn):
|
| 52 |
"""Run classifier on current boxes, then sync to session samples."""
|
| 53 |
-
if not item.get("boxes"):
|
| 54 |
item["preds"] = []; item["user_labels"] = []; item["actions"] = []
|
| 55 |
sync_samples_with_state(); return
|
| 56 |
crops = [item["pil"].crop(b) for b in item["boxes"]]
|
| 57 |
with torch.inference_mode():
|
| 58 |
item["preds"] = predict_fn(crops, topk=3)
|
|
|
|
| 59 |
set_defaults_from_preds(item)
|
| 60 |
item["actions"] = ["pending"] * len(item["boxes"])
|
| 61 |
sync_samples_with_state()
|
| 62 |
|
| 63 |
def _bump_on_image_switch(ak: str, item: dict):
|
| 64 |
-
"""Force a canvas remount the
|
| 65 |
if st.session_state.get("_last_active_key") != ak:
|
| 66 |
item["canvas_rev"] = (item.get("canvas_rev", 0) + 1)
|
| 67 |
st.session_state["_last_active_key"] = ak
|
| 68 |
|
| 69 |
-
|
| 70 |
-
"""
|
| 71 |
-
Return a copy of live_json with non-selectable numeric labels injected
|
| 72 |
-
for each rectangle in the order they appear.
|
| 73 |
-
Labels are Fabric 'textbox' objects and will be ignored by parse_boxes().
|
| 74 |
-
"""
|
| 75 |
-
if not live_json:
|
| 76 |
-
return {"objects": []}
|
| 77 |
-
objs = list(live_json.get("objects") or [])
|
| 78 |
-
labeled = []
|
| 79 |
-
idx = 1
|
| 80 |
-
for o in objs:
|
| 81 |
-
labeled.append(o)
|
| 82 |
-
if o.get("type") == "rect":
|
| 83 |
-
# Place label near rect's top-left in canvas coords
|
| 84 |
-
left = float(o.get("left", 0.0)) + offset_px
|
| 85 |
-
top = float(o.get("top", 0.0)) + offset_px
|
| 86 |
-
label = {
|
| 87 |
-
"type": "textbox",
|
| 88 |
-
"text": str(idx),
|
| 89 |
-
"left": left,
|
| 90 |
-
"top": top,
|
| 91 |
-
"width": 24,
|
| 92 |
-
"height": 20,
|
| 93 |
-
"fontSize": 16,
|
| 94 |
-
"fill": "#ffffff",
|
| 95 |
-
"backgroundColor": "#111827", # slate-900
|
| 96 |
-
"opacity": 0.85,
|
| 97 |
-
"selectable": False,
|
| 98 |
-
"evented": False,
|
| 99 |
-
"hasControls": False,
|
| 100 |
-
"editable": False,
|
| 101 |
-
"textAlign": "center",
|
| 102 |
-
"fontFamily": "sans-serif",
|
| 103 |
-
}
|
| 104 |
-
labeled.append(label)
|
| 105 |
-
idx += 1
|
| 106 |
-
return {"objects": labeled}
|
| 107 |
-
|
| 108 |
-
def rects_only_json(j: dict) -> dict:
|
| 109 |
-
"""Return a JSON with only rectangle objects (drop labels/others)."""
|
| 110 |
-
if not j:
|
| 111 |
-
return {"objects": []}
|
| 112 |
-
return {"objects": [o for o in (j.get("objects") or []) if o.get("type") == "rect"]}
|
| 113 |
-
|
| 114 |
-
# ---- main UI ----
|
| 115 |
def render_workflow_tab(predict_fn, class_names, class_to_id, model_meta, export_root):
|
| 116 |
CANVAS_W = _auto_canvas_width()
|
| 117 |
if "uploader_rev" not in st.session_state:
|
| 118 |
st.session_state.uploader_rev = 0
|
|
|
|
| 119 |
# =============== (1) Load images ===============
|
| 120 |
st.subheader("1) Load image(s)")
|
| 121 |
ups = st.file_uploader(
|
| 122 |
"Upload one or more images",
|
| 123 |
type=["jpg", "jpeg", "png", "bmp"],
|
| 124 |
accept_multiple_files=True,
|
| 125 |
-
key=f"uploader_{st.session_state.uploader_rev}" #
|
| 126 |
)
|
| 127 |
for up in (ups or []):
|
| 128 |
key = f"{up.name}-{up.size}"
|
|
@@ -131,7 +125,8 @@ def render_workflow_tab(predict_fn, class_names, class_to_id, model_meta, export
|
|
| 131 |
st.session_state.images[key] = {
|
| 132 |
"key": key, "name": up.name, "pil": pil,
|
| 133 |
"boxes": [], "preds": [], "user_labels": [], "actions": [],
|
| 134 |
-
"
|
|
|
|
| 135 |
"canvas_rev": 0,
|
| 136 |
}
|
| 137 |
if st.session_state.active_key is None:
|
|
@@ -142,7 +137,6 @@ def render_workflow_tab(predict_fn, class_names, class_to_id, model_meta, export
|
|
| 142 |
st.info("Upload images or load from S3/Drive in the sidebar.")
|
| 143 |
return
|
| 144 |
|
| 145 |
-
|
| 146 |
names = [st.session_state.images[k]["name"] for k in keys]
|
| 147 |
idx = keys.index(st.session_state.active_key) if st.session_state.active_key in keys else 0
|
| 148 |
chosen = st.selectbox("Active image", names, index=idx)
|
|
@@ -161,16 +155,28 @@ def render_workflow_tab(predict_fn, class_names, class_to_id, model_meta, export
|
|
| 161 |
if st.session_state.assist_plus and not item.get("boxes") and not item.get("_skip_autodetect_once"):
|
| 162 |
det_out = _run_detection(item)
|
| 163 |
if det_out:
|
| 164 |
-
|
| 165 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
item["canvas_mode"] = "edit"
|
| 167 |
-
item
|
| 168 |
-
item["_skip_autodetect_once"] = False
|
| 169 |
-
_parse_canvas_into_boxes(item, scale, base.width, base.height) # ensure boxes align with canvas JSON
|
| 170 |
_classify_now(item, predict_fn)
|
| 171 |
st.rerun()
|
| 172 |
|
| 173 |
-
|
| 174 |
bg_pil = ImageOps.exif_transpose(base).resize((disp_w, disp_h), Image.BILINEAR).convert("RGB").copy()
|
| 175 |
|
| 176 |
# 2) Draw / edit rectangles
|
|
@@ -197,83 +203,124 @@ def render_workflow_tab(predict_fn, class_names, class_to_id, model_meta, export
|
|
| 197 |
else:
|
| 198 |
st.caption("Tip: Drag on the image to draw new boxes.")
|
| 199 |
|
| 200 |
-
# ----------
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
display_json = _inject_index_labels(base_json)
|
| 205 |
|
| 206 |
# ---------- Canvas ----------
|
| 207 |
canvas = st_canvas(
|
| 208 |
fill_color="rgba(0,0,0,0)",
|
| 209 |
stroke_width=3,
|
| 210 |
stroke_color="#FF9900",
|
| 211 |
-
background_image=bg_pil,
|
| 212 |
update_streamlit=True,
|
| 213 |
height=disp_h,
|
| 214 |
width=disp_w,
|
| 215 |
-
drawing_mode=drawing_mode,
|
| 216 |
display_toolbar=True,
|
| 217 |
-
initial_drawing=display_json,
|
| 218 |
key=f"canvas_{ak}_{item.get('canvas_rev', 0)}_{disp_w}x{disp_h}"
|
| 219 |
)
|
| 220 |
|
| 221 |
-
# ---------- Live canvas snapshot
|
| 222 |
live_json = canvas.json_data or base_json
|
| 223 |
objs = list(live_json.get("objects") or [])
|
| 224 |
-
live_rects_only = rects_only_json(live_json) # keep only rectangles
|
| 225 |
-
|
| 226 |
-
# Re-index labels whenever crop count changes (e.g., a new rectangle drawn)
|
| 227 |
-
prev_cnt = item.get("_prev_rect_count", None)
|
| 228 |
-
cur_cnt = len(live_rects_only["objects"])
|
| 229 |
-
if prev_cnt is None:
|
| 230 |
-
item["_prev_rect_count"] = cur_cnt
|
| 231 |
-
elif cur_cnt != prev_cnt:
|
| 232 |
-
# persist rects-only, bump rev to redraw with fresh labels 1..N
|
| 233 |
-
item["canvas_json"] = live_rects_only
|
| 234 |
-
item["canvas_rev"] = (item.get("canvas_rev", 0) + 1)
|
| 235 |
-
item["_prev_rect_count"] = cur_cnt
|
| 236 |
-
st.rerun()
|
| 237 |
|
| 238 |
-
#
|
| 239 |
rect_indices = [k for k, o in enumerate(objs) if o.get("type") == "rect"]
|
|
|
|
| 240 |
|
| 241 |
-
# Live boxes
|
| 242 |
live_boxes = parse_boxes(live_json, scale, base.width, base.height)
|
| 243 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
# ---- Crops on canvas (live) ----
|
| 245 |
if live_boxes:
|
| 246 |
with st.expander(f"Crops on canvas (live) — {len(live_boxes)}", expanded=False):
|
| 247 |
for i, b in enumerate(live_boxes):
|
| 248 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
with c1:
|
| 250 |
st.write(f"#{i+1} [x0={b[0]}, y0={b[1]}, x1={b[2]}, y1={b[3]}]")
|
|
|
|
| 251 |
with c2:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
if st.button("Remove", key=f"rm_live_{ak}_{i}"):
|
| 253 |
-
#
|
| 254 |
-
|
| 255 |
-
if 0 <=
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
item["canvas_mode"] = "edit"
|
| 260 |
item["canvas_rev"] = (item.get("canvas_rev", 0) + 1)
|
| 261 |
-
# predictions are stale after geometry changes
|
| 262 |
item["preds"] = []; item["user_labels"] = []; item["actions"] = []
|
| 263 |
sync_samples_with_state()
|
| 264 |
st.rerun()
|
| 265 |
|
| 266 |
-
|
| 267 |
-
# Keep the on-screen count truly live
|
| 268 |
st.caption(f"Crops in this image: **{len(live_boxes)}**")
|
| 269 |
|
| 270 |
-
# Commit canvas into boxes (
|
| 271 |
-
# (was: Save/Lock, Clear, Remove). Now: Clear + Remove only.
|
| 272 |
col_clear, col_remove = st.columns([1, 1])
|
| 273 |
|
| 274 |
if col_clear.button("Clear boxes (THIS image)", key=f"clear_{ak}"):
|
| 275 |
item["canvas_json"] = {"objects": []}
|
| 276 |
item["boxes"] = []; item["preds"] = []; item["user_labels"] = []; item["actions"] = []
|
|
|
|
| 277 |
item["canvas_mode"] = "draw"
|
| 278 |
item["canvas_rev"] = (item.get("canvas_rev", 0) + 1)
|
| 279 |
item["_skip_autodetect_once"] = True
|
|
@@ -281,45 +328,32 @@ def render_workflow_tab(predict_fn, class_names, class_to_id, model_meta, export
|
|
| 281 |
st.rerun()
|
| 282 |
|
| 283 |
if col_remove.button("Remove image", key=f"remove_{ak}"):
|
| 284 |
-
# Choose the next image BEFORE deletion
|
| 285 |
remaining_keys = [k for k in st.session_state.images.keys() if k != ak]
|
| 286 |
next_key = remaining_keys[0] if remaining_keys else None
|
| 287 |
-
|
| 288 |
-
# Delete the image and its samples
|
| 289 |
del st.session_state.images[ak]
|
| 290 |
st.session_state.all_samples = [s for s in st.session_state.all_samples if s.image_key != ak]
|
| 291 |
-
|
| 292 |
-
# Clear the uploader so removed files don't get re-added on rerun
|
| 293 |
st.session_state.uploader_rev += 1
|
| 294 |
-
|
| 295 |
-
# Switch to the next image (if any) but DO NOT autorun detection once
|
| 296 |
st.session_state.active_key = next_key
|
| 297 |
if next_key:
|
| 298 |
nxt = st.session_state.images[next_key]
|
| 299 |
nxt["_skip_autodetect_once"] = True
|
| 300 |
nxt["canvas_mode"] = "edit" if nxt.get("boxes") else "draw"
|
| 301 |
-
nxt["canvas_rev"] = nxt.get("canvas_rev", 0) + 1
|
| 302 |
-
|
| 303 |
st.rerun()
|
| 304 |
|
| 305 |
-
|
| 306 |
-
|
| 307 |
# =============== (3) Detect / Classify controls ===============
|
| 308 |
st.subheader("3) Detect & Classify")
|
| 309 |
|
| 310 |
assist_on = bool(st.session_state.get("assist_plus"))
|
| 311 |
has_boxes = bool(item.get("boxes"))
|
| 312 |
|
| 313 |
-
# detect whether the user has drawn rectangles but not saved yet
|
| 314 |
canvas_objects = (canvas.json_data or {}).get("objects") or []
|
| 315 |
has_canvas_rects = any(obj.get("type") == "rect" for obj in canvas_objects)
|
| 316 |
can_classify = has_boxes or has_canvas_rects
|
| 317 |
|
| 318 |
# ---------- DETECT ----------
|
| 319 |
if assist_on:
|
| 320 |
-
# Only show detect controls in Assist+ mode
|
| 321 |
st.caption("Detect: find objects and write boxes onto the canvas.")
|
| 322 |
-
|
| 323 |
dcol1, dcol2 = st.columns([1, 1])
|
| 324 |
|
| 325 |
# First-time Detect (hidden if Assist+ already produced boxes)
|
|
@@ -332,8 +366,22 @@ def render_workflow_tab(predict_fn, class_names, class_to_id, model_meta, export
|
|
| 332 |
):
|
| 333 |
det_out = _run_detection(item)
|
| 334 |
if det_out:
|
| 335 |
-
|
| 336 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 337 |
item["canvas_mode"] = "edit"
|
| 338 |
item["canvas_rev"] = (item.get("canvas_rev", 0) + 1)
|
| 339 |
item["preds"] = []; item["user_labels"] = []; item["actions"] = []
|
|
@@ -358,14 +406,34 @@ def render_workflow_tab(predict_fn, class_names, class_to_id, model_meta, export
|
|
| 358 |
):
|
| 359 |
det_out = _run_detection(item)
|
| 360 |
new_boxes = [tuple(int(v) for v in o["bbox"]) for o in det_out]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 361 |
if new_boxes:
|
| 362 |
if rd_mode == "Replace":
|
| 363 |
item["boxes"] = new_boxes
|
|
|
|
|
|
|
|
|
|
| 364 |
else:
|
| 365 |
-
|
| 366 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 367 |
item["canvas_mode"] = "edit"
|
| 368 |
-
item["canvas_rev"] = (item.get("canvas_rev", 0) + 1)
|
| 369 |
item["preds"] = []; item["user_labels"] = []; item["actions"] = []
|
| 370 |
sync_samples_with_state()
|
| 371 |
st.rerun()
|
|
@@ -396,27 +464,27 @@ def render_workflow_tab(predict_fn, class_names, class_to_id, model_meta, export
|
|
| 396 |
_classify_now(item, predict_fn)
|
| 397 |
st.rerun()
|
| 398 |
|
| 399 |
-
|
| 400 |
-
# Quick count
|
| 401 |
-
st.caption(f"Crops in this image: **{len(live_boxes)}**")
|
| 402 |
-
|
| 403 |
_render_active_summary(class_names)
|
| 404 |
|
| 405 |
def _render_active_summary(class_names):
|
| 406 |
st.subheader("Summary for Active Image")
|
| 407 |
-
ak=st.session_state.active_key
|
| 408 |
if not st.session_state.all_samples or not ak:
|
| 409 |
-
st.info("No samples available. Run predictions on images first.")
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 413 |
for sample in active:
|
| 414 |
-
crop_image=img_data["pil"].crop(sample.bbox)
|
| 415 |
-
col_img,col_details,col_relabel=st.columns([1,2,1.5])
|
| 416 |
with col_img:
|
| 417 |
st.image(crop_image, caption=f"Crop {sample.crop_idx+1}", width=180)
|
| 418 |
with col_details:
|
| 419 |
-
top3_str=", ".join([f"{n} ({p:.3f})" for n,p in sample.top3[:3]])
|
| 420 |
st.write(f"**Current Label:** {sample.current_label}")
|
| 421 |
st.write(f"**Top-3:** {top3_str}")
|
| 422 |
st.write(f"**Confidence:** {sample.confidence:.3f}")
|
|
@@ -430,10 +498,16 @@ def _render_active_summary(class_names):
|
|
| 430 |
idx = class_names.index(sample.current_label) if sample.current_label in class_names else 0
|
| 431 |
except Exception:
|
| 432 |
idx = 0
|
| 433 |
-
new_label = st.selectbox(
|
| 434 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 435 |
if new_label != sample.current_label:
|
| 436 |
-
sample.user_label
|
|
|
|
| 437 |
from core.state import update_image_state_from_samples
|
| 438 |
-
update_image_state_from_samples()
|
|
|
|
| 439 |
st.divider()
|
|
|
|
| 1 |
# ui/workflow.py
|
|
|
|
| 2 |
import streamlit as st
|
|
|
|
| 3 |
from PIL import Image, ImageOps
|
|
|
|
| 4 |
from streamlit_drawable_canvas import st_canvas
|
| 5 |
+
from core.drawing import (
|
| 6 |
+
parse_boxes, # from canvas JSON -> list[xyxy]
|
| 7 |
+
boxes_to_canvas_json, # (kept for compatibility; not used here)
|
| 8 |
+
rects_only_json,
|
| 9 |
+
get_rect_conf,
|
| 10 |
+
inject_index_labels,
|
| 11 |
+
seed_canvas_from_boxes, # list[xyxy] -> canvas JSON
|
| 12 |
+
)
|
| 13 |
+
from core.state import set_defaults_from_preds, sync_samples_with_state
|
| 14 |
+
from core.exports import export_session # not used here but retained
|
| 15 |
from core.detect_infer import DetConfig
|
| 16 |
import torch
|
| 17 |
|
| 18 |
+
# ---------------- helpers ----------------
|
| 19 |
def _auto_canvas_width():
|
| 20 |
try:
|
| 21 |
from streamlit_js_eval import get_page_info
|
|
|
|
| 35 |
half=bool(c.get("half", True)),
|
| 36 |
)
|
| 37 |
|
| 38 |
+
def _iou_xyxy(a, b, eps=1e-6):
|
| 39 |
+
ax1, ay1, ax2, ay2 = map(float, a)
|
| 40 |
+
bx1, by1, bx2, by2 = map(float, b)
|
| 41 |
+
ix1, iy1 = max(ax1, bx1), max(ay1, by1)
|
| 42 |
+
ix2, iy2 = min(ax2, bx2), min(ay2, by2)
|
| 43 |
+
iw, ih = max(0.0, ix2 - ix1), max(0.0, iy2 - iy1)
|
| 44 |
+
inter = iw * ih
|
| 45 |
+
aa = max(0.0, (ax2 - ax1)) * max(0.0, (ay2 - ay1))
|
| 46 |
+
ba = max(0.0, (bx2 - bx1)) * max(0.0, (by2 - by1))
|
| 47 |
+
union = aa + ba - inter + eps
|
| 48 |
+
return inter / union
|
| 49 |
+
|
| 50 |
+
def _match_det_info_to_live_boxes(item, live_boxes, iou_thresh=0.90):
|
| 51 |
+
"""Align per-box detector info (conf/cls/name) to current live boxes."""
|
| 52 |
+
base_boxes = item.get("boxes") or []
|
| 53 |
+
info = item.get("detector_info") or []
|
| 54 |
+
out = []
|
| 55 |
+
for b in live_boxes:
|
| 56 |
+
best_iou, best_idx = 0.0, -1
|
| 57 |
+
for j, bb in enumerate(base_boxes):
|
| 58 |
+
i = _iou_xyxy(b, bb)
|
| 59 |
+
if i > best_iou:
|
| 60 |
+
best_iou, best_idx = i, j
|
| 61 |
+
out.append(info[best_idx] if (best_iou >= iou_thresh and 0 <= best_idx < len(info)) else None)
|
| 62 |
+
return out
|
| 63 |
+
|
| 64 |
+
def _reindex_det_info(old_boxes, old_info, new_boxes, iou_thresh=0.90):
|
| 65 |
+
"""Rebuild detector_info aligned to new_boxes by matching from (old_boxes, old_info)."""
|
| 66 |
+
info_new = []
|
| 67 |
+
for b in new_boxes:
|
| 68 |
+
best_iou, best_idx = 0.0, -1
|
| 69 |
+
for j, ob in enumerate(old_boxes):
|
| 70 |
+
i = _iou_xyxy(b, ob)
|
| 71 |
+
if i > best_iou:
|
| 72 |
+
best_iou, best_idx = i, j
|
| 73 |
+
info_new.append(old_info[best_idx] if (best_iou >= iou_thresh and 0 <= best_idx < len(old_info)) else None)
|
| 74 |
+
return info_new
|
| 75 |
+
|
| 76 |
def _run_detection(item):
|
| 77 |
det = st.session_state.get("detector")
|
| 78 |
+
if det is None:
|
| 79 |
+
return []
|
| 80 |
cfg = _det_cfg_from_state()
|
| 81 |
+
return det.predict_one(item["pil"], cfg) or [] # list of dicts {'bbox':[x1,y1,x2,y2], 'score','cls','name',...}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
def _parse_canvas_into_boxes(item, scale, W, H):
|
| 84 |
"""Read current canvas JSON back into xyxy boxes on the original image grid."""
|
|
|
|
| 87 |
|
| 88 |
def _classify_now(item, predict_fn):
|
| 89 |
"""Run classifier on current boxes, then sync to session samples."""
|
| 90 |
+
if not item.get("boxes"):
|
| 91 |
item["preds"] = []; item["user_labels"] = []; item["actions"] = []
|
| 92 |
sync_samples_with_state(); return
|
| 93 |
crops = [item["pil"].crop(b) for b in item["boxes"]]
|
| 94 |
with torch.inference_mode():
|
| 95 |
item["preds"] = predict_fn(crops, topk=3)
|
| 96 |
+
# set_defaults fills current_label/confidence/margin/badges and seeds samples
|
| 97 |
set_defaults_from_preds(item)
|
| 98 |
item["actions"] = ["pending"] * len(item["boxes"])
|
| 99 |
sync_samples_with_state()
|
| 100 |
|
| 101 |
def _bump_on_image_switch(ak: str, item: dict):
|
| 102 |
+
"""Force a canvas remount when the active image changes."""
|
| 103 |
if st.session_state.get("_last_active_key") != ak:
|
| 104 |
item["canvas_rev"] = (item.get("canvas_rev", 0) + 1)
|
| 105 |
st.session_state["_last_active_key"] = ak
|
| 106 |
|
| 107 |
+
# ---------------- main UI ----------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
def render_workflow_tab(predict_fn, class_names, class_to_id, model_meta, export_root):
|
| 109 |
CANVAS_W = _auto_canvas_width()
|
| 110 |
if "uploader_rev" not in st.session_state:
|
| 111 |
st.session_state.uploader_rev = 0
|
| 112 |
+
|
| 113 |
# =============== (1) Load images ===============
|
| 114 |
st.subheader("1) Load image(s)")
|
| 115 |
ups = st.file_uploader(
|
| 116 |
"Upload one or more images",
|
| 117 |
type=["jpg", "jpeg", "png", "bmp"],
|
| 118 |
accept_multiple_files=True,
|
| 119 |
+
key=f"uploader_{st.session_state.uploader_rev}" # resettable
|
| 120 |
)
|
| 121 |
for up in (ups or []):
|
| 122 |
key = f"{up.name}-{up.size}"
|
|
|
|
| 125 |
st.session_state.images[key] = {
|
| 126 |
"key": key, "name": up.name, "pil": pil,
|
| 127 |
"boxes": [], "preds": [], "user_labels": [], "actions": [],
|
| 128 |
+
"detector_info": [], # <— conf/cls/name aligned with boxes
|
| 129 |
+
"canvas_json": {"objects": []},
|
| 130 |
"canvas_rev": 0,
|
| 131 |
}
|
| 132 |
if st.session_state.active_key is None:
|
|
|
|
| 137 |
st.info("Upload images or load from S3/Drive in the sidebar.")
|
| 138 |
return
|
| 139 |
|
|
|
|
| 140 |
names = [st.session_state.images[k]["name"] for k in keys]
|
| 141 |
idx = keys.index(st.session_state.active_key) if st.session_state.active_key in keys else 0
|
| 142 |
chosen = st.selectbox("Active image", names, index=idx)
|
|
|
|
| 155 |
if st.session_state.assist_plus and not item.get("boxes") and not item.get("_skip_autodetect_once"):
|
| 156 |
det_out = _run_detection(item)
|
| 157 |
if det_out:
|
| 158 |
+
boxes_xyxy = [tuple(int(v) for v in o["bbox"]) for o in det_out]
|
| 159 |
+
det_info = []
|
| 160 |
+
for o in det_out:
|
| 161 |
+
conf = o.get("score", o.get("conf", o.get("confidence")))
|
| 162 |
+
clsid = o.get("cls", o.get("class"))
|
| 163 |
+
name = o.get("name")
|
| 164 |
+
det_info.append({
|
| 165 |
+
"conf": float(conf) if conf is not None else None,
|
| 166 |
+
"cls": int(clsid) if clsid is not None else None,
|
| 167 |
+
"name": name if name is not None else None,
|
| 168 |
+
})
|
| 169 |
+
item["boxes"] = boxes_xyxy
|
| 170 |
+
item["detector_info"] = det_info
|
| 171 |
+
item["canvas_json"] = seed_canvas_from_boxes(
|
| 172 |
+
boxes_xyxy, scale, det_scores=[d["conf"] for d in det_info]
|
| 173 |
+
)
|
| 174 |
item["canvas_mode"] = "edit"
|
| 175 |
+
_parse_canvas_into_boxes(item, scale, base.width, base.height)
|
|
|
|
|
|
|
| 176 |
_classify_now(item, predict_fn)
|
| 177 |
st.rerun()
|
| 178 |
|
| 179 |
+
# Build display background
|
| 180 |
bg_pil = ImageOps.exif_transpose(base).resize((disp_w, disp_h), Image.BILINEAR).convert("RGB").copy()
|
| 181 |
|
| 182 |
# 2) Draw / edit rectangles
|
|
|
|
| 203 |
else:
|
| 204 |
st.caption("Tip: Drag on the image to draw new boxes.")
|
| 205 |
|
| 206 |
+
# ---------- initial drawing ----------
|
| 207 |
+
base_json = item.get("canvas_json") or {"objects": []}
|
| 208 |
+
show_labels = st.checkbox("Show box labels on canvas", value=True, key=f"show_labels_{ak}")
|
| 209 |
+
display_json = inject_index_labels(base_json, show_conf=False) if show_labels else base_json
|
|
|
|
| 210 |
|
| 211 |
# ---------- Canvas ----------
|
| 212 |
canvas = st_canvas(
|
| 213 |
fill_color="rgba(0,0,0,0)",
|
| 214 |
stroke_width=3,
|
| 215 |
stroke_color="#FF9900",
|
| 216 |
+
background_image=bg_pil,
|
| 217 |
update_streamlit=True,
|
| 218 |
height=disp_h,
|
| 219 |
width=disp_w,
|
| 220 |
+
drawing_mode=drawing_mode,
|
| 221 |
display_toolbar=True,
|
| 222 |
+
initial_drawing=display_json,
|
| 223 |
key=f"canvas_{ak}_{item.get('canvas_rev', 0)}_{disp_w}x{disp_h}"
|
| 224 |
)
|
| 225 |
|
| 226 |
+
# ---------- Live canvas snapshot ----------
|
| 227 |
live_json = canvas.json_data or base_json
|
| 228 |
objs = list(live_json.get("objects") or [])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
|
| 230 |
+
# Rect indices & objects (ignore overlay textboxes)
|
| 231 |
rect_indices = [k for k, o in enumerate(objs) if o.get("type") == "rect"]
|
| 232 |
+
rect_objs = [objs[k] for k in rect_indices]
|
| 233 |
|
| 234 |
+
# Live boxes in original image coords
|
| 235 |
live_boxes = parse_boxes(live_json, scale, base.width, base.height)
|
| 236 |
|
| 237 |
+
# Detector summary (use detector_info matched to live boxes)
|
| 238 |
+
infos = _match_det_info_to_live_boxes(item, live_boxes)
|
| 239 |
+
confs = [d["conf"] for d in infos if d and d.get("conf") is not None]
|
| 240 |
+
if confs:
|
| 241 |
+
n = len(confs)
|
| 242 |
+
cmin, cmax = min(confs), max(confs)
|
| 243 |
+
cmean = sum(confs)/n
|
| 244 |
+
th = float(st.session_state.get("det_cfg", {}).get("conf", 0.25))
|
| 245 |
+
below = sum(1 for c in confs if c < th)
|
| 246 |
+
st.markdown(
|
| 247 |
+
f"**Active image (detector)** — boxes: {n} · mean **{cmean:.3f}** · "
|
| 248 |
+
f"min **{cmin:.3f}** · max **{cmax:.3f}** · below {th:.2f}: **{below}/{n}**"
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
# ---- Crops on canvas (live) ----
|
| 252 |
if live_boxes:
|
| 253 |
with st.expander(f"Crops on canvas (live) — {len(live_boxes)}", expanded=False):
|
| 254 |
for i, b in enumerate(live_boxes):
|
| 255 |
+
info = infos[i] if i < len(infos) else None
|
| 256 |
+
det_conf = (info or {}).get("conf", None)
|
| 257 |
+
# resolve detector class label
|
| 258 |
+
det_label = None
|
| 259 |
+
if info:
|
| 260 |
+
if info.get("name") is not None:
|
| 261 |
+
det_label = str(info["name"])
|
| 262 |
+
elif info.get("cls") is not None:
|
| 263 |
+
names = st.session_state.get("detector_names") or []
|
| 264 |
+
cid = info["cls"]
|
| 265 |
+
if isinstance(names, (list, tuple)) and isinstance(cid, int) and 0 <= cid < len(names):
|
| 266 |
+
det_label = str(names[cid])
|
| 267 |
+
else:
|
| 268 |
+
det_label = str(cid)
|
| 269 |
+
|
| 270 |
+
c1, c2, c3 = st.columns([6, 3, 1])
|
| 271 |
with c1:
|
| 272 |
st.write(f"#{i+1} [x0={b[0]}, y0={b[1]}, x1={b[2]}, y1={b[3]}]")
|
| 273 |
+
|
| 274 |
with c2:
|
| 275 |
+
if det_conf is None and not det_label:
|
| 276 |
+
st.caption("det: —")
|
| 277 |
+
else:
|
| 278 |
+
th = float(st.session_state.get("det_cfg", {}).get("conf", 0.25))
|
| 279 |
+
status = " ⚠️" if (det_conf is not None and det_conf < th) else ""
|
| 280 |
+
parts = []
|
| 281 |
+
if det_conf is not None: parts.append(f"{det_conf:.2f}")
|
| 282 |
+
if det_label: parts.append(det_label)
|
| 283 |
+
st.caption("det: " + " · ".join(parts) + status)
|
| 284 |
+
|
| 285 |
+
with c3:
|
| 286 |
if st.button("Remove", key=f"rm_live_{ak}_{i}"):
|
| 287 |
+
# Remove by rect index, then remount canvas
|
| 288 |
+
true_idx = rect_indices[i] if i < len(rect_indices) else None
|
| 289 |
+
if true_idx is not None and 0 <= true_idx < len(objs):
|
| 290 |
+
# snapshot old for detector-info reindex
|
| 291 |
+
old_boxes = item.get("boxes") or []
|
| 292 |
+
old_info = item.get("detector_info") or []
|
| 293 |
+
|
| 294 |
+
objs.pop(true_idx)
|
| 295 |
+
# if a label textbox immediately follows, drop it too (best-effort)
|
| 296 |
+
if true_idx < len(objs) and objs[true_idx].get("type") == "textbox":
|
| 297 |
+
objs.pop(true_idx)
|
| 298 |
+
|
| 299 |
+
# update canvas + boxes
|
| 300 |
+
live_json["objects"] = objs
|
| 301 |
+
item["canvas_json"] = live_json
|
| 302 |
+
_parse_canvas_into_boxes(item, scale, base.width, base.height)
|
| 303 |
+
|
| 304 |
+
# reindex detector_info to new boxes
|
| 305 |
+
new_boxes = item.get("boxes") or []
|
| 306 |
+
item["detector_info"] = _reindex_det_info(old_boxes, old_info, new_boxes)
|
| 307 |
+
|
| 308 |
item["canvas_mode"] = "edit"
|
| 309 |
item["canvas_rev"] = (item.get("canvas_rev", 0) + 1)
|
|
|
|
| 310 |
item["preds"] = []; item["user_labels"] = []; item["actions"] = []
|
| 311 |
sync_samples_with_state()
|
| 312 |
st.rerun()
|
| 313 |
|
| 314 |
+
# Always show a live count under the canvas
|
|
|
|
| 315 |
st.caption(f"Crops in this image: **{len(live_boxes)}**")
|
| 316 |
|
| 317 |
+
# Commit canvas into boxes (Clear / Remove)
|
|
|
|
| 318 |
col_clear, col_remove = st.columns([1, 1])
|
| 319 |
|
| 320 |
if col_clear.button("Clear boxes (THIS image)", key=f"clear_{ak}"):
|
| 321 |
item["canvas_json"] = {"objects": []}
|
| 322 |
item["boxes"] = []; item["preds"] = []; item["user_labels"] = []; item["actions"] = []
|
| 323 |
+
item["detector_info"] = [] # also clear detector stats
|
| 324 |
item["canvas_mode"] = "draw"
|
| 325 |
item["canvas_rev"] = (item.get("canvas_rev", 0) + 1)
|
| 326 |
item["_skip_autodetect_once"] = True
|
|
|
|
| 328 |
st.rerun()
|
| 329 |
|
| 330 |
if col_remove.button("Remove image", key=f"remove_{ak}"):
|
|
|
|
| 331 |
remaining_keys = [k for k in st.session_state.images.keys() if k != ak]
|
| 332 |
next_key = remaining_keys[0] if remaining_keys else None
|
|
|
|
|
|
|
| 333 |
del st.session_state.images[ak]
|
| 334 |
st.session_state.all_samples = [s for s in st.session_state.all_samples if s.image_key != ak]
|
|
|
|
|
|
|
| 335 |
st.session_state.uploader_rev += 1
|
|
|
|
|
|
|
| 336 |
st.session_state.active_key = next_key
|
| 337 |
if next_key:
|
| 338 |
nxt = st.session_state.images[next_key]
|
| 339 |
nxt["_skip_autodetect_once"] = True
|
| 340 |
nxt["canvas_mode"] = "edit" if nxt.get("boxes") else "draw"
|
| 341 |
+
nxt["canvas_rev"] = nxt.get("canvas_rev", 0) + 1
|
|
|
|
| 342 |
st.rerun()
|
| 343 |
|
|
|
|
|
|
|
| 344 |
# =============== (3) Detect / Classify controls ===============
|
| 345 |
st.subheader("3) Detect & Classify")
|
| 346 |
|
| 347 |
assist_on = bool(st.session_state.get("assist_plus"))
|
| 348 |
has_boxes = bool(item.get("boxes"))
|
| 349 |
|
|
|
|
| 350 |
canvas_objects = (canvas.json_data or {}).get("objects") or []
|
| 351 |
has_canvas_rects = any(obj.get("type") == "rect" for obj in canvas_objects)
|
| 352 |
can_classify = has_boxes or has_canvas_rects
|
| 353 |
|
| 354 |
# ---------- DETECT ----------
|
| 355 |
if assist_on:
|
|
|
|
| 356 |
st.caption("Detect: find objects and write boxes onto the canvas.")
|
|
|
|
| 357 |
dcol1, dcol2 = st.columns([1, 1])
|
| 358 |
|
| 359 |
# First-time Detect (hidden if Assist+ already produced boxes)
|
|
|
|
| 366 |
):
|
| 367 |
det_out = _run_detection(item)
|
| 368 |
if det_out:
|
| 369 |
+
boxes_xyxy = [tuple(int(v) for v in o["bbox"]) for o in det_out]
|
| 370 |
+
det_info = []
|
| 371 |
+
for o in det_out:
|
| 372 |
+
conf = o.get("score", o.get("conf", o.get("confidence")))
|
| 373 |
+
clsid = o.get("cls", o.get("class"))
|
| 374 |
+
name = o.get("name")
|
| 375 |
+
det_info.append({
|
| 376 |
+
"conf": float(conf) if conf is not None else None,
|
| 377 |
+
"cls": int(clsid) if clsid is not None else None,
|
| 378 |
+
"name": name if name is not None else None,
|
| 379 |
+
})
|
| 380 |
+
item["boxes"] = boxes_xyxy
|
| 381 |
+
item["detector_info"] = det_info
|
| 382 |
+
item["canvas_json"] = seed_canvas_from_boxes(
|
| 383 |
+
boxes_xyxy, scale, det_scores=[d["conf"] for d in det_info]
|
| 384 |
+
)
|
| 385 |
item["canvas_mode"] = "edit"
|
| 386 |
item["canvas_rev"] = (item.get("canvas_rev", 0) + 1)
|
| 387 |
item["preds"] = []; item["user_labels"] = []; item["actions"] = []
|
|
|
|
| 406 |
):
|
| 407 |
det_out = _run_detection(item)
|
| 408 |
new_boxes = [tuple(int(v) for v in o["bbox"]) for o in det_out]
|
| 409 |
+
new_info = []
|
| 410 |
+
for o in det_out:
|
| 411 |
+
conf = o.get("score", o.get("conf", o.get("confidence")))
|
| 412 |
+
clsid = o.get("cls", o.get("class"))
|
| 413 |
+
name = o.get("name")
|
| 414 |
+
new_info.append({
|
| 415 |
+
"conf": float(conf) if conf is not None else None,
|
| 416 |
+
"cls": int(clsid) if clsid is not None else None,
|
| 417 |
+
"name": name if name is not None else None,
|
| 418 |
+
})
|
| 419 |
if new_boxes:
|
| 420 |
if rd_mode == "Replace":
|
| 421 |
item["boxes"] = new_boxes
|
| 422 |
+
item["detector_info"] = new_info
|
| 423 |
+
merged_boxes = new_boxes
|
| 424 |
+
merged_scores = [d["conf"] for d in new_info]
|
| 425 |
else:
|
| 426 |
+
old_boxes = list(item.get("boxes") or [])
|
| 427 |
+
old_info = list(item.get("detector_info") or [])
|
| 428 |
+
merged_boxes = old_boxes + new_boxes
|
| 429 |
+
item["boxes"] = merged_boxes
|
| 430 |
+
item["detector_info"] = old_info + new_info
|
| 431 |
+
merged_scores = [d["conf"] for d in (old_info + new_info)]
|
| 432 |
+
|
| 433 |
+
item["canvas_json"] = seed_canvas_from_boxes(
|
| 434 |
+
merged_boxes, scale, det_scores=merged_scores
|
| 435 |
+
)
|
| 436 |
item["canvas_mode"] = "edit"
|
|
|
|
| 437 |
item["preds"] = []; item["user_labels"] = []; item["actions"] = []
|
| 438 |
sync_samples_with_state()
|
| 439 |
st.rerun()
|
|
|
|
| 464 |
_classify_now(item, predict_fn)
|
| 465 |
st.rerun()
|
| 466 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 467 |
_render_active_summary(class_names)
|
| 468 |
|
| 469 |
def _render_active_summary(class_names):
|
| 470 |
st.subheader("Summary for Active Image")
|
| 471 |
+
ak = st.session_state.active_key
|
| 472 |
if not st.session_state.all_samples or not ak:
|
| 473 |
+
st.info("No samples available. Run predictions on images first.")
|
| 474 |
+
return
|
| 475 |
+
active = [s for s in st.session_state.all_samples if s.image_key == ak]
|
| 476 |
+
if not active:
|
| 477 |
+
st.info("No samples processed for this image yet. Run predictions.")
|
| 478 |
+
return
|
| 479 |
+
active.sort(key=lambda s: s.crop_idx)
|
| 480 |
+
img_data = st.session_state.images[ak]
|
| 481 |
for sample in active:
|
| 482 |
+
crop_image = img_data["pil"].crop(sample.bbox)
|
| 483 |
+
col_img, col_details, col_relabel = st.columns([1, 2, 1.5])
|
| 484 |
with col_img:
|
| 485 |
st.image(crop_image, caption=f"Crop {sample.crop_idx+1}", width=180)
|
| 486 |
with col_details:
|
| 487 |
+
top3_str = ", ".join([f"{n} ({p:.3f})" for n, p in sample.top3[:3]])
|
| 488 |
st.write(f"**Current Label:** {sample.current_label}")
|
| 489 |
st.write(f"**Top-3:** {top3_str}")
|
| 490 |
st.write(f"**Confidence:** {sample.confidence:.3f}")
|
|
|
|
| 498 |
idx = class_names.index(sample.current_label) if sample.current_label in class_names else 0
|
| 499 |
except Exception:
|
| 500 |
idx = 0
|
| 501 |
+
new_label = st.selectbox(
|
| 502 |
+
f"Relabel Crop {sample.crop_idx+1}",
|
| 503 |
+
class_names,
|
| 504 |
+
index=idx,
|
| 505 |
+
key=f"wf_summary_relabel_select_{sample.image_key}_{sample.crop_idx}",
|
| 506 |
+
)
|
| 507 |
if new_label != sample.current_label:
|
| 508 |
+
sample.user_label = new_label
|
| 509 |
+
sample.action = "relabel"
|
| 510 |
from core.state import update_image_state_from_samples
|
| 511 |
+
update_image_state_from_samples()
|
| 512 |
+
st.rerun()
|
| 513 |
st.divider()
|