Tan Zi Xu commited on
Commit
dd63ee6
·
1 Parent(s): a84e477

reflect detection model conf in crop tracking tab

Browse files
Files changed (2) hide show
  1. core/drawing.py +71 -0
  2. ui/workflow.py +217 -143
core/drawing.py CHANGED
@@ -40,3 +40,74 @@ def boxes_to_canvas_json(boxes, scale: float = 1.0, stroke_color: str = "#FF9900
40
  })
41
  return {"objects": objects}
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  })
41
  return {"objects": objects}
42
 
43
+ def rects_only_json(j: dict) -> dict:
44
+ """Return a JSON with only rectangle objects (drop labels/others)."""
45
+ if not j:
46
+ return {"objects": []}
47
+ return {"objects": [o for o in (j.get("objects") or []) if o.get("type") == "rect"]}
48
+
49
+ def get_rect_conf(rect_obj: dict) -> float | None:
50
+ """Read detector confidence from canvas rect's custom metadata, if any."""
51
+ meta = rect_obj.get("meta") or {}
52
+ c = meta.get("det_conf", None)
53
+ try:
54
+ return float(c) if c is not None else None
55
+ except Exception:
56
+ return None
57
+
58
+ def inject_index_labels(live_json: dict, offset_px: int = 4, show_conf: bool = True):
59
+ """
60
+ Return a copy of live_json with small text labels injected at top-left of each rect.
61
+ If a rect has meta.det_conf, label becomes 'idx · 0.82' when show_conf=True.
62
+ """
63
+ if not live_json:
64
+ return {"objects": []}
65
+ objs = list(live_json.get("objects") or [])
66
+ labeled, idx = [], 1
67
+ for o in objs:
68
+ labeled.append(o)
69
+ if o.get("type") == "rect":
70
+ text = f"{idx}"
71
+ if show_conf:
72
+ conf = get_rect_conf(o)
73
+ if conf is not None:
74
+ text = f"{idx} · {conf:.2f}"
75
+ left = float(o.get("left", 0.0)) + offset_px
76
+ top = float(o.get("top", 0.0)) + offset_px
77
+ labeled.append({
78
+ "type": "textbox", "text": text,
79
+ "left": left, "top": top,
80
+ "width": max(24, 10*len(text)), "height": 20,
81
+ "fontSize": 14, "fill": "#ffffff", "backgroundColor": "#111827",
82
+ "opacity": 0.9, "selectable": False, "evented": False,
83
+ "hasControls": False, "editable": False, "textAlign": "center",
84
+ "fontFamily": "sans-serif",
85
+ })
86
+ idx += 1
87
+ return {"objects": labeled}
88
+
89
+ def seed_canvas_from_boxes(boxes_xyxy, scale: float, det_scores=None):
90
+ """
91
+ Build a canvas JSON (rect objects) from xyxy boxes at given display scale.
92
+ If det_scores provided, store them as rect.meta['det_conf'].
93
+ """
94
+ objects = []
95
+ det_scores = det_scores or [None] * len(boxes_xyxy)
96
+ for (x0, y0, x1, y1), sc in zip(boxes_xyxy, det_scores):
97
+ left = float(x0) * scale
98
+ top = float(y0) * scale
99
+ width = float(x1 - x0) * scale
100
+ height = float(y1 - y0) * scale
101
+ meta = {"origin": "det"}
102
+ if sc is not None:
103
+ try:
104
+ meta["det_conf"] = float(sc)
105
+ except Exception:
106
+ pass
107
+ objects.append({
108
+ "type": "rect",
109
+ "left": left, "top": top, "width": width, "height": height,
110
+ "strokeWidth": 3, "stroke": "#FF9900", "fill": "rgba(0,0,0,0)",
111
+ "angle": 0, "meta": meta
112
+ })
113
+ return {"objects": objects}
ui/workflow.py CHANGED
@@ -1,17 +1,21 @@
1
  # ui/workflow.py
2
- import io
3
  import streamlit as st
4
- import numpy as np
5
  from PIL import Image, ImageOps
6
- import uuid
7
  from streamlit_drawable_canvas import st_canvas
8
- from core.drawing import parse_boxes, boxes_to_canvas_json
9
- from core.state import set_defaults_from_preds, sync_samples_with_state, update_image_state_from_samples
10
- from core.exports import export_session
 
 
 
 
 
 
 
11
  from core.detect_infer import DetConfig
12
  import torch
13
 
14
- # ---- helpers ----
15
  def _auto_canvas_width():
16
  try:
17
  from streamlit_js_eval import get_page_info
@@ -31,17 +35,50 @@ def _det_cfg_from_state():
31
  half=bool(c.get("half", True)),
32
  )
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  def _run_detection(item):
35
  det = st.session_state.get("detector")
36
- if det is None: return []
 
37
  cfg = _det_cfg_from_state()
38
- return det.predict_one(item["pil"], cfg) or [] # list of dicts {'bbox':[x1,y1,x2,y2], 'score', ...}
39
-
40
- def _seed_canvas_from_boxes(item, disp_w, base_w):
41
- """Write boxes to canvas JSON at display scale (non-destructive)."""
42
- if not item.get("boxes"): return
43
- scale = float(disp_w) / float(base_w)
44
- item["canvas_json"] = boxes_to_canvas_json(item["boxes"], scale=scale)
45
 
