righthook75 commited on
Commit
26a0321
·
verified ·
1 Parent(s): bf335e8

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +425 -347
app.py CHANGED
@@ -1,9 +1,10 @@
1
  import streamlit as st
 
2
  from PIL import Image
3
  from streamlit_drawable_canvas import st_canvas
4
 
5
  from sam3_engine import get_device, load_model, load_model_for_training, combined_prompt_inference
6
- from viz import overlay_masks, overlay_boxes, overlay_accepted
7
  from manifest import build_manifest, manifest_to_json, deduplicate
8
  from training import SAM3FineTuneDataset, freeze_encoder, run_training, get_model_zip_bytes
9
 
@@ -15,18 +16,19 @@ CANVAS_MAX_WIDTH = 700
15
 
16
  # --- Session state defaults ---
17
  defaults = {
18
- "step": 1,
19
  "image": None,
20
  "filename": None,
21
  "images": [], # list of (filename, PIL.Image) tuples
22
  "image_index": 0, # current position in batch
23
  "all_image_detections": [], # accumulated detections across ALL images
24
- "accepted_detections": [], # per-image accumulated across rounds
25
- "prompts": [], # list of prompt dicts for current image
26
- "prompt_counter": 0, # monotonic counter for prompt IDs
27
- "sam_results": [], # latest SAM3 results for current image
28
  "label_round": 0, # iteration counter for canvas key stability
29
  "canvas_scale": 1.0, # image-to-canvas scale factor
 
 
30
  "training_loss_history": [],
31
  "training_complete": False,
32
  "finetuned_model_bytes": None,
@@ -36,16 +38,6 @@ for key, val in defaults.items():
36
  st.session_state[key] = val
37
 
38
 
39
- def _reset_per_image_state():
40
- """Reset state that is specific to a single image."""
41
- st.session_state.accepted_detections = []
42
- st.session_state.prompts = []
43
- st.session_state.prompt_counter = 0
44
- st.session_state.sam_results = []
45
- st.session_state.label_round = 0
46
- st.session_state.canvas_scale = 1.0
47
-
48
-
49
  def _load_image_at_index(idx: int):
50
  """Load the image at the given batch index into session state."""
51
  filename, image = st.session_state.images[idx]
@@ -58,6 +50,26 @@ def go_to(step: int):
58
  st.session_state.step = step
59
 
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  # --- Coordinate scaling helpers ---
62
  def _canvas_to_image(obj: dict, scale: float):
63
  """Convert a Fabric.js canvas object to image-space coordinates."""
@@ -80,7 +92,6 @@ def _canvas_to_image(obj: dict, scale: float):
80
  ],
81
  }
82
  elif obj_type == "circle":
83
- # Points rendered as small circles
84
  r = obj.get("radius", 0)
85
  cx = (left + r * sx) / scale
86
  cy = (top + r * sy) / scale
@@ -91,56 +102,78 @@ def _canvas_to_image(obj: dict, scale: float):
91
  return None
92
 
93
 