46
  def _parse_canvas_into_boxes(item, scale, W, H):
47
  """Read current canvas JSON back into xyxy boxes on the original image grid."""
@@ -50,79 +87,36 @@ def _parse_canvas_into_boxes(item, scale, W, H):
50
 
51
  def _classify_now(item, predict_fn):
52
  """Run classifier on current boxes, then sync to session samples."""
53
- if not item.get("boxes"):
54
  item["preds"] = []; item["user_labels"] = []; item["actions"] = []
55
  sync_samples_with_state(); return
56
  crops = [item["pil"].crop(b) for b in item["boxes"]]
57
  with torch.inference_mode():
58
  item["preds"] = predict_fn(crops, topk=3)
 
59
  set_defaults_from_preds(item)
60
  item["actions"] = ["pending"] * len(item["boxes"])
61
  sync_samples_with_state()
62
 
63
  def _bump_on_image_switch(ak: str, item: dict):
64
- """Force a canvas remount the moment active image changes."""
65
  if st.session_state.get("_last_active_key") != ak:
66
  item["canvas_rev"] = (item.get("canvas_rev", 0) + 1)
67
  st.session_state["_last_active_key"] = ak
68
 
69
- def _inject_index_labels(live_json: dict, offset_px: int = 4):
70
- """
71
- Return a copy of live_json with non-selectable numeric labels injected
72
- for each rectangle in the order they appear.
73
- Labels are Fabric 'textbox' objects and will be ignored by parse_boxes().
74
- """
75
- if not live_json:
76
- return {"objects": []}
77
- objs = list(live_json.get("objects") or [])
78
- labeled = []
79
- idx = 1
80
- for o in objs:
81
- labeled.append(o)
82
- if o.get("type") == "rect":
83
- # Place label near rect's top-left in canvas coords
84
- left = float(o.get("left", 0.0)) + offset_px
85
- top = float(o.get("top", 0.0)) + offset_px
86
- label = {
87
- "type": "textbox",
88
- "text": str(idx),
89
- "left": left,
90
- "top": top,
91
- "width": 24,
92
- "height": 20,
93
- "fontSize": 16,
94
- "fill": "#ffffff",
95
- "backgroundColor": "#111827", # slate-900
96
- "opacity": 0.85,
97
- "selectable": False,
98
- "evented": False,
99
- "hasControls": False,
100
- "editable": False,
101
- "textAlign": "center",
102
- "fontFamily": "sans-serif",
103
- }
104
- labeled.append(label)
105
- idx += 1
106
- return {"objects": labeled}
107
-
108
- def rects_only_json(j: dict) -> dict:
109
- """Return a JSON with only rectangle objects (drop labels/others)."""
110
- if not j:
111
- return {"objects": []}
112
- return {"objects": [o for o in (j.get("objects") or []) if o.get("type") == "rect"]}
113
-
114
- # ---- main UI ----
115
  def render_workflow_tab(predict_fn, class_names, class_to_id, model_meta, export_root):
116
  CANVAS_W = _auto_canvas_width()
117
  if "uploader_rev" not in st.session_state:
118
  st.session_state.uploader_rev = 0
 
119
  # =============== (1) Load images ===============
120
  st.subheader("1) Load image(s)")
121
  ups = st.file_uploader(
122
  "Upload one or more images",
123
  type=["jpg", "jpeg", "png", "bmp"],
124
  accept_multiple_files=True,
125
- key=f"uploader_{st.session_state.uploader_rev}" # <-- resettable
126
  )
127
  for up in (ups or []):
128
  key = f"{up.name}-{up.size}"
@@ -131,7 +125,8 @@ def render_workflow_tab(predict_fn, class_names, class_to_id, model_meta, export
131
  st.session_state.images[key] = {
132
  "key": key, "name": up.name, "pil": pil,
133
  "boxes": [], "preds": [], "user_labels": [], "actions": [],
134
- "canvas_json": None,
 
135
  "canvas_rev": 0,
136
  }
137
  if st.session_state.active_key is None:
@@ -142,7 +137,6 @@ def render_workflow_tab(predict_fn, class_names, class_to_id, model_meta, export
142
  st.info("Upload images or load from S3/Drive in the sidebar.")
143
  return
144
 
145
-
146
  names = [st.session_state.images[k]["name"] for k in keys]
147
  idx = keys.index(st.session_state.active_key) if st.session_state.active_key in keys else 0
148
  chosen = st.selectbox("Active image", names, index=idx)
@@ -161,16 +155,28 @@ def render_workflow_tab(predict_fn, class_names, class_to_id, model_meta, export
161
  if st.session_state.assist_plus and not item.get("boxes") and not item.get("_skip_autodetect_once"):
162
  det_out = _run_detection(item)
163
  if det_out:
164
- item["boxes"] = [tuple(int(v) for v in o["bbox"]) for o in det_out]
165
- _seed_canvas_from_boxes(item, disp_w, base.width)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  item["canvas_mode"] = "edit"
167
- item["canvas_rev"] = (item.get("canvas_rev", 0) + 1)
168
- item["_skip_autodetect_once"] = False
169
- _parse_canvas_into_boxes(item, scale, base.width, base.height) # ensure boxes align with canvas JSON
170
  _classify_now(item, predict_fn)
171
  st.rerun()
172
 
173
- # Build an orientation-safe PIL background at display size
174
  bg_pil = ImageOps.exif_transpose(base).resize((disp_w, disp_h), Image.BILINEAR).convert("RGB").copy()
175
 
176
  # 2) Draw / edit rectangles
@@ -197,83 +203,124 @@ def render_workflow_tab(predict_fn, class_names, class_to_id, model_meta, export
197
  else:
198
  st.caption("Tip: Drag on the image to draw new boxes.")
199
 
200
- # ---------- IMPORTANT: prepare initial drawing BEFORE canvas ----------
201
- # Use whatever we have persisted so far (no live JSON yet)
202
- base_json = rects_only_json(item.get("canvas_json") or {})
203
- item["canvas_json"] = base_json # keep only rects in state
204
- display_json = _inject_index_labels(base_json)
205
 
206
  # ---------- Canvas ----------
207
  canvas = st_canvas(
208
  fill_color="rgba(0,0,0,0)",
209
  stroke_width=3,
210
  stroke_color="#FF9900",
211
- background_image=bg_pil, # PIL image; reliable first paint
212
  update_streamlit=True,
213
  height=disp_h,
214
  width=disp_w,
215
- drawing_mode=drawing_mode, # Draw vs Edit (transform)
216
  display_toolbar=True,
217
- initial_drawing=display_json, # labeled copy (not persisted)
218
  key=f"canvas_{ak}_{item.get('canvas_rev', 0)}_{disp_w}x{disp_h}"
219
  )
220
 
221
- # ---------- Live canvas snapshot (now canvas.json_data exists) ----------
222
  live_json = canvas.json_data or base_json
223
  objs = list(live_json.get("objects") or [])
224
- live_rects_only = rects_only_json(live_json) # keep only rectangles
225
-
226
- # Re-index labels whenever crop count changes (e.g., a new rectangle drawn)
227
- prev_cnt = item.get("_prev_rect_count", None)
228
- cur_cnt = len(live_rects_only["objects"])
229
- if prev_cnt is None:
230
- item["_prev_rect_count"] = cur_cnt
231
- elif cur_cnt != prev_cnt:
232
- # persist rects-only, bump rev to redraw with fresh labels 1..N
233
- item["canvas_json"] = live_rects_only
234
- item["canvas_rev"] = (item.get("canvas_rev", 0) + 1)
235
- item["_prev_rect_count"] = cur_cnt
236
- st.rerun()
237
 
238
- # Derive rect indices so Remove buttons target the right object (labels are 'textbox')
239
  rect_indices = [k for k, o in enumerate(objs) if o.get("type") == "rect"]
 
240
 
241
- # Live boxes (original image coords)
242
  live_boxes = parse_boxes(live_json, scale, base.width, base.height)
243
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  # ---- Crops on canvas (live) ----
245
  if live_boxes:
246
  with st.expander(f"Crops on canvas (live) — {len(live_boxes)}", expanded=False):
247
  for i, b in enumerate(live_boxes):
248
- c1, c2 = st.columns([6, 1])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  with c1:
250
  st.write(f"#{i+1} [x0={b[0]}, y0={b[1]}, x1={b[2]}, y1={b[3]}]")
 
251
  with c2:
 
 
 
 
 
 
 
 
 
 
 
252
  if st.button("Remove", key=f"rm_live_{ak}_{i}"):
253
- # Start from the full live JSON, drop the i-th rectangle (labels ignored)
254
- rects = [o for o in (live_json.get("objects") or []) if o.get("type") == "rect"]
255
- if 0 <= i < len(rects):
256
- rects.pop(i)
257
- # Persist rects-only; labels will be re-injected next render and re-indexed 1..N
258
- item["canvas_json"] = {"objects": rects}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
  item["canvas_mode"] = "edit"
260
  item["canvas_rev"] = (item.get("canvas_rev", 0) + 1)
261
- # predictions are stale after geometry changes
262
  item["preds"] = []; item["user_labels"] = []; item["actions"] = []
263
  sync_samples_with_state()
264
  st.rerun()
265
 
266
-
267
- # Keep the on-screen count truly live
268
  st.caption(f"Crops in this image: **{len(live_boxes)}**")
269
 
270
- # Commit canvas into boxes (Save / Lock)
271
- # (was: Save/Lock, Clear, Remove). Now: Clear + Remove only.
272
  col_clear, col_remove = st.columns([1, 1])
273
 
274
  if col_clear.button("Clear boxes (THIS image)", key=f"clear_{ak}"):
275
  item["canvas_json"] = {"objects": []}
276
  item["boxes"] = []; item["preds"] = []; item["user_labels"] = []; item["actions"] = []
 
277
  item["canvas_mode"] = "draw"
278
  item["canvas_rev"] = (item.get("canvas_rev", 0) + 1)
279
  item["_skip_autodetect_once"] = True
@@ -281,45 +328,32 @@ def render_workflow_tab(predict_fn, class_names, class_to_id, model_meta, export
281
  st.rerun()
282
 
283
  if col_remove.button("Remove image", key=f"remove_{ak}"):
284
- # Choose the next image BEFORE deletion
285
  remaining_keys = [k for k in st.session_state.images.keys() if k != ak]
286
  next_key = remaining_keys[0] if remaining_keys else None
287
-
288
- # Delete the image and its samples
289
  del st.session_state.images[ak]
290
  st.session_state.all_samples = [s for s in st.session_state.all_samples if s.image_key != ak]
291
-
292
- # Clear the uploader so removed files don't get re-added on rerun
293
  st.session_state.uploader_rev += 1
294
-
295
- # Switch to the next image (if any) but DO NOT autorun detection once
296
  st.session_state.active_key = next_key
297
  if next_key:
298
  nxt = st.session_state.images[next_key]
299
  nxt["_skip_autodetect_once"] = True
300
  nxt["canvas_mode"] = "edit" if nxt.get("boxes") else "draw"
301
- nxt["canvas_rev"] = nxt.get("canvas_rev", 0) + 1 # fresh canvas
302
-
303
  st.rerun()
304
 
305
-
306
-
307
  # =============== (3) Detect / Classify controls ===============
308
  st.subheader("3) Detect & Classify")
309
 
310
  assist_on = bool(st.session_state.get("assist_plus"))
311
  has_boxes = bool(item.get("boxes"))
312
 
313
- # detect whether the user has drawn rectangles but not saved yet
314
  canvas_objects = (canvas.json_data or {}).get("objects") or []
315
  has_canvas_rects = any(obj.get("type") == "rect" for obj in canvas_objects)
316
  can_classify = has_boxes or has_canvas_rects
317
 
318
  # ---------- DETECT ----------
319
  if assist_on:
320
- # Only show detect controls in Assist+ mode
321
  st.caption("Detect: find objects and write boxes onto the canvas.")
322
-
323
  dcol1, dcol2 = st.columns([1, 1])
324
 
325
  # First-time Detect (hidden if Assist+ already produced boxes)
@@ -332,8 +366,22 @@ def render_workflow_tab(predict_fn, class_names, class_to_id, model_meta, export
332
  ):
333
  det_out = _run_detection(item)
334
  if det_out:
335
- item["boxes"] = [tuple(int(v) for v in o["bbox"]) for o in det_out]
336
- _seed_canvas_from_boxes(item, disp_w, base.width)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
  item["canvas_mode"] = "edit"
338
  item["canvas_rev"] = (item.get("canvas_rev", 0) + 1)
339
  item["preds"] = []; item["user_labels"] = []; item["actions"] = []
@@ -358,14 +406,34 @@ def render_workflow_tab(predict_fn, class_names, class_to_id, model_meta, export
358
  ):
359
  det_out = _run_detection(item)
360
  new_boxes = [tuple(int(v) for v in o["bbox"]) for o in det_out]
 
 
 
 
 
 
 
 
 
 
361
  if new_boxes:
362
  if rd_mode == "Replace":
363
  item["boxes"] = new_boxes
 
 
 
364
  else:
365
- item["boxes"] = list(item.get("boxes") or []) + new_boxes
366
- _seed_canvas_from_boxes(item, disp_w, base.width)
 
 
 
 
 
 
 
 
367
  item["canvas_mode"] = "edit"
368
- item["canvas_rev"] = (item.get("canvas_rev", 0) + 1)
369
  item["preds"] = []; item["user_labels"] = []; item["actions"] = []
370
  sync_samples_with_state()
371
  st.rerun()
@@ -396,27 +464,27 @@ def render_workflow_tab(predict_fn, class_names, class_to_id, model_meta, export
396
  _classify_now(item, predict_fn)
397
  st.rerun()
398
 
399
-
400
- # Quick count
401
- st.caption(f"Crops in this image: **{len(live_boxes)}**")
402
-
403
  _render_active_summary(class_names)
404
 
405
  def _render_active_summary(class_names):
406
  st.subheader("Summary for Active Image")
407
- ak=st.session_state.active_key
408
  if not st.session_state.all_samples or not ak:
409
- st.info("No samples available. Run predictions on images first."); return
410
- active=[s for s in st.session_state.all_samples if s.image_key==ak]
411
- if not active: st.info("No samples processed for this image yet. Run predictions."); return
412
- active.sort(key=lambda s: s.crop_idx); img_data=st.session_state.images[ak]
 
 
 
 
413
  for sample in active:
414
- crop_image=img_data["pil"].crop(sample.bbox)
415
- col_img,col_details,col_relabel=st.columns([1,2,1.5])
416
  with col_img:
417
  st.image(crop_image, caption=f"Crop {sample.crop_idx+1}", width=180)
418
  with col_details:
419
- top3_str=", ".join([f"{n} ({p:.3f})" for n,p in sample.top3[:3]])
420
  st.write(f"**Current Label:** {sample.current_label}")
421
  st.write(f"**Top-3:** {top3_str}")
422
  st.write(f"**Confidence:** {sample.confidence:.3f}")
@@ -430,10 +498,16 @@ def _render_active_summary(class_names):
430
  idx = class_names.index(sample.current_label) if sample.current_label in class_names else 0
431
  except Exception:
432
  idx = 0
433
- new_label = st.selectbox(f"Relabel Crop {sample.crop_idx+1}", class_names, index=idx,
434
- key=f"wf_summary_relabel_select_{sample.image_key}_{sample.crop_idx}")
 
 
 
 
435
  if new_label != sample.current_label:
436
- sample.user_label=new_label; sample.action="relabel"
 
437
  from core.state import update_image_state_from_samples
438
- update_image_state_from_samples(); st.rerun()
 
439
  st.divider()
 
1
  # ui/workflow.py
 
2
  import streamlit as st
 
3
  from PIL import Image, ImageOps
 
4
  from streamlit_drawable_canvas import st_canvas
5
+ from core.drawing import (
6
+ parse_boxes, # from canvas JSON -> list[xyxy]
7
+ boxes_to_canvas_json, # (kept for compatibility; not used here)
8
+ rects_only_json,
9
+ get_rect_conf,
10
+ inject_index_labels,
11
+ seed_canvas_from_boxes, # list[xyxy] -> canvas JSON
12
+ )
13
+ from core.state import set_defaults_from_preds, sync_samples_with_state
14
+ from core.exports import export_session # not used here but retained
15
  from core.detect_infer import DetConfig
16
  import torch
17
 
18
+ # ---------------- helpers ----------------
19
  def _auto_canvas_width():
20
  try:
21
  from streamlit_js_eval import get_page_info
 
35
  half=bool(c.get("half", True)),
36
  )
37
 
38
+ def _iou_xyxy(a, b, eps=1e-6):
39
+ ax1, ay1, ax2, ay2 = map(float, a)
40
+ bx1, by1, bx2, by2 = map(float, b)
41
+ ix1, iy1 = max(ax1, bx1), max(ay1, by1)
42
+ ix2, iy2 = min(ax2, bx2), min(ay2, by2)
43
+ iw, ih = max(0.0, ix2 - ix1), max(0.0, iy2 - iy1)
44
+ inter = iw * ih
45
+ aa = max(0.0, (ax2 - ax1)) * max(0.0, (ay2 - ay1))
46
+ ba = max(0.0, (bx2 - bx1)) * max(0.0, (by2 - by1))
47
+ union = aa + ba - inter + eps
48
+ return inter / union
49
+
50
+ def _match_det_info_to_live_boxes(item, live_boxes, iou_thresh=0.90):
51
+ """Align per-box detector info (conf/cls/name) to current live boxes."""
52
+ base_boxes = item.get("boxes") or []
53
+ info = item.get("detector_info") or []
54
+ out = []
55
+ for b in live_boxes:
56
+ best_iou, best_idx = 0.0, -1
57
+ for j, bb in enumerate(base_boxes):
58
+ i = _iou_xyxy(b, bb)
59
+ if i > best_iou:
60
+ best_iou, best_idx = i, j
61
+ out.append(info[best_idx] if (best_iou >= iou_thresh and 0 <= best_idx < len(info)) else None)
62
+ return out
63
+
64
+ def _reindex_det_info(old_boxes, old_info, new_boxes, iou_thresh=0.90):
65
+ """Rebuild detector_info aligned to new_boxes by matching from (old_boxes, old_info)."""
66
+ info_new = []
67
+ for b in new_boxes:
68
+ best_iou, best_idx = 0.0, -1
69
+ for j, ob in enumerate(old_boxes):
70
+ i = _iou_xyxy(b, ob)
71
+ if i > best_iou:
72
+ best_iou, best_idx = i, j
73
+ info_new.append(old_info[best_idx] if (best_iou >= iou_thresh and 0 <= best_idx < len(old_info)) else None)
74
+ return info_new
75
+
76
  def _run_detection(item):
77
  det = st.session_state.get("detector")
78
+ if det is None:
79
+ return []
80
  cfg = _det_cfg_from_state()
81
+ return det.predict_one(item["pil"], cfg) or [] # list of dicts {'bbox':[x1,y1,x2,y2], 'score','cls','name',...}
 
 
 
 
 
 
82
 
83
  def _parse_canvas_into_boxes(item, scale, W, H):
84
  """Read current canvas JSON back into xyxy boxes on the original image grid."""
 
87
 
88
  def _classify_now(item, predict_fn):
89
  """Run classifier on current boxes, then sync to session samples."""
90
+ if not item.get("boxes"):
91
  item["preds"] = []; item["user_labels"] = []; item["actions"] = []
92
  sync_samples_with_state(); return
93
  crops = [item["pil"].crop(b) for b in item["boxes"]]
94
  with torch.inference_mode():
95
  item["preds"] = predict_fn(crops, topk=3)
96
+ # set_defaults fills current_label/confidence/margin/badges and seeds samples
97
  set_defaults_from_preds(item)
98
  item["actions"] = ["pending"] * len(item["boxes"])
99
  sync_samples_with_state()
100
 
101
  def _bump_on_image_switch(ak: str, item: dict):
102
+ """Force a canvas remount when the active image changes."""
103
  if st.session_state.get("_last_active_key") != ak:
104
  item["canvas_rev"] = (item.get("canvas_rev", 0) + 1)
105
  st.session_state["_last_active_key"] = ak
106
 
107
+ # ---------------- main UI ----------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  def render_workflow_tab(predict_fn, class_names, class_to_id, model_meta, export_root):
109
  CANVAS_W = _auto_canvas_width()
110
  if "uploader_rev" not in st.session_state:
111
  st.session_state.uploader_rev = 0
112
+
113
  # =============== (1) Load images ===============
114
  st.subheader("1) Load image(s)")
115
  ups = st.file_uploader(
116
  "Upload one or more images",
117
  type=["jpg", "jpeg", "png", "bmp"],
118
  accept_multiple_files=True,
119
+ key=f"uploader_{st.session_state.uploader_rev}" # resettable
120
  )
121
  for up in (ups or []):
122
  key = f"{up.name}-{up.size}"
 
125
  st.session_state.images[key] = {
126
  "key": key, "name": up.name, "pil": pil,
127
  "boxes": [], "preds": [], "user_labels": [], "actions": [],
128
+ "detector_info": [], # <— conf/cls/name aligned with boxes
129
+ "canvas_json": {"objects": []},
130
  "canvas_rev": 0,
131
  }
132
  if st.session_state.active_key is None:
 
137
  st.info("Upload images or load from S3/Drive in the sidebar.")
138
  return
139
 
 
140
  names = [st.session_state.images[k]["name"] for k in keys]
141
  idx = keys.index(st.session_state.active_key) if st.session_state.active_key in keys else 0
142
  chosen = st.selectbox("Active image", names, index=idx)
 
155
  if st.session_state.assist_plus and not item.get("boxes") and not item.get("_skip_autodetect_once"):
156
  det_out = _run_detection(item)
157
  if det_out:
158
+ boxes_xyxy = [tuple(int(v) for v in o["bbox"]) for o in det_out]
159
+ det_info = []
160
+ for o in det_out:
161
+ conf = o.get("score", o.get("conf", o.get("confidence")))
162
+ clsid = o.get("cls", o.get("class"))
163
+ name = o.get("name")
164
+ det_info.append({
165
+ "conf": float(conf) if conf is not None else None,
166
+ "cls": int(clsid) if clsid is not None else None,
167
+ "name": name if name is not None else None,
168
+ })
169
+ item["boxes"] = boxes_xyxy
170
+ item["detector_info"] = det_info
171
+ item["canvas_json"] = seed_canvas_from_boxes(
172
+ boxes_xyxy, scale, det_scores=[d["conf"] for d in det_info]
173
+ )
174
  item["canvas_mode"] = "edit"
175
+ _parse_canvas_into_boxes(item, scale, base.width, base.height)
 
 
176
  _classify_now(item, predict_fn)
177
  st.rerun()
178
 
179
+ # Build display background
180
  bg_pil = ImageOps.exif_transpose(base).resize((disp_w, disp_h), Image.BILINEAR).convert("RGB").copy()
181
 
182
  # 2) Draw / edit rectangles
 
203
  else:
204
  st.caption("Tip: Drag on the image to draw new boxes.")
205
 
206
+ # ---------- initial drawing ----------
207
+ base_json = item.get("canvas_json") or {"objects": []}
208
+ show_labels = st.checkbox("Show box labels on canvas", value=True, key=f"show_labels_{ak}")
209
+ display_json = inject_index_labels(base_json, show_conf=False) if show_labels else base_json
 
210
 
211
  # ---------- Canvas ----------
212
  canvas = st_canvas(
213
  fill_color="rgba(0,0,0,0)",
214
  stroke_width=3,
215
  stroke_color="#FF9900",
216
+ background_image=bg_pil,
217
  update_streamlit=True,
218
  height=disp_h,
219
  width=disp_w,
220
+ drawing_mode=drawing_mode,
221
  display_toolbar=True,
222
+ initial_drawing=display_json,
223
  key=f"canvas_{ak}_{item.get('canvas_rev', 0)}_{disp_w}x{disp_h}"
224
  )
225
 
226
+ # ---------- Live canvas snapshot ----------
227
  live_json = canvas.json_data or base_json
228
  objs = list(live_json.get("objects") or [])
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
+ # Rect indices & objects (ignore overlay textboxes)
231
  rect_indices = [k for k, o in enumerate(objs) if o.get("type") == "rect"]
232
+ rect_objs = [objs[k] for k in rect_indices]
233
 
234
+ # Live boxes in original image coords
235
  live_boxes = parse_boxes(live_json, scale, base.width, base.height)
236
 
237
+ # Detector summary (use detector_info matched to live boxes)
238
+ infos = _match_det_info_to_live_boxes(item, live_boxes)
239
+ confs = [d["conf"] for d in infos if d and d.get("conf") is not None]
240
+ if confs:
241
+ n = len(confs)
242
+ cmin, cmax = min(confs), max(confs)
243
+ cmean = sum(confs)/n
244
+ th = float(st.session_state.get("det_cfg", {}).get("conf", 0.25))
245
+ below = sum(1 for c in confs if c < th)
246
+ st.markdown(
247
+ f"**Active image (detector)** — boxes: {n} · mean **{cmean:.3f}** · "
248
+ f"min **{cmin:.3f}** · max **{cmax:.3f}** · below {th:.2f}: **{below}/{n}**"
249
+ )
250
+
251
  # ---- Crops on canvas (live) ----
252
  if live_boxes:
253
  with st.expander(f"Crops on canvas (live) — {len(live_boxes)}", expanded=False):
254
  for i, b in enumerate(live_boxes):
255
+ info = infos[i] if i < len(infos) else None
256
+ det_conf = (info or {}).get("conf", None)
257
+ # resolve detector class label
258
+ det_label = None
259
+ if info:
260
+ if info.get("name") is not None:
261
+ det_label = str(info["name"])
262
+ elif info.get("cls") is not None:
263
+ names = st.session_state.get("detector_names") or []
264
+ cid = info["cls"]
265
+ if isinstance(names, (list, tuple)) and isinstance(cid, int) and 0 <= cid < len(names):
266
+ det_label = str(names[cid])
267
+ else:
268
+ det_label = str(cid)
269
+
270
+ c1, c2, c3 = st.columns([6, 3, 1])
271
  with c1:
272
  st.write(f"#{i+1} [x0={b[0]}, y0={b[1]}, x1={b[2]}, y1={b[3]}]")
273
+
274
  with c2:
275
+ if det_conf is None and not det_label:
276
+ st.caption("det: —")
277
+ else:
278
+ th = float(st.session_state.get("det_cfg", {}).get("conf", 0.25))
279
+ status = " ⚠️" if (det_conf is not None and det_conf < th) else ""
280
+ parts = []
281
+ if det_conf is not None: parts.append(f"{det_conf:.2f}")
282
+ if det_label: parts.append(det_label)
283
+ st.caption("det: " + " · ".join(parts) + status)
284
+
285
+ with c3:
286
  if st.button("Remove", key=f"rm_live_{ak}_{i}"):
287
+ # Remove by rect index, then remount canvas
288
+ true_idx = rect_indices[i] if i < len(rect_indices) else None
289
+ if true_idx is not None and 0 <= true_idx < len(objs):
290
+ # snapshot old for detector-info reindex
291
+ old_boxes = item.get("boxes") or []
292
+ old_info = item.get("detector_info") or []
293
+
294
+ objs.pop(true_idx)
295
+ # if a label textbox immediately follows, drop it too (best-effort)
296
+ if true_idx < len(objs) and objs[true_idx].get("type") == "textbox":
297
+ objs.pop(true_idx)
298
+
299
+ # update canvas + boxes
300
+ live_json["objects"] = objs
301
+ item["canvas_json"] = live_json
302
+ _parse_canvas_into_boxes(item, scale, base.width, base.height)
303
+
304
+ # reindex detector_info to new boxes
305
+ new_boxes = item.get("boxes") or []
306
+ item["detector_info"] = _reindex_det_info(old_boxes, old_info, new_boxes)
307
+
308
  item["canvas_mode"] = "edit"
309
  item["canvas_rev"] = (item.get("canvas_rev", 0) + 1)
 
310
  item["preds"] = []; item["user_labels"] = []; item["actions"] = []
311
  sync_samples_with_state()
312
  st.rerun()
313
 
314
+ # Always show a live count under the canvas
 
315
  st.caption(f"Crops in this image: **{len(live_boxes)}**")
316
 
317
+ # Commit canvas into boxes (Clear / Remove)
 
318
  col_clear, col_remove = st.columns([1, 1])
319
 
320
  if col_clear.button("Clear boxes (THIS image)", key=f"clear_{ak}"):
321
  item["canvas_json"] = {"objects": []}
322
  item["boxes"] = []; item["preds"] = []; item["user_labels"] = []; item["actions"] = []
323
+ item["detector_info"] = [] # also clear detector stats
324
  item["canvas_mode"] = "draw"
325
  item["canvas_rev"] = (item.get("canvas_rev", 0) + 1)
326
  item["_skip_autodetect_once"] = True
 
328
  st.rerun()
329
 
330
  if col_remove.button("Remove image", key=f"remove_{ak}"):
 
331
  remaining_keys = [k for k in st.session_state.images.keys() if k != ak]
332
  next_key = remaining_keys[0] if remaining_keys else None
 
 
333
  del st.session_state.images[ak]
334
  st.session_state.all_samples = [s for s in st.session_state.all_samples if s.image_key != ak]
 
 
335
  st.session_state.uploader_rev += 1
 
 
336
  st.session_state.active_key = next_key
337
  if next_key:
338
  nxt = st.session_state.images[next_key]
339
  nxt["_skip_autodetect_once"] = True
340
  nxt["canvas_mode"] = "edit" if nxt.get("boxes") else "draw"
341
+ nxt["canvas_rev"] = nxt.get("canvas_rev", 0) + 1
 
342
  st.rerun()
343
 
 
 
344
  # =============== (3) Detect / Classify controls ===============
345
  st.subheader("3) Detect & Classify")
346
 
347
  assist_on = bool(st.session_state.get("assist_plus"))
348
  has_boxes = bool(item.get("boxes"))
349
 
 
350
  canvas_objects = (canvas.json_data or {}).get("objects") or []
351
  has_canvas_rects = any(obj.get("type") == "rect" for obj in canvas_objects)
352
  can_classify = has_boxes or has_canvas_rects
353
 
354
  # ---------- DETECT ----------
355
  if assist_on:
 
356
  st.caption("Detect: find objects and write boxes onto the canvas.")
 
357
  dcol1, dcol2 = st.columns([1, 1])
358
 
359
  # First-time Detect (hidden if Assist+ already produced boxes)
 
366
  ):
367
  det_out = _run_detection(item)
368
  if det_out:
369
+ boxes_xyxy = [tuple(int(v) for v in o["bbox"]) for o in det_out]
370
+ det_info = []
371
+ for o in det_out:
372
+ conf = o.get("score", o.get("conf", o.get("confidence")))
373
+ clsid = o.get("cls", o.get("class"))
374
+ name = o.get("name")
375
+ det_info.append({
376
+ "conf": float(conf) if conf is not None else None,
377
+ "cls": int(clsid) if clsid is not None else None,
378
+ "name": name if name is not None else None,
379
+ })
380
+ item["boxes"] = boxes_xyxy
381
+ item["detector_info"] = det_info
382
+ item["canvas_json"] = seed_canvas_from_boxes(
383
+ boxes_xyxy, scale, det_scores=[d["conf"] for d in det_info]
384
+ )
385
  item["canvas_mode"] = "edit"
386
  item["canvas_rev"] = (item.get("canvas_rev", 0) + 1)
387
  item["preds"] = []; item["user_labels"] = []; item["actions"] = []
 
406
  ):
407
  det_out = _run_detection(item)
408
  new_boxes = [tuple(int(v) for v in o["bbox"]) for o in det_out]
409
+ new_info = []
410
+ for o in det_out:
411
+ conf = o.get("score", o.get("conf", o.get("confidence")))
412
+ clsid = o.get("cls", o.get("class"))
413
+ name = o.get("name")
414
+ new_info.append({
415
+ "conf": float(conf) if conf is not None else None,
416
+ "cls": int(clsid) if clsid is not None else None,
417
+ "name": name if name is not None else None,
418
+ })
419
  if new_boxes:
420
  if rd_mode == "Replace":
421
  item["boxes"] = new_boxes
422
+ item["detector_info"] = new_info
423
+ merged_boxes = new_boxes
424
+ merged_scores = [d["conf"] for d in new_info]
425
  else:
426
+ old_boxes = list(item.get("boxes") or [])
427
+ old_info = list(item.get("detector_info") or [])
428
+ merged_boxes = old_boxes + new_boxes
429
+ item["boxes"] = merged_boxes
430
+ item["detector_info"] = old_info + new_info
431
+ merged_scores = [d["conf"] for d in (old_info + new_info)]
432
+
433
+ item["canvas_json"] = seed_canvas_from_boxes(
434
+ merged_boxes, scale, det_scores=merged_scores
435
+ )
436
  item["canvas_mode"] = "edit"
 
437
  item["preds"] = []; item["user_labels"] = []; item["actions"] = []
438
  sync_samples_with_state()
439
  st.rerun()
 
464
  _classify_now(item, predict_fn)
465
  st.rerun()
466
 
 
 
 
 
467
  _render_active_summary(class_names)
468
 
469
  def _render_active_summary(class_names):
470
  st.subheader("Summary for Active Image")
471
+ ak = st.session_state.active_key
472
  if not st.session_state.all_samples or not ak:
473
+ st.info("No samples available. Run predictions on images first.")
474
+ return
475
+ active = [s for s in st.session_state.all_samples if s.image_key == ak]
476
+ if not active:
477
+ st.info("No samples processed for this image yet. Run predictions.")
478
+ return
479
+ active.sort(key=lambda s: s.crop_idx)
480
+ img_data = st.session_state.images[ak]
481
  for sample in active:
482
+ crop_image = img_data["pil"].crop(sample.bbox)
483
+ col_img, col_details, col_relabel = st.columns([1, 2, 1.5])
484
  with col_img:
485
  st.image(crop_image, caption=f"Crop {sample.crop_idx+1}", width=180)
486
  with col_details:
487
+ top3_str = ", ".join([f"{n} ({p:.3f})" for n, p in sample.top3[:3]])
488
  st.write(f"**Current Label:** {sample.current_label}")
489
  st.write(f"**Top-3:** {top3_str}")
490
  st.write(f"**Confidence:** {sample.confidence:.3f}")
 
498
  idx = class_names.index(sample.current_label) if sample.current_label in class_names else 0
499
  except Exception:
500
  idx = 0
501
+ new_label = st.selectbox(
502
+ f"Relabel Crop {sample.crop_idx+1}",
503
+ class_names,
504
+ index=idx,
505
+ key=f"wf_summary_relabel_select_{sample.image_key}_{sample.crop_idx}",
506
+ )
507
  if new_label != sample.current_label:
508
+ sample.user_label = new_label
509
+ sample.action = "relabel"
510
  from core.state import update_image_state_from_samples
511
+ update_image_state_from_samples()
512
+ st.rerun()
513
  st.divider()