94
- def _prompts_to_fabric_json(prompts: list, scale: float) -> dict:
95
- """Convert prompt list to Fabric.js JSON for initial_drawing."""
96
- objects = []
97
- for p in prompts:
98
- if p["type"] == "box":
99
- x1, y1, x2, y2 = p["coords"]
100
- objects.append({
101
- "type": "rect",
102
- "left": x1 * scale,
103
- "top": y1 * scale,
104
- "width": (x2 - x1) * scale,
105
- "height": (y2 - y1) * scale,
106
- "fill": "rgba(255, 0, 0, 0.1)",
107
- "stroke": "red",
108
- "strokeWidth": 2,
109
- "scaleX": 1,
110
- "scaleY": 1,
111
- })
112
- elif p["type"] == "point":
113
- cx, cy = p["coords"]
114
- color = "lime" if p.get("point_label", 1) == 1 else "red"
115
- objects.append({
116
- "type": "circle",
117
- "left": cx * scale - 5,
118
- "top": cy * scale - 5,
119
- "radius": 5,
120
- "fill": color,
121
- "stroke": color,
122
- "strokeWidth": 1,
123
- "scaleX": 1,
124
- "scaleY": 1,
125
- })
126
- return {"version": "5.3.0", "objects": objects}
127
-
128
-
129
- def _next_prompt_id() -> str:
130
- st.session_state.prompt_counter += 1
131
- return f"p{st.session_state.prompt_counter}"
132
-
133
-
134
- def _accepted_to_prompts(detections: list):
135
- """Convert accepted detections into box prompts and add to prompt list."""
136
- for det in detections:
137
- st.session_state.prompts.append({
138
- "id": _next_prompt_id(),
139
- "type": "box",
140
- "coords": det["box"],
141
- "label": det.get("label", ""),
142
- "point_label": None,
143
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
 
146
  # --- Sidebar ---
@@ -149,11 +182,14 @@ with st.sidebar:
149
  device = get_device()
150
  st.caption(f"Device: **{device}**")
151
  st.caption("Model: `facebook/sam3`")
 
 
 
152
  st.divider()
153
 
154
- step_labels = ["Upload", "Label", "Export", "Train"]
155
  current = st.session_state.step
156
- for i, label in enumerate(step_labels, start=1):
157
  if current == i:
158
  marker = f"-> {i}. {label}"
159
  else:
@@ -165,14 +201,10 @@ with st.sidebar:
165
  st.divider()
166
  st.metric("Image", f"{st.session_state.image_index + 1} of {n_images}")
167
 
168
- total_all = len(st.session_state.all_image_detections) + len(st.session_state.accepted_detections)
169
- accepted = st.session_state.accepted_detections
170
- if total_all:
171
  st.divider()
172
- if n_images > 1:
173
- st.metric("Accepted (all images)", total_all)
174
- else:
175
- st.metric("Total accepted", len(accepted))
176
 
177
  st.divider()
178
  if st.button("Start over"):
@@ -182,155 +214,84 @@ with st.sidebar:
182
 
183
 
184
  # =============================================================================
185
- # Step 1: Upload
186
- # =============================================================================
187
- if st.session_state.step == 1:
188
- st.header("Step 1: Upload Images")
189
- uploaded_files = st.file_uploader(
190
- "Choose one or more images (PNG/JPG)",
191
- type=["png", "jpg", "jpeg"],
192
- accept_multiple_files=True,
193
- )
194
- if uploaded_files:
195
- images = [(f.name, Image.open(f).convert("RGB")) for f in uploaded_files]
196
- st.session_state.images = images
197
- st.session_state.image_index = 0
198
-
199
- # Show thumbnail grid
200
- n = len(images)
201
- cols = st.columns(min(n, 4))
202
- for i, (name, img) in enumerate(images):
203
- with cols[i % len(cols)]:
204
- st.image(img, caption=name, width="stretch")
205
-
206
- # Load first image
207
- _load_image_at_index(0)
208
-
209
- label = f"Next: Label images (1 of {n})" if n > 1 else "Next: Label image"
210
- if st.button(label):
211
- go_to(2)
212
- st.rerun()
213
-
214
- # =============================================================================
215
- # Step 2: Label (interactive canvas + prompts + SAM3)
216
  # =============================================================================
217
- elif st.session_state.step == 2:
218
- image = st.session_state.image
219
- if image is None:
220
- st.warning("No image loaded. Go back to Upload.")
221
- if st.button("Back to Upload"):
222
- go_to(1)
223
- st.rerun()
224
- else:
225
- n_images = len(st.session_state.images)
226
- img_idx = st.session_state.image_index
227
- img_label = f" ({img_idx + 1} of {n_images})" if n_images > 1 else ""
228
- st.header(f"Step 2: Label — {st.session_state.filename}{img_label}")
229
-
230
- # Compute canvas dimensions
231
- img_w, img_h = image.size
232
- canvas_w = min(img_w, CANVAS_MAX_WIDTH)
233
- scale = canvas_w / img_w
234
- canvas_h = int(img_h * scale)
235
- st.session_state.canvas_scale = scale
236
-
237
- # Build background image with accepted detections + SAM results overlay
238
- bg = image.copy()
239
- if st.session_state.accepted_detections:
240
- bg = overlay_accepted(bg, st.session_state.accepted_detections)
241
- if st.session_state.sam_results:
242
- masks = [d["mask"] for d in st.session_state.sam_results]
243
- boxes = [d["box"] for d in st.session_state.sam_results]
244
- bg = overlay_masks(bg, masks)
245
- bg = overlay_boxes(bg, boxes)
246
- bg_rgb = bg.convert("RGB")
247
-
248
- # --- Two-column layout ---
249
- col_canvas, col_controls = st.columns([3, 2])
250
-
251
- with col_controls:
252
- st.subheader("Prompts")
253
-
254
- # Text prompt input
255
- text_col, btn_col = st.columns([3, 1])
256
- with text_col:
257
- text_input = st.text_input("Text prompt", key="text_prompt_input", label_visibility="collapsed", placeholder="Describe objects to find...")
258
- with btn_col:
259
- if st.button("Add text", disabled=not text_input):
260
- st.session_state.prompts.append({
261
- "id": _next_prompt_id(),
262
- "type": "text",
263
- "coords": [],
264
- "label": text_input,
265
- "point_label": None,
266
- })
267
- st.rerun()
268
-
269
- # Prompt table
270
- prompts = st.session_state.prompts
271
- if prompts:
272
- st.caption(f"{len(prompts)} prompt(s)")
273
- for i, p in enumerate(prompts):
274
- pcol1, pcol2, pcol3, pcol4 = st.columns([1, 2, 3, 1])
275
- with pcol1:
276
- st.text(p["id"])
277
- with pcol2:
278
- st.text(p["type"])
279
- with pcol3:
280
- new_label = st.text_input(
281
- "label", value=p.get("label", ""), key=f"plabel_{p['id']}",
282
- label_visibility="collapsed",
283
- )
284
- if new_label != p.get("label", ""):
285
- st.session_state.prompts[i]["label"] = new_label
286
- with pcol4:
287
- if p["type"] == "point":
288
- is_pos = p.get("point_label", 1) == 1
289
- toggled = st.checkbox("+", value=is_pos, key=f"ptoggle_{p['id']}")
290
- st.session_state.prompts[i]["point_label"] = 1 if toggled else 0
291
- if st.button("X", key=f"pdel_{p['id']}"):
292
- st.session_state.prompts.pop(i)
293
- st.rerun()
294
- else:
295
- st.caption("No prompts yet. Draw boxes or points on the canvas, or add text prompts above.")
296
-
297
- # Threshold
298
- threshold = st.slider("Confidence threshold", 0.0, 1.0, 0.5, 0.05, key="label_threshold")
299
-
300
- # Run SAM3 button
301
- @st.fragment
302
- def run_sam3():
303
- prompts = st.session_state.prompts
304
- has_prompts = len(prompts) > 0
305
- if st.button("Run SAM3", type="primary", disabled=not has_prompts):
306
- # Gather prompts by type
307
- text_parts = [p["label"] for p in prompts if p["type"] == "text" and p["label"]]
308
- text_combined = ". ".join(text_parts) if text_parts else None
309
-
310
- box_list = [p["coords"] for p in prompts if p["type"] == "box" and len(p["coords"]) == 4]
311
- box_list = box_list if box_list else None
312
-
313
- pt_prompts = [p for p in prompts if p["type"] == "point" and len(p["coords"]) == 2]
314
- points = [p["coords"] for p in pt_prompts] if pt_prompts else None
315
- point_labels = [p.get("point_label", 1) for p in pt_prompts] if pt_prompts else None
316
-
317
- status = st.status("Running SAM3 inference...", expanded=True)
318
- status.write(f"Running on {get_device()} with {len(prompts)} prompt(s)...")
319
- results = combined_prompt_inference(
320
- image, text=text_combined, boxes=box_list,
321
- points=points, point_labels=point_labels,
322
- threshold=threshold,
323
- )
324
- status.write(f"Found {len(results)} objects!")
325
- status.update(label="Inference complete", state="complete")
326
- st.session_state.sam_results = results
327
- st.session_state.label_round += 1
328
- st.rerun(scope="app")
329
-
330
- run_sam3()
331
 
332
- with col_canvas:
333
- # Drawing mode selector
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
  drawing_mode = st.radio(
335
  "Drawing mode",
336
  ["rect", "point", "transform"],
@@ -338,12 +299,6 @@ elif st.session_state.step == 2:
338
  key="drawing_mode",
339
  )
340
 
341
- # Build initial_drawing from existing prompts
342
- initial = _prompts_to_fabric_json(
343
- [p for p in st.session_state.prompts if p["type"] in ("box", "point")],
344
- scale,
345
- )
346
-
347
  canvas_result = st_canvas(
348
  fill_color="rgba(255, 0, 0, 0.1)",
349
  stroke_width=2,
@@ -353,118 +308,257 @@ elif st.session_state.step == 2:
353
  height=canvas_h,
354
  drawing_mode=drawing_mode,
355
  point_display_radius=5,
356
- initial_drawing=initial,
357
  key=f"canvas_{img_idx}_{st.session_state.label_round}",
358
  )
359
 
360
- # Sync canvas objects back to prompts
361
  if canvas_result.json_data is not None:
362
  canvas_objects = canvas_result.json_data.get("objects", [])
363
- # Count non-text prompts currently in state
364
- spatial_prompts = [p for p in st.session_state.prompts if p["type"] in ("box", "point")]
365
- n_existing = len(spatial_prompts)
366
  n_canvas = len(canvas_objects)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
367
 
368
- if n_canvas > n_existing:
369
- # New objects drawn on canvas — append them
370
- for obj in canvas_objects[n_existing:]:
371
- converted = _canvas_to_image(obj, scale)
372
- if converted:
373
- st.session_state.prompts.append({
374
- "id": _next_prompt_id(),
375
- "type": converted["type"],
376
- "coords": converted["coords"],
377
- "label": "",
378
- "point_label": 1 if converted["type"] == "point" else None,
379
- })
380
- st.rerun()
381
-
382
- # --- SAM3 Results Section ---
383
- if st.session_state.sam_results:
384
- st.divider()
385
- st.subheader(f"SAM3 Results {len(st.session_state.sam_results)} detections")
386
-
387
- # Show overlay
388
- results_vis = image.copy()
389
- if st.session_state.accepted_detections:
390
- results_vis = overlay_accepted(results_vis, st.session_state.accepted_detections)
391
- masks = [d["mask"] for d in st.session_state.sam_results]
392
- boxes = [d["box"] for d in st.session_state.sam_results]
393
- det_labels = [f"Detection {i+1}" for i in range(len(boxes))]
394
- results_vis = overlay_masks(results_vis, masks)
395
- results_vis = overlay_boxes(results_vis, boxes, labels=det_labels)
396
- st.image(results_vis, caption="SAM3 results", width="stretch")
397
-
398
- # Batch accept
399
- batch_label = st.text_input("Label for all detections", key="batch_accept_label")
400
- col_accept, col_discard = st.columns(2)
401
- with col_accept:
402
- if st.button("Accept all", type="primary", disabled=not batch_label):
403
- for det in st.session_state.sam_results:
404
- det["accepted"] = True
405
- det["label"] = batch_label
406
- new_accepted = st.session_state.sam_results
407
- unique = deduplicate(new_accepted, st.session_state.accepted_detections)
408
- st.session_state.accepted_detections.extend(unique)
409
- _accepted_to_prompts(unique)
410
- st.session_state.sam_results = []
411
- st.rerun()
412
- with col_discard:
413
- if st.button("Discard all"):
414
- st.session_state.sam_results = []
415
- st.rerun()
416
-
417
- # Individual review in expander
418
- with st.expander("Review individual detections"):
419
- for i, det in enumerate(st.session_state.sam_results):
420
- det_col1, det_col2 = st.columns([3, 1])
421
- with det_col1:
422
- st.text(f"Detection {i+1} — score: {det['score']:.3f} — box: [{', '.join(f'{c:.0f}' for c in det['box'])}]")
423
- with det_col2:
424
- ind_label = st.text_input("Label", key=f"det_label_{i}", label_visibility="collapsed")
425
- if st.button("Accept", key=f"det_accept_{i}"):
426
- det["accepted"] = True
427
- det["label"] = ind_label
428
- unique = deduplicate([det], st.session_state.accepted_detections)
429
- st.session_state.accepted_detections.extend(unique)
430
- _accepted_to_prompts(unique)
431
- st.session_state.sam_results.pop(i)
432
- st.rerun()
433
-
434
- # Accepted count
435
- if st.session_state.accepted_detections:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
436
  st.divider()
437
- st.info(f"**{len(st.session_state.accepted_detections)}** accepted detections for this image")
438
 
439
- # --- Navigation ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
440
  st.divider()
441
- nav_cols = st.columns(3)
442
- with nav_cols[0]:
443
- if st.button("Back to Upload"):
444
- go_to(1)
445
- st.rerun()
446
- with nav_cols[1]:
447
- has_next_image = n_images > 1 and img_idx < n_images - 1
448
- if has_next_image:
449
- next_name = st.session_state.images[img_idx + 1][0]
450
- if st.button(f"Next image: {next_name}"):
451
- # Stamp image_path and merge
452
- for det in st.session_state.accepted_detections:
453
- det["image_path"] = st.session_state.filename
454
- st.session_state.all_image_detections.extend(st.session_state.accepted_detections)
455
- _reset_per_image_state()
456
- _load_image_at_index(img_idx + 1)
457
- st.rerun()
458
- with nav_cols[2]:
459
- total = len(st.session_state.all_image_detections) + len(st.session_state.accepted_detections)
460
- if st.button(f"Done — Export ({total} detections)" if total else "Done — Export"):
461
- # Stamp and merge current image
462
- for det in st.session_state.accepted_detections:
463
- det["image_path"] = st.session_state.filename
464
- st.session_state.all_image_detections.extend(st.session_state.accepted_detections)
465
- st.session_state.accepted_detections = []
466
- go_to(3)
467
- st.rerun()
468
 
469
  # =============================================================================
470
  # Step 3: Export
@@ -473,7 +567,6 @@ elif st.session_state.step == 3:
473
  st.header("Step 3: Export Manifest")
474
 
475
  combined = list(st.session_state.all_image_detections)
476
- # Re-index combined IDs
477
  for i, det in enumerate(combined):
478
  det["id"] = i
479
 
@@ -514,14 +607,11 @@ elif st.session_state.step == 3:
514
  elif st.session_state.step == 4:
515
  st.header("Step 4: Fine-Tune SAM3")
516
 
517
- # Build combined detections list
518
  combined_dets = list(st.session_state.all_image_detections)
519
- # Stamp image_path on detections if not set
520
  for det in combined_dets:
521
  if "image_path" not in det:
522
  det["image_path"] = st.session_state.filename
523
 
524
- # Only keep detections with masks
525
  train_dets = [d for d in combined_dets if d.get("accepted") and d.get("mask") is not None]
526
  image_names = list(set(d["image_path"] for d in train_dets))
527
 
@@ -533,7 +623,6 @@ elif st.session_state.step == 4:
533
  go_to(3)
534
  st.rerun()
535
  else:
536
- # Hyperparameters
537
  col_ep, col_lr = st.columns(2)
538
  with col_ep:
539
  epochs = st.slider("Epochs", 1, 50, 5, key="train_epochs")
@@ -554,7 +643,6 @@ elif st.session_state.step == 4:
554
  processor = None
555
  result = None
556
  try:
557
- # 1. Free GPU memory from cached inference model
558
  status = st.status("Preparing for training...", expanded=True)
559
  status.write("Clearing cached inference model to free GPU memory...")
560
  load_model.clear()
@@ -563,20 +651,16 @@ elif st.session_state.step == 4:
563
  elif _torch.backends.mps.is_available():
564
  _torch.mps.empty_cache()
565
 
566
- # 2. Load fresh trainable model
567
  status.write("Loading fresh model for training...")
568
  processor, model = load_model_for_training()
569
 
570
- # 3. Freeze encoder
571
  trainable, total = freeze_encoder(model)
572
  status.write(f"Frozen encoder. Trainable params: {trainable:,} / {total:,}")
573
 
574
- # 4. Build dataset
575
  images_dict = {name: img for name, img in st.session_state.images}
576
  dataset = SAM3FineTuneDataset(images_dict, train_dets, processor)
577
  status.write(f"Dataset ready: {len(dataset)} samples")
578
 
579
- # 5. Train with progress bar
580
  status.update(label="Training...", state="running")
581
  progress_bar = st.progress(0, text="Starting training...")
582
 
@@ -589,14 +673,12 @@ elif st.session_state.step == 4:
589
 
590
  st.session_state.training_loss_history = result["loss_history"]
591
 
592
- # 6. Save model zip
593
  status.write("Packaging fine-tuned model...")
594
  st.session_state.finetuned_model_bytes = get_model_zip_bytes(result["model"], processor)
595
 
596
  st.session_state.training_complete = True
597
  status.update(label="Training complete!", state="complete")
598
  finally:
599
- # Always clean up GPU memory, even if stopped/interrupted
600
  del model, processor, result
601
  if _torch.cuda.is_available():
602
  _torch.cuda.empty_cache()
@@ -605,17 +687,13 @@ elif st.session_state.step == 4:
605
 
606
  st.rerun()
607
  else:
608
- # Post-training UI
609
  st.success("Training complete!")
610
 
611
- # Loss curve
612
  loss_hist = st.session_state.training_loss_history
613
  if loss_hist:
614
- import pandas as pd
615
  df = pd.DataFrame({"Epoch": range(1, len(loss_hist) + 1), "Avg Loss": loss_hist})
616
  st.line_chart(df, x="Epoch", y="Avg Loss")
617
 
618
- # Download button
619
  if st.session_state.finetuned_model_bytes:
620
  st.download_button(
621
  label="Download fine-tuned model (.zip)",
 
1
  import streamlit as st
2
+ import pandas as pd
3
  from PIL import Image
4
  from streamlit_drawable_canvas import st_canvas
5
 
6
  from sam3_engine import get_device, load_model, load_model_for_training, combined_prompt_inference
7
+ from viz import overlay_detections_by_class, _hex_to_rgb, CLASS_COLORS
8
  from manifest import build_manifest, manifest_to_json, deduplicate
9
  from training import SAM3FineTuneDataset, freeze_encoder, run_training, get_model_zip_bytes
10
 
 
16
 
17
  # --- Session state defaults ---
18
  defaults = {
19
+ "step": 2,
20
  "image": None,
21
  "filename": None,
22
  "images": [], # list of (filename, PIL.Image) tuples
23
  "image_index": 0, # current position in batch
24
  "all_image_detections": [], # accumulated detections across ALL images
25
+ "classes": [], # list of class dicts
26
+ "pending_box_coords": None, # drawn box awaiting class assignment
27
+ "detection_id_counter": 0, # monotonic ID for detections
 
28
  "label_round": 0, # iteration counter for canvas key stability
29
  "canvas_scale": 1.0, # image-to-canvas scale factor
30
+ "_last_canvas_count": 0, # track canvas object count for new-drawing detection
31
+ "selected_detection_id": None, # ID of detection selected for highlighting
32
  "training_loss_history": [],
33
  "training_complete": False,
34
  "finetuned_model_bytes": None,
 
38
  st.session_state[key] = val
39
 
40
 
 
 
 
 
 
 
 
 
 
 
41
  def _load_image_at_index(idx: int):
42
  """Load the image at the given batch index into session state."""
43
  filename, image = st.session_state.images[idx]
 
50
  st.session_state.step = step
51
 
52
 
53
+ def _next_detection_id() -> int:
54
+ st.session_state.detection_id_counter += 1
55
+ return st.session_state.detection_id_counter
56
+
57
+
58
+ def _get_current_image_detections(visible_only=False):
59
+ """Get all detections for the current image across all classes."""
60
+ fname = st.session_state.filename
61
+ if not fname:
62
+ return []
63
+ dets = []
64
+ for cls in st.session_state.classes:
65
+ if visible_only and not cls["visible"]:
66
+ continue
67
+ for det in cls["detections"]:
68
+ if det.get("image_path") == fname:
69
+ dets.append(det)
70
+ return dets
71
+
72
+
73
  # --- Coordinate scaling helpers ---
74
  def _canvas_to_image(obj: dict, scale: float):
75
  """Convert a Fabric.js canvas object to image-space coordinates."""
 
92
  ],
93
  }
94
  elif obj_type == "circle":
 
95
  r = obj.get("radius", 0)
96
  cx = (left + r * sx) / scale
97
  cy = (top + r * sy) / scale
 
102
  return None
103
 
104
 
105
+ def _add_class(name: str):
106
+ """Create a new class and return it."""
107
+ color = CLASS_COLORS[len(st.session_state.classes) % len(CLASS_COLORS)]
108
+ cls = {
109
+ "name": name,
110
+ "color": color,
111
+ "visible": True,
112
+ "threshold": 0.85,
113
+ "detections": [],
114
+ }
115
+ st.session_state.classes.append(cls)
116
+ return cls
117
+
118
+
119
+ @st.dialog("Assign to Class")
120
+ def assign_drawing_dialog():
121
+ """Modal dialog for assigning a drawn box/point to a class."""
122
+ pending = st.session_state.pending_box_coords
123
+ if pending is None:
124
+ st.warning("No pending drawing.")
125
+ return
126
+
127
+ st.write(f"New **{pending['type']}** drawn. Choose a class to assign it to:")
128
+
129
+ # Existing class selector
130
+ class_names = [c["name"] for c in st.session_state.classes]
131
+ chosen_existing = None
132
+ if class_names:
133
+ chosen_existing = st.selectbox("Existing class", class_names, key="dlg_class_select")
134
+
135
+ # Or create a new class
136
+ st.divider()
137
+ new_name = st.text_input("Or create a new class", key="dlg_new_class", placeholder="e.g. Cable, Label...")
138
+
139
+ st.divider()
140
+ assign_col, cancel_col = st.columns(2)
141
+ with assign_col:
142
+ can_assign = bool(new_name) or bool(chosen_existing)
143
+ if st.button("Assign", type="primary", disabled=not can_assign, use_container_width=True):
144
+ # Determine target class
145
+ if new_name:
146
+ existing_names = {c["name"] for c in st.session_state.classes}
147
+ if new_name not in existing_names:
148
+ target_cls = _add_class(new_name)
149
+ else:
150
+ target_cls = next(c for c in st.session_state.classes if c["name"] == new_name)
151
+ else:
152
+ target_cls = next(c for c in st.session_state.classes if c["name"] == chosen_existing)
153
+
154
+ det = {
155
+ "id": _next_detection_id(),
156
+ "mask": None,
157
+ "box": pending["coords"] if pending["type"] == "box" else [
158
+ pending["coords"][0] - 10, pending["coords"][1] - 10,
159
+ pending["coords"][0] + 10, pending["coords"][1] + 10,
160
+ ],
161
+ "score": 1.0,
162
+ "label": target_cls["name"],
163
+ "accepted": True,
164
+ "image_path": st.session_state.filename,
165
+ }
166
+ target_cls["detections"].append(det)
167
+ st.session_state.pending_box_coords = None
168
+ st.session_state.label_round += 1
169
+ st.session_state._last_canvas_count = 0
170
+ st.rerun()
171
+ with cancel_col:
172
+ if st.button("Cancel", use_container_width=True):
173
+ st.session_state.pending_box_coords = None
174
+ st.session_state.label_round += 1
175
+ st.session_state._last_canvas_count = 0
176
+ st.rerun()
177
 
178
 
179
  # --- Sidebar ---
 
182
  device = get_device()
183
  st.caption(f"Device: **{device}**")
184
  st.caption("Model: `facebook/sam3`")
185
+ with st.spinner("Loading SAM3 model..."):
186
+ load_model()
187
+ st.caption("Model loaded")
188
  st.divider()
189
 
190
+ step_labels = ["Label", "Export", "Train"]
191
  current = st.session_state.step
192
+ for i, label in enumerate(step_labels, start=2):
193
  if current == i:
194
  marker = f"-> {i}. {label}"
195
  else:
 
201
  st.divider()
202
  st.metric("Image", f"{st.session_state.image_index + 1} of {n_images}")
203
 
204
+ total_dets = sum(len(c["detections"]) for c in st.session_state.classes)
205
+ if total_dets:
 
206
  st.divider()
207
+ st.metric("Total detections", total_dets)
 
 
 
208
 
209
  st.divider()
210
  if st.button("Start over"):
 
214
 
215
 
216
  # =============================================================================
217
+ # Step 2: Label (3-column class-centric layout)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  # =============================================================================
219
+ if st.session_state.step == 2:
220
+ col_files, col_canvas, col_controls = st.columns([1, 3, 2])
221
+
222
+ # --- Left column: File list ---
223
+ with col_files:
224
+ st.subheader("Images")
225
+ uploaded_files = st.file_uploader(
226
+ "Upload images",
227
+ type=["png", "jpg", "jpeg"],
228
+ accept_multiple_files=True,
229
+ label_visibility="collapsed",
230
+ )
231
+ if uploaded_files:
232
+ existing_names = {name for name, _ in st.session_state.images}
233
+ for f in uploaded_files:
234
+ if f.name not in existing_names:
235
+ st.session_state.images.append((f.name, Image.open(f).convert("RGB")))
236
+ existing_names.add(f.name)
237
+ # Auto-load first image if none loaded
238
+ if st.session_state.image is None and st.session_state.images:
239
+ _load_image_at_index(0)
240
+ st.rerun()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
 
242
+ # Show file list with thumbnails
243
+ if st.session_state.images:
244
+ filenames = [name for name, _ in st.session_state.images]
245
+ for i, (name, img) in enumerate(st.session_state.images):
246
+ st.image(img, width=100)
247
+ is_current = (i == st.session_state.image_index)
248
+ if st.button(
249
+ name,
250
+ key=f"file_select_{i}",
251
+ type="primary" if is_current else "secondary",
252
+ use_container_width=True,
253
+ ):
254
+ if not is_current:
255
+ _load_image_at_index(i)
256
+ st.session_state.label_round += 1
257
+ st.session_state._last_canvas_count = 0
258
+ st.session_state.pending_box_coords = None
259
+ st.session_state.selected_detection_id = None
260
+ st.rerun()
261
+
262
+ # --- Center column: Canvas ---
263
+ with col_canvas:
264
+ image = st.session_state.image
265
+ if image is None:
266
+ st.info("Upload images in the left panel to get started.")
267
+ else:
268
+ img_idx = st.session_state.image_index
269
+ n_images = len(st.session_state.images)
270
+ img_label = f" ({img_idx + 1} of {n_images})" if n_images > 1 else ""
271
+ st.subheader(f"{st.session_state.filename}{img_label}")
272
+
273
+ # Compute canvas dimensions
274
+ img_w, img_h = image.size
275
+ canvas_w = min(img_w, CANVAS_MAX_WIDTH)
276
+ scale = canvas_w / img_w
277
+ canvas_h = int(img_h * scale)
278
+ st.session_state.canvas_scale = scale
279
+
280
+ # Build background with visible detections overlaid
281
+ visible_dets = _get_current_image_detections(visible_only=True)
282
+ bg = image.copy()
283
+ if visible_dets:
284
+ # Build color map from class definitions
285
+ color_map = {}
286
+ for cls in st.session_state.classes:
287
+ if cls["visible"]:
288
+ color_map[cls["name"]] = _hex_to_rgb(cls["color"])
289
+ color_map[""] = (180, 180, 180)
290
+ hl_ids = {st.session_state.selected_detection_id} if st.session_state.selected_detection_id is not None else None
291
+ bg = overlay_detections_by_class(bg, visible_dets, color_override=color_map, highlight_ids=hl_ids)
292
+ bg_rgb = bg.convert("RGB")
293
+
294
+ # Drawing mode
295
  drawing_mode = st.radio(
296
  "Drawing mode",
297
  ["rect", "point", "transform"],
 
299
  key="drawing_mode",
300
  )
301
 
 
 
 
 
 
 
302
  canvas_result = st_canvas(
303
  fill_color="rgba(255, 0, 0, 0.1)",
304
  stroke_width=2,
 
308
  height=canvas_h,
309
  drawing_mode=drawing_mode,
310
  point_display_radius=5,
 
311
  key=f"canvas_{img_idx}_{st.session_state.label_round}",
312
  )
313
 
314
+ # Detect new drawings
315
  if canvas_result.json_data is not None:
316
  canvas_objects = canvas_result.json_data.get("objects", [])
 
 
 
317
  n_canvas = len(canvas_objects)
318
+ last_count = st.session_state._last_canvas_count
319
+
320
+ if n_canvas > last_count and st.session_state.pending_box_coords is None:
321
+ # New object drawn — convert the last one
322
+ new_obj = canvas_objects[-1]
323
+ converted = _canvas_to_image(new_obj, scale)
324
+ if converted:
325
+ st.session_state.pending_box_coords = converted
326
+ st.session_state._last_canvas_count = n_canvas
327
+ st.rerun()
328
+
329
+ # Open assignment dialog when a new drawing is pending
330
+ if st.session_state.pending_box_coords is not None:
331
+ assign_drawing_dialog()
332
+
333
+ # --- Right column: Class controls ---
334
+ with col_controls:
335
+ st.subheader("Classes")
336
+
337
+ # Class input
338
+ new_class = st.text_input("New class name", key="new_class_input", placeholder="e.g. Server, Cable, Label...")
339
+ if new_class:
340
+ existing_names = {c["name"] for c in st.session_state.classes}
341
+ if new_class not in existing_names:
342
+ color = CLASS_COLORS[len(st.session_state.classes) % len(CLASS_COLORS)]
343
+ st.session_state.classes.append({
344
+ "name": new_class,
345
+ "color": color,
346
+ "visible": True,
347
+ "threshold": 0.85,
348
+ "detections": [],
349
+ })
350
+ st.rerun()
351
 
352
+ # Class cards
353
+ classes_to_delete = []
354
+ dets_to_delete = [] # list of (class_idx, det_id)
355
+ find_single_class_idx = None # index of class to run per-class find
356
+
357
+ for ci, cls in enumerate(st.session_state.classes):
358
+ with st.container(border=True):
359
+ # Header row
360
+ hcol_name, hcol_vis, hcol_del = st.columns([3, 1, 1])
361
+ with hcol_name:
362
+ st.markdown(
363
+ f"<span style='color:{cls['color']};font-weight:bold;font-size:1.1em'>"
364
+ f"{cls['name']}</span>",
365
+ unsafe_allow_html=True,
366
+ )
367
+ with hcol_vis:
368
+ vis = st.checkbox("👁", value=cls["visible"], key=f"vis_{ci}", label_visibility="collapsed")
369
+ if vis != cls["visible"]:
370
+ st.session_state.classes[ci]["visible"] = vis
371
+ st.rerun()
372
+ with hcol_del:
373
+ if st.button("🗑", key=f"del_class_{ci}"):
374
+ classes_to_delete.append(ci)
375
+
376
+ # Detections for current image — colored buttons
377
+ fname = st.session_state.filename
378
+ if fname:
379
+ img_dets = [d for d in cls["detections"] if d.get("image_path") == fname]
380
+ if img_dets:
381
+ for det in img_dets:
382
+ dcol_label, dcol_del = st.columns([4, 1])
383
+ with dcol_label:
384
+ is_sel = st.session_state.selected_detection_id == det["id"]
385
+ # Colored detection button via markdown + button
386
+ border_style = "3px solid yellow" if is_sel else f"2px solid {cls['color']}"
387
+ st.markdown(
388
+ f"<div style='background:{cls['color']}22;border:{border_style};"
389
+ f"border-radius:6px;padding:4px 8px;text-align:center;"
390
+ f"color:{cls['color']};font-weight:600;cursor:default'>"
391
+ f"{cls['name']} {det['id']} {det['score']:.0%}</div>",
392
+ unsafe_allow_html=True,
393
+ )
394
+ if st.button(
395
+ "Select" if not is_sel else "Deselect",
396
+ key=f"sel_det_{ci}_{det['id']}",
397
+ use_container_width=True,
398
+ ):
399
+ if is_sel:
400
+ st.session_state.selected_detection_id = None
401
+ else:
402
+ st.session_state.selected_detection_id = det["id"]
403
+ st.rerun()
404
+ with dcol_del:
405
+ if st.button("🗑", key=f"del_det_{ci}_{det['id']}"):
406
+ dets_to_delete.append((ci, det["id"]))
407
+ else:
408
+ st.caption("No detections on this image")
409
+
410
+ # Per-class confidence threshold
411
+ new_thresh = st.slider(
412
+ "Confidence threshold", 0.0, 1.0, cls["threshold"], 0.05,
413
+ key=f"thresh_{ci}",
414
+ )
415
+ st.caption(f"Default 85%")
416
+ if new_thresh != cls["threshold"]:
417
+ st.session_state.classes[ci]["threshold"] = new_thresh
418
+
419
+ # Per-class Find Objects button
420
+ if st.session_state.image is not None:
421
+ if st.button(f"🔍 Find Objects for this Class", key=f"find_class_{ci}", use_container_width=True):
422
+ find_single_class_idx = ci
423
+
424
+ # Process deletions
425
+ if classes_to_delete:
426
+ for ci in sorted(classes_to_delete, reverse=True):
427
+ st.session_state.classes.pop(ci)
428
+ st.rerun()
429
+
430
+ if dets_to_delete:
431
+ for ci, det_id in dets_to_delete:
432
+ if st.session_state.selected_detection_id == det_id:
433
+ st.session_state.selected_detection_id = None
434
+ st.session_state.classes[ci]["detections"] = [
435
+ d for d in st.session_state.classes[ci]["detections"] if d["id"] != det_id
436
+ ]
437
+ st.session_state.label_round += 1
438
+ st.session_state._last_canvas_count = 0
439
+ st.rerun()
440
+
441
+ # --- Per-class Find Objects execution ---
442
+ if find_single_class_idx is not None:
443
+ cls = st.session_state.classes[find_single_class_idx]
444
+ image = st.session_state.image
445
+ fname = st.session_state.filename
446
+ status = st.status(f"Finding {cls['name']}...", expanded=True)
447
+ status.write(f"Running on {get_device()} (threshold {cls['threshold']:.0%})...")
448
+
449
+ existing_boxes = [
450
+ d["box"] for d in cls["detections"]
451
+ if d.get("image_path") == fname
452
+ ]
453
+ dets = combined_prompt_inference(
454
+ image,
455
+ text=cls["name"],
456
+ boxes=existing_boxes if existing_boxes else None,
457
+ threshold=cls["threshold"],
458
+ )
459
+ for d in dets:
460
+ d["label"] = cls["name"]
461
+ d["accepted"] = True
462
+ d["image_path"] = fname
463
+ d["id"] = _next_detection_id()
464
+
465
+ existing_for_class = [
466
+ d for d in cls["detections"]
467
+ if d.get("image_path") == fname
468
+ ]
469
+ unique = deduplicate(dets, existing_for_class) if existing_for_class else dets
470
+ cls["detections"].extend(unique)
471
+
472
+ status.write(f"Found {len(unique)} new {cls['name']} detection(s)")
473
+ status.update(label=f"Found {len(unique)} {cls['name']}", state="complete")
474
+ st.session_state.label_round += 1
475
+ st.session_state._last_canvas_count = 0
476
+ st.rerun()
477
+
478
+ # --- Find Objects for ALL classes button (with confirmation) ---
479
+ if st.session_state.classes and st.session_state.image is not None:
480
  st.divider()
 
481
 
482
+ @st.fragment
483
+ def find_all_objects():
484
+ if "confirm_find_all" not in st.session_state:
485
+ st.session_state.confirm_find_all = False
486
+
487
+ if not st.session_state.confirm_find_all:
488
+ if st.button("Find Objects for all classes", use_container_width=True):
489
+ st.session_state.confirm_find_all = True
490
+ st.rerun(scope="fragment")
491
+ else:
492
+ st.warning(f"This will run SAM3 for **{len(st.session_state.classes)}** class(es). Continue?")
493
+ yes_col, no_col = st.columns(2)
494
+ with yes_col:
495
+ if st.button("Yes, find all", type="primary", use_container_width=True):
496
+ st.session_state.confirm_find_all = False
497
+ image = st.session_state.image
498
+ fname = st.session_state.filename
499
+ status = st.status("Running SAM3 inference...", expanded=True)
500
+ status.write(f"Running on {get_device()}...")
501
+
502
+ for cls in st.session_state.classes:
503
+ status.write(f"Finding **{cls['name']}** (threshold {cls['threshold']:.0%})...")
504
+
505
+ existing_boxes = [
506
+ d["box"] for d in cls["detections"]
507
+ if d.get("image_path") == fname
508
+ ]
509
+ dets = combined_prompt_inference(
510
+ image,
511
+ text=cls["name"],
512
+ boxes=existing_boxes if existing_boxes else None,
513
+ threshold=cls["threshold"],
514
+ )
515
+ for d in dets:
516
+ d["label"] = cls["name"]
517
+ d["accepted"] = True
518
+ d["image_path"] = fname
519
+ d["id"] = _next_detection_id()
520
+
521
+ existing_for_class = [
522
+ d for d in cls["detections"]
523
+ if d.get("image_path") == fname
524
+ ]
525
+ unique = deduplicate(dets, existing_for_class) if existing_for_class else dets
526
+ cls["detections"].extend(unique)
527
+ status.write(f" → {len(unique)} new {cls['name']} detection(s)")
528
+
529
+ status.update(label="Inference complete", state="complete")
530
+ st.session_state.label_round += 1
531
+ st.session_state._last_canvas_count = 0
532
+ st.rerun(scope="app")
533
+ with no_col:
534
+ if st.button("Cancel", use_container_width=True):
535
+ st.session_state.confirm_find_all = False
536
+ st.rerun(scope="fragment")
537
+
538
+ find_all_objects()
539
+
540
+ # --- Update Label Manifest button ---
541
+ if st.session_state.classes:
542
+ st.divider()
543
+ if st.button("Update Label Manifest", use_container_width=True):
544
+ all_dets = []
545
+ for cls in st.session_state.classes:
546
+ all_dets.extend(cls["detections"])
547
+ st.session_state.all_image_detections = all_dets
548
+ st.success(f"Manifest updated: {len(all_dets)} detections")
549
+
550
+ # --- Navigation ---
551
+ if st.session_state.image is not None:
552
  st.divider()
553
+ total = sum(len(c["detections"]) for c in st.session_state.classes)
554
+ if st.button(f"Done — Export ({total} detections)" if total else "Done — Export"):
555
+ # Flatten all class detections into all_image_detections
556
+ all_dets = []
557
+ for cls in st.session_state.classes:
558
+ all_dets.extend(cls["detections"])
559
+ st.session_state.all_image_detections = all_dets
560
+ go_to(3)
561
+ st.rerun()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
562
 
563
  # =============================================================================
564
  # Step 3: Export
 
567
  st.header("Step 3: Export Manifest")
568
 
569
  combined = list(st.session_state.all_image_detections)
 
570
  for i, det in enumerate(combined):
571
  det["id"] = i
572
 
 
607
  elif st.session_state.step == 4:
608
  st.header("Step 4: Fine-Tune SAM3")
609
 
 
610
  combined_dets = list(st.session_state.all_image_detections)
 
611
  for det in combined_dets:
612
  if "image_path" not in det:
613
  det["image_path"] = st.session_state.filename
614
 
 
615
  train_dets = [d for d in combined_dets if d.get("accepted") and d.get("mask") is not None]
616
  image_names = list(set(d["image_path"] for d in train_dets))
617
 
 
623
  go_to(3)
624
  st.rerun()
625
  else:
 
626
  col_ep, col_lr = st.columns(2)
627
  with col_ep:
628
  epochs = st.slider("Epochs", 1, 50, 5, key="train_epochs")
 
643
  processor = None
644
  result = None
645
  try:
 
646
  status = st.status("Preparing for training...", expanded=True)
647
  status.write("Clearing cached inference model to free GPU memory...")
648
  load_model.clear()
 
651
  elif _torch.backends.mps.is_available():
652
  _torch.mps.empty_cache()
653
 
 
654
  status.write("Loading fresh model for training...")
655
  processor, model = load_model_for_training()
656
 
 
657
  trainable, total = freeze_encoder(model)
658
  status.write(f"Frozen encoder. Trainable params: {trainable:,} / {total:,}")
659
 
 
660
  images_dict = {name: img for name, img in st.session_state.images}
661
  dataset = SAM3FineTuneDataset(images_dict, train_dets, processor)
662
  status.write(f"Dataset ready: {len(dataset)} samples")
663
 
 
664
  status.update(label="Training...", state="running")
665
  progress_bar = st.progress(0, text="Starting training...")
666
 
 
673
 
674
  st.session_state.training_loss_history = result["loss_history"]
675
 
 
676
  status.write("Packaging fine-tuned model...")
677
  st.session_state.finetuned_model_bytes = get_model_zip_bytes(result["model"], processor)
678
 
679
  st.session_state.training_complete = True
680
  status.update(label="Training complete!", state="complete")
681
  finally:
 
682
  del model, processor, result
683
  if _torch.cuda.is_available():
684
  _torch.cuda.empty_cache()
 
687
 
688
  st.rerun()
689
  else:
 
690
  st.success("Training complete!")
691
 
 
692
  loss_hist = st.session_state.training_loss_history
693
  if loss_hist:
 
694
  df = pd.DataFrame({"Epoch": range(1, len(loss_hist) + 1), "Avg Loss": loss_hist})
695
  st.line_chart(df, x="Epoch", y="Avg Loss")
696
 
 
697
  if st.session_state.finetuned_model_bytes:
698
  st.download_button(
699
  label="Download fine-tuned model (.zip)",