Spaces:
Paused
Paused
Upload app.py with huggingface_hub
Browse files
app.py
CHANGED
|
@@ -1,9 +1,10 @@
|
|
| 1 |
import streamlit as st
|
|
|
|
| 2 |
from PIL import Image
|
| 3 |
from streamlit_drawable_canvas import st_canvas
|
| 4 |
|
| 5 |
from sam3_engine import get_device, load_model, load_model_for_training, combined_prompt_inference
|
| 6 |
-
from viz import
|
| 7 |
from manifest import build_manifest, manifest_to_json, deduplicate
|
| 8 |
from training import SAM3FineTuneDataset, freeze_encoder, run_training, get_model_zip_bytes
|
| 9 |
|
|
@@ -15,18 +16,19 @@ CANVAS_MAX_WIDTH = 700
|
|
| 15 |
|
| 16 |
# --- Session state defaults ---
|
| 17 |
defaults = {
|
| 18 |
-
"step":
|
| 19 |
"image": None,
|
| 20 |
"filename": None,
|
| 21 |
"images": [], # list of (filename, PIL.Image) tuples
|
| 22 |
"image_index": 0, # current position in batch
|
| 23 |
"all_image_detections": [], # accumulated detections across ALL images
|
| 24 |
-
"
|
| 25 |
-
"
|
| 26 |
-
"
|
| 27 |
-
"sam_results": [], # latest SAM3 results for current image
|
| 28 |
"label_round": 0, # iteration counter for canvas key stability
|
| 29 |
"canvas_scale": 1.0, # image-to-canvas scale factor
|
|
|
|
|
|
|
| 30 |
"training_loss_history": [],
|
| 31 |
"training_complete": False,
|
| 32 |
"finetuned_model_bytes": None,
|
|
@@ -36,16 +38,6 @@ for key, val in defaults.items():
|
|
| 36 |
st.session_state[key] = val
|
| 37 |
|
| 38 |
|
| 39 |
-
def _reset_per_image_state():
|
| 40 |
-
"""Reset state that is specific to a single image."""
|
| 41 |
-
st.session_state.accepted_detections = []
|
| 42 |
-
st.session_state.prompts = []
|
| 43 |
-
st.session_state.prompt_counter = 0
|
| 44 |
-
st.session_state.sam_results = []
|
| 45 |
-
st.session_state.label_round = 0
|
| 46 |
-
st.session_state.canvas_scale = 1.0
|
| 47 |
-
|
| 48 |
-
|
| 49 |
def _load_image_at_index(idx: int):
|
| 50 |
"""Load the image at the given batch index into session state."""
|
| 51 |
filename, image = st.session_state.images[idx]
|
|
@@ -58,6 +50,26 @@ def go_to(step: int):
|
|
| 58 |
st.session_state.step = step
|
| 59 |
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
# --- Coordinate scaling helpers ---
|
| 62 |
def _canvas_to_image(obj: dict, scale: float):
|
| 63 |
"""Convert a Fabric.js canvas object to image-space coordinates."""
|
|
@@ -80,7 +92,6 @@ def _canvas_to_image(obj: dict, scale: float):
|
|
| 80 |
],
|
| 81 |
}
|
| 82 |
elif obj_type == "circle":
|
| 83 |
-
# Points rendered as small circles
|
| 84 |
r = obj.get("radius", 0)
|
| 85 |
cx = (left + r * sx) / scale
|
| 86 |
cy = (top + r * sy) / scale
|
|
@@ -91,56 +102,78 @@ def _canvas_to_image(obj: dict, scale: float):
|
|
| 91 |
return None
|
| 92 |
|
| 93 |
|
| 94 |
-
def
|
| 95 |
-
"""
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
|
| 145 |
|
| 146 |
# --- Sidebar ---
|
|
@@ -149,11 +182,14 @@ with st.sidebar:
|
|
| 149 |
device = get_device()
|
| 150 |
st.caption(f"Device: **{device}**")
|
| 151 |
st.caption("Model: `facebook/sam3`")
|
|
|
|
|
|
|
|
|
|
| 152 |
st.divider()
|
| 153 |
|
| 154 |
-
step_labels = ["
|
| 155 |
current = st.session_state.step
|
| 156 |
-
for i, label in enumerate(step_labels, start=
|
| 157 |
if current == i:
|
| 158 |
marker = f"-> {i}. {label}"
|
| 159 |
else:
|
|
@@ -165,14 +201,10 @@ with st.sidebar:
|
|
| 165 |
st.divider()
|
| 166 |
st.metric("Image", f"{st.session_state.image_index + 1} of {n_images}")
|
| 167 |
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
if total_all:
|
| 171 |
st.divider()
|
| 172 |
-
|
| 173 |
-
st.metric("Accepted (all images)", total_all)
|
| 174 |
-
else:
|
| 175 |
-
st.metric("Total accepted", len(accepted))
|
| 176 |
|
| 177 |
st.divider()
|
| 178 |
if st.button("Start over"):
|
|
@@ -182,155 +214,84 @@ with st.sidebar:
|
|
| 182 |
|
| 183 |
|
| 184 |
# =============================================================================
|
| 185 |
-
# Step
|
| 186 |
-
# =============================================================================
|
| 187 |
-
if st.session_state.step == 1:
|
| 188 |
-
st.header("Step 1: Upload Images")
|
| 189 |
-
uploaded_files = st.file_uploader(
|
| 190 |
-
"Choose one or more images (PNG/JPG)",
|
| 191 |
-
type=["png", "jpg", "jpeg"],
|
| 192 |
-
accept_multiple_files=True,
|
| 193 |
-
)
|
| 194 |
-
if uploaded_files:
|
| 195 |
-
images = [(f.name, Image.open(f).convert("RGB")) for f in uploaded_files]
|
| 196 |
-
st.session_state.images = images
|
| 197 |
-
st.session_state.image_index = 0
|
| 198 |
-
|
| 199 |
-
# Show thumbnail grid
|
| 200 |
-
n = len(images)
|
| 201 |
-
cols = st.columns(min(n, 4))
|
| 202 |
-
for i, (name, img) in enumerate(images):
|
| 203 |
-
with cols[i % len(cols)]:
|
| 204 |
-
st.image(img, caption=name, width="stretch")
|
| 205 |
-
|
| 206 |
-
# Load first image
|
| 207 |
-
_load_image_at_index(0)
|
| 208 |
-
|
| 209 |
-
label = f"Next: Label images (1 of {n})" if n > 1 else "Next: Label image"
|
| 210 |
-
if st.button(label):
|
| 211 |
-
go_to(2)
|
| 212 |
-
st.rerun()
|
| 213 |
-
|
| 214 |
-
# =============================================================================
|
| 215 |
-
# Step 2: Label (interactive canvas + prompts + SAM3)
|
| 216 |
# =============================================================================
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
if st.session_state.accepted_detections:
|
| 240 |
-
bg = overlay_accepted(bg, st.session_state.accepted_detections)
|
| 241 |
-
if st.session_state.sam_results:
|
| 242 |
-
masks = [d["mask"] for d in st.session_state.sam_results]
|
| 243 |
-
boxes = [d["box"] for d in st.session_state.sam_results]
|
| 244 |
-
bg = overlay_masks(bg, masks)
|
| 245 |
-
bg = overlay_boxes(bg, boxes)
|
| 246 |
-
bg_rgb = bg.convert("RGB")
|
| 247 |
-
|
| 248 |
-
# --- Two-column layout ---
|
| 249 |
-
col_canvas, col_controls = st.columns([3, 2])
|
| 250 |
-
|
| 251 |
-
with col_controls:
|
| 252 |
-
st.subheader("Prompts")
|
| 253 |
-
|
| 254 |
-
# Text prompt input
|
| 255 |
-
text_col, btn_col = st.columns([3, 1])
|
| 256 |
-
with text_col:
|
| 257 |
-
text_input = st.text_input("Text prompt", key="text_prompt_input", label_visibility="collapsed", placeholder="Describe objects to find...")
|
| 258 |
-
with btn_col:
|
| 259 |
-
if st.button("Add text", disabled=not text_input):
|
| 260 |
-
st.session_state.prompts.append({
|
| 261 |
-
"id": _next_prompt_id(),
|
| 262 |
-
"type": "text",
|
| 263 |
-
"coords": [],
|
| 264 |
-
"label": text_input,
|
| 265 |
-
"point_label": None,
|
| 266 |
-
})
|
| 267 |
-
st.rerun()
|
| 268 |
-
|
| 269 |
-
# Prompt table
|
| 270 |
-
prompts = st.session_state.prompts
|
| 271 |
-
if prompts:
|
| 272 |
-
st.caption(f"{len(prompts)} prompt(s)")
|
| 273 |
-
for i, p in enumerate(prompts):
|
| 274 |
-
pcol1, pcol2, pcol3, pcol4 = st.columns([1, 2, 3, 1])
|
| 275 |
-
with pcol1:
|
| 276 |
-
st.text(p["id"])
|
| 277 |
-
with pcol2:
|
| 278 |
-
st.text(p["type"])
|
| 279 |
-
with pcol3:
|
| 280 |
-
new_label = st.text_input(
|
| 281 |
-
"label", value=p.get("label", ""), key=f"plabel_{p['id']}",
|
| 282 |
-
label_visibility="collapsed",
|
| 283 |
-
)
|
| 284 |
-
if new_label != p.get("label", ""):
|
| 285 |
-
st.session_state.prompts[i]["label"] = new_label
|
| 286 |
-
with pcol4:
|
| 287 |
-
if p["type"] == "point":
|
| 288 |
-
is_pos = p.get("point_label", 1) == 1
|
| 289 |
-
toggled = st.checkbox("+", value=is_pos, key=f"ptoggle_{p['id']}")
|
| 290 |
-
st.session_state.prompts[i]["point_label"] = 1 if toggled else 0
|
| 291 |
-
if st.button("X", key=f"pdel_{p['id']}"):
|
| 292 |
-
st.session_state.prompts.pop(i)
|
| 293 |
-
st.rerun()
|
| 294 |
-
else:
|
| 295 |
-
st.caption("No prompts yet. Draw boxes or points on the canvas, or add text prompts above.")
|
| 296 |
-
|
| 297 |
-
# Threshold
|
| 298 |
-
threshold = st.slider("Confidence threshold", 0.0, 1.0, 0.5, 0.05, key="label_threshold")
|
| 299 |
-
|
| 300 |
-
# Run SAM3 button
|
| 301 |
-
@st.fragment
|
| 302 |
-
def run_sam3():
|
| 303 |
-
prompts = st.session_state.prompts
|
| 304 |
-
has_prompts = len(prompts) > 0
|
| 305 |
-
if st.button("Run SAM3", type="primary", disabled=not has_prompts):
|
| 306 |
-
# Gather prompts by type
|
| 307 |
-
text_parts = [p["label"] for p in prompts if p["type"] == "text" and p["label"]]
|
| 308 |
-
text_combined = ". ".join(text_parts) if text_parts else None
|
| 309 |
-
|
| 310 |
-
box_list = [p["coords"] for p in prompts if p["type"] == "box" and len(p["coords"]) == 4]
|
| 311 |
-
box_list = box_list if box_list else None
|
| 312 |
-
|
| 313 |
-
pt_prompts = [p for p in prompts if p["type"] == "point" and len(p["coords"]) == 2]
|
| 314 |
-
points = [p["coords"] for p in pt_prompts] if pt_prompts else None
|
| 315 |
-
point_labels = [p.get("point_label", 1) for p in pt_prompts] if pt_prompts else None
|
| 316 |
-
|
| 317 |
-
status = st.status("Running SAM3 inference...", expanded=True)
|
| 318 |
-
status.write(f"Running on {get_device()} with {len(prompts)} prompt(s)...")
|
| 319 |
-
results = combined_prompt_inference(
|
| 320 |
-
image, text=text_combined, boxes=box_list,
|
| 321 |
-
points=points, point_labels=point_labels,
|
| 322 |
-
threshold=threshold,
|
| 323 |
-
)
|
| 324 |
-
status.write(f"Found {len(results)} objects!")
|
| 325 |
-
status.update(label="Inference complete", state="complete")
|
| 326 |
-
st.session_state.sam_results = results
|
| 327 |
-
st.session_state.label_round += 1
|
| 328 |
-
st.rerun(scope="app")
|
| 329 |
-
|
| 330 |
-
run_sam3()
|
| 331 |
|
| 332 |
-
with
|
| 333 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
drawing_mode = st.radio(
|
| 335 |
"Drawing mode",
|
| 336 |
["rect", "point", "transform"],
|
|
@@ -338,12 +299,6 @@ elif st.session_state.step == 2:
|
|
| 338 |
key="drawing_mode",
|
| 339 |
)
|
| 340 |
|
| 341 |
-
# Build initial_drawing from existing prompts
|
| 342 |
-
initial = _prompts_to_fabric_json(
|
| 343 |
-
[p for p in st.session_state.prompts if p["type"] in ("box", "point")],
|
| 344 |
-
scale,
|
| 345 |
-
)
|
| 346 |
-
|
| 347 |
canvas_result = st_canvas(
|
| 348 |
fill_color="rgba(255, 0, 0, 0.1)",
|
| 349 |
stroke_width=2,
|
|
@@ -353,118 +308,257 @@ elif st.session_state.step == 2:
|
|
| 353 |
height=canvas_h,
|
| 354 |
drawing_mode=drawing_mode,
|
| 355 |
point_display_radius=5,
|
| 356 |
-
initial_drawing=initial,
|
| 357 |
key=f"canvas_{img_idx}_{st.session_state.label_round}",
|
| 358 |
)
|
| 359 |
|
| 360 |
-
#
|
| 361 |
if canvas_result.json_data is not None:
|
| 362 |
canvas_objects = canvas_result.json_data.get("objects", [])
|
| 363 |
-
# Count non-text prompts currently in state
|
| 364 |
-
spatial_prompts = [p for p in st.session_state.prompts if p["type"] in ("box", "point")]
|
| 365 |
-
n_existing = len(spatial_prompts)
|
| 366 |
n_canvas = len(canvas_objects)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 367 |
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 436 |
st.divider()
|
| 437 |
-
st.info(f"**{len(st.session_state.accepted_detections)}** accepted detections for this image")
|
| 438 |
|
| 439 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 440 |
st.divider()
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
if st.button(f"Next image: {next_name}"):
|
| 451 |
-
# Stamp image_path and merge
|
| 452 |
-
for det in st.session_state.accepted_detections:
|
| 453 |
-
det["image_path"] = st.session_state.filename
|
| 454 |
-
st.session_state.all_image_detections.extend(st.session_state.accepted_detections)
|
| 455 |
-
_reset_per_image_state()
|
| 456 |
-
_load_image_at_index(img_idx + 1)
|
| 457 |
-
st.rerun()
|
| 458 |
-
with nav_cols[2]:
|
| 459 |
-
total = len(st.session_state.all_image_detections) + len(st.session_state.accepted_detections)
|
| 460 |
-
if st.button(f"Done — Export ({total} detections)" if total else "Done — Export"):
|
| 461 |
-
# Stamp and merge current image
|
| 462 |
-
for det in st.session_state.accepted_detections:
|
| 463 |
-
det["image_path"] = st.session_state.filename
|
| 464 |
-
st.session_state.all_image_detections.extend(st.session_state.accepted_detections)
|
| 465 |
-
st.session_state.accepted_detections = []
|
| 466 |
-
go_to(3)
|
| 467 |
-
st.rerun()
|
| 468 |
|
| 469 |
# =============================================================================
|
| 470 |
# Step 3: Export
|
|
@@ -473,7 +567,6 @@ elif st.session_state.step == 3:
|
|
| 473 |
st.header("Step 3: Export Manifest")
|
| 474 |
|
| 475 |
combined = list(st.session_state.all_image_detections)
|
| 476 |
-
# Re-index combined IDs
|
| 477 |
for i, det in enumerate(combined):
|
| 478 |
det["id"] = i
|
| 479 |
|
|
@@ -514,14 +607,11 @@ elif st.session_state.step == 3:
|
|
| 514 |
elif st.session_state.step == 4:
|
| 515 |
st.header("Step 4: Fine-Tune SAM3")
|
| 516 |
|
| 517 |
-
# Build combined detections list
|
| 518 |
combined_dets = list(st.session_state.all_image_detections)
|
| 519 |
-
# Stamp image_path on detections if not set
|
| 520 |
for det in combined_dets:
|
| 521 |
if "image_path" not in det:
|
| 522 |
det["image_path"] = st.session_state.filename
|
| 523 |
|
| 524 |
-
# Only keep detections with masks
|
| 525 |
train_dets = [d for d in combined_dets if d.get("accepted") and d.get("mask") is not None]
|
| 526 |
image_names = list(set(d["image_path"] for d in train_dets))
|
| 527 |
|
|
@@ -533,7 +623,6 @@ elif st.session_state.step == 4:
|
|
| 533 |
go_to(3)
|
| 534 |
st.rerun()
|
| 535 |
else:
|
| 536 |
-
# Hyperparameters
|
| 537 |
col_ep, col_lr = st.columns(2)
|
| 538 |
with col_ep:
|
| 539 |
epochs = st.slider("Epochs", 1, 50, 5, key="train_epochs")
|
|
@@ -554,7 +643,6 @@ elif st.session_state.step == 4:
|
|
| 554 |
processor = None
|
| 555 |
result = None
|
| 556 |
try:
|
| 557 |
-
# 1. Free GPU memory from cached inference model
|
| 558 |
status = st.status("Preparing for training...", expanded=True)
|
| 559 |
status.write("Clearing cached inference model to free GPU memory...")
|
| 560 |
load_model.clear()
|
|
@@ -563,20 +651,16 @@ elif st.session_state.step == 4:
|
|
| 563 |
elif _torch.backends.mps.is_available():
|
| 564 |
_torch.mps.empty_cache()
|
| 565 |
|
| 566 |
-
# 2. Load fresh trainable model
|
| 567 |
status.write("Loading fresh model for training...")
|
| 568 |
processor, model = load_model_for_training()
|
| 569 |
|
| 570 |
-
# 3. Freeze encoder
|
| 571 |
trainable, total = freeze_encoder(model)
|
| 572 |
status.write(f"Frozen encoder. Trainable params: {trainable:,} / {total:,}")
|
| 573 |
|
| 574 |
-
# 4. Build dataset
|
| 575 |
images_dict = {name: img for name, img in st.session_state.images}
|
| 576 |
dataset = SAM3FineTuneDataset(images_dict, train_dets, processor)
|
| 577 |
status.write(f"Dataset ready: {len(dataset)} samples")
|
| 578 |
|
| 579 |
-
# 5. Train with progress bar
|
| 580 |
status.update(label="Training...", state="running")
|
| 581 |
progress_bar = st.progress(0, text="Starting training...")
|
| 582 |
|
|
@@ -589,14 +673,12 @@ elif st.session_state.step == 4:
|
|
| 589 |
|
| 590 |
st.session_state.training_loss_history = result["loss_history"]
|
| 591 |
|
| 592 |
-
# 6. Save model zip
|
| 593 |
status.write("Packaging fine-tuned model...")
|
| 594 |
st.session_state.finetuned_model_bytes = get_model_zip_bytes(result["model"], processor)
|
| 595 |
|
| 596 |
st.session_state.training_complete = True
|
| 597 |
status.update(label="Training complete!", state="complete")
|
| 598 |
finally:
|
| 599 |
-
# Always clean up GPU memory, even if stopped/interrupted
|
| 600 |
del model, processor, result
|
| 601 |
if _torch.cuda.is_available():
|
| 602 |
_torch.cuda.empty_cache()
|
|
@@ -605,17 +687,13 @@ elif st.session_state.step == 4:
|
|
| 605 |
|
| 606 |
st.rerun()
|
| 607 |
else:
|
| 608 |
-
# Post-training UI
|
| 609 |
st.success("Training complete!")
|
| 610 |
|
| 611 |
-
# Loss curve
|
| 612 |
loss_hist = st.session_state.training_loss_history
|
| 613 |
if loss_hist:
|
| 614 |
-
import pandas as pd
|
| 615 |
df = pd.DataFrame({"Epoch": range(1, len(loss_hist) + 1), "Avg Loss": loss_hist})
|
| 616 |
st.line_chart(df, x="Epoch", y="Avg Loss")
|
| 617 |
|
| 618 |
-
# Download button
|
| 619 |
if st.session_state.finetuned_model_bytes:
|
| 620 |
st.download_button(
|
| 621 |
label="Download fine-tuned model (.zip)",
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
+
import pandas as pd
|
| 3 |
from PIL import Image
|
| 4 |
from streamlit_drawable_canvas import st_canvas
|
| 5 |
|
| 6 |
from sam3_engine import get_device, load_model, load_model_for_training, combined_prompt_inference
|
| 7 |
+
from viz import overlay_detections_by_class, _hex_to_rgb, CLASS_COLORS
|
| 8 |
from manifest import build_manifest, manifest_to_json, deduplicate
|
| 9 |
from training import SAM3FineTuneDataset, freeze_encoder, run_training, get_model_zip_bytes
|
| 10 |
|
|
|
|
| 16 |
|
| 17 |
# --- Session state defaults ---
|
| 18 |
defaults = {
|
| 19 |
+
"step": 2,
|
| 20 |
"image": None,
|
| 21 |
"filename": None,
|
| 22 |
"images": [], # list of (filename, PIL.Image) tuples
|
| 23 |
"image_index": 0, # current position in batch
|
| 24 |
"all_image_detections": [], # accumulated detections across ALL images
|
| 25 |
+
"classes": [], # list of class dicts
|
| 26 |
+
"pending_box_coords": None, # drawn box awaiting class assignment
|
| 27 |
+
"detection_id_counter": 0, # monotonic ID for detections
|
|
|
|
| 28 |
"label_round": 0, # iteration counter for canvas key stability
|
| 29 |
"canvas_scale": 1.0, # image-to-canvas scale factor
|
| 30 |
+
"_last_canvas_count": 0, # track canvas object count for new-drawing detection
|
| 31 |
+
"selected_detection_id": None, # ID of detection selected for highlighting
|
| 32 |
"training_loss_history": [],
|
| 33 |
"training_complete": False,
|
| 34 |
"finetuned_model_bytes": None,
|
|
|
|
| 38 |
st.session_state[key] = val
|
| 39 |
|
| 40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
def _load_image_at_index(idx: int):
|
| 42 |
"""Load the image at the given batch index into session state."""
|
| 43 |
filename, image = st.session_state.images[idx]
|
|
|
|
| 50 |
st.session_state.step = step
|
| 51 |
|
| 52 |
|
| 53 |
+
def _next_detection_id() -> int:
|
| 54 |
+
st.session_state.detection_id_counter += 1
|
| 55 |
+
return st.session_state.detection_id_counter
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _get_current_image_detections(visible_only=False):
|
| 59 |
+
"""Get all detections for the current image across all classes."""
|
| 60 |
+
fname = st.session_state.filename
|
| 61 |
+
if not fname:
|
| 62 |
+
return []
|
| 63 |
+
dets = []
|
| 64 |
+
for cls in st.session_state.classes:
|
| 65 |
+
if visible_only and not cls["visible"]:
|
| 66 |
+
continue
|
| 67 |
+
for det in cls["detections"]:
|
| 68 |
+
if det.get("image_path") == fname:
|
| 69 |
+
dets.append(det)
|
| 70 |
+
return dets
|
| 71 |
+
|
| 72 |
+
|
| 73 |
# --- Coordinate scaling helpers ---
|
| 74 |
def _canvas_to_image(obj: dict, scale: float):
|
| 75 |
"""Convert a Fabric.js canvas object to image-space coordinates."""
|
|
|
|
| 92 |
],
|
| 93 |
}
|
| 94 |
elif obj_type == "circle":
|
|
|
|
| 95 |
r = obj.get("radius", 0)
|
| 96 |
cx = (left + r * sx) / scale
|
| 97 |
cy = (top + r * sy) / scale
|
|
|
|
| 102 |
return None
|
| 103 |
|
| 104 |
|
| 105 |
+
def _add_class(name: str):
|
| 106 |
+
"""Create a new class and return it."""
|
| 107 |
+
color = CLASS_COLORS[len(st.session_state.classes) % len(CLASS_COLORS)]
|
| 108 |
+
cls = {
|
| 109 |
+
"name": name,
|
| 110 |
+
"color": color,
|
| 111 |
+
"visible": True,
|
| 112 |
+
"threshold": 0.85,
|
| 113 |
+
"detections": [],
|
| 114 |
+
}
|
| 115 |
+
st.session_state.classes.append(cls)
|
| 116 |
+
return cls
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
@st.dialog("Assign to Class")
|
| 120 |
+
def assign_drawing_dialog():
|
| 121 |
+
"""Modal dialog for assigning a drawn box/point to a class."""
|
| 122 |
+
pending = st.session_state.pending_box_coords
|
| 123 |
+
if pending is None:
|
| 124 |
+
st.warning("No pending drawing.")
|
| 125 |
+
return
|
| 126 |
+
|
| 127 |
+
st.write(f"New **{pending['type']}** drawn. Choose a class to assign it to:")
|
| 128 |
+
|
| 129 |
+
# Existing class selector
|
| 130 |
+
class_names = [c["name"] for c in st.session_state.classes]
|
| 131 |
+
chosen_existing = None
|
| 132 |
+
if class_names:
|
| 133 |
+
chosen_existing = st.selectbox("Existing class", class_names, key="dlg_class_select")
|
| 134 |
+
|
| 135 |
+
# Or create a new class
|
| 136 |
+
st.divider()
|
| 137 |
+
new_name = st.text_input("Or create a new class", key="dlg_new_class", placeholder="e.g. Cable, Label...")
|
| 138 |
+
|
| 139 |
+
st.divider()
|
| 140 |
+
assign_col, cancel_col = st.columns(2)
|
| 141 |
+
with assign_col:
|
| 142 |
+
can_assign = bool(new_name) or bool(chosen_existing)
|
| 143 |
+
if st.button("Assign", type="primary", disabled=not can_assign, use_container_width=True):
|
| 144 |
+
# Determine target class
|
| 145 |
+
if new_name:
|
| 146 |
+
existing_names = {c["name"] for c in st.session_state.classes}
|
| 147 |
+
if new_name not in existing_names:
|
| 148 |
+
target_cls = _add_class(new_name)
|
| 149 |
+
else:
|
| 150 |
+
target_cls = next(c for c in st.session_state.classes if c["name"] == new_name)
|
| 151 |
+
else:
|
| 152 |
+
target_cls = next(c for c in st.session_state.classes if c["name"] == chosen_existing)
|
| 153 |
+
|
| 154 |
+
det = {
|
| 155 |
+
"id": _next_detection_id(),
|
| 156 |
+
"mask": None,
|
| 157 |
+
"box": pending["coords"] if pending["type"] == "box" else [
|
| 158 |
+
pending["coords"][0] - 10, pending["coords"][1] - 10,
|
| 159 |
+
pending["coords"][0] + 10, pending["coords"][1] + 10,
|
| 160 |
+
],
|
| 161 |
+
"score": 1.0,
|
| 162 |
+
"label": target_cls["name"],
|
| 163 |
+
"accepted": True,
|
| 164 |
+
"image_path": st.session_state.filename,
|
| 165 |
+
}
|
| 166 |
+
target_cls["detections"].append(det)
|
| 167 |
+
st.session_state.pending_box_coords = None
|
| 168 |
+
st.session_state.label_round += 1
|
| 169 |
+
st.session_state._last_canvas_count = 0
|
| 170 |
+
st.rerun()
|
| 171 |
+
with cancel_col:
|
| 172 |
+
if st.button("Cancel", use_container_width=True):
|
| 173 |
+
st.session_state.pending_box_coords = None
|
| 174 |
+
st.session_state.label_round += 1
|
| 175 |
+
st.session_state._last_canvas_count = 0
|
| 176 |
+
st.rerun()
|
| 177 |
|
| 178 |
|
| 179 |
# --- Sidebar ---
|
|
|
|
| 182 |
device = get_device()
|
| 183 |
st.caption(f"Device: **{device}**")
|
| 184 |
st.caption("Model: `facebook/sam3`")
|
| 185 |
+
with st.spinner("Loading SAM3 model..."):
|
| 186 |
+
load_model()
|
| 187 |
+
st.caption("Model loaded")
|
| 188 |
st.divider()
|
| 189 |
|
| 190 |
+
step_labels = ["Label", "Export", "Train"]
|
| 191 |
current = st.session_state.step
|
| 192 |
+
for i, label in enumerate(step_labels, start=2):
|
| 193 |
if current == i:
|
| 194 |
marker = f"-> {i}. {label}"
|
| 195 |
else:
|
|
|
|
| 201 |
st.divider()
|
| 202 |
st.metric("Image", f"{st.session_state.image_index + 1} of {n_images}")
|
| 203 |
|
| 204 |
+
total_dets = sum(len(c["detections"]) for c in st.session_state.classes)
|
| 205 |
+
if total_dets:
|
|
|
|
| 206 |
st.divider()
|
| 207 |
+
st.metric("Total detections", total_dets)
|
|
|
|
|
|
|
|
|
|
| 208 |
|
| 209 |
st.divider()
|
| 210 |
if st.button("Start over"):
|
|
|
|
| 214 |
|
| 215 |
|
| 216 |
# =============================================================================
|
| 217 |
+
# Step 2: Label (3-column class-centric layout)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
# =============================================================================
|
| 219 |
+
if st.session_state.step == 2:
|
| 220 |
+
col_files, col_canvas, col_controls = st.columns([1, 3, 2])
|
| 221 |
+
|
| 222 |
+
# --- Left column: File list ---
|
| 223 |
+
with col_files:
|
| 224 |
+
st.subheader("Images")
|
| 225 |
+
uploaded_files = st.file_uploader(
|
| 226 |
+
"Upload images",
|
| 227 |
+
type=["png", "jpg", "jpeg"],
|
| 228 |
+
accept_multiple_files=True,
|
| 229 |
+
label_visibility="collapsed",
|
| 230 |
+
)
|
| 231 |
+
if uploaded_files:
|
| 232 |
+
existing_names = {name for name, _ in st.session_state.images}
|
| 233 |
+
for f in uploaded_files:
|
| 234 |
+
if f.name not in existing_names:
|
| 235 |
+
st.session_state.images.append((f.name, Image.open(f).convert("RGB")))
|
| 236 |
+
existing_names.add(f.name)
|
| 237 |
+
# Auto-load first image if none loaded
|
| 238 |
+
if st.session_state.image is None and st.session_state.images:
|
| 239 |
+
_load_image_at_index(0)
|
| 240 |
+
st.rerun()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
|
| 242 |
+
# Show file list with thumbnails
|
| 243 |
+
if st.session_state.images:
|
| 244 |
+
filenames = [name for name, _ in st.session_state.images]
|
| 245 |
+
for i, (name, img) in enumerate(st.session_state.images):
|
| 246 |
+
st.image(img, width=100)
|
| 247 |
+
is_current = (i == st.session_state.image_index)
|
| 248 |
+
if st.button(
|
| 249 |
+
name,
|
| 250 |
+
key=f"file_select_{i}",
|
| 251 |
+
type="primary" if is_current else "secondary",
|
| 252 |
+
use_container_width=True,
|
| 253 |
+
):
|
| 254 |
+
if not is_current:
|
| 255 |
+
_load_image_at_index(i)
|
| 256 |
+
st.session_state.label_round += 1
|
| 257 |
+
st.session_state._last_canvas_count = 0
|
| 258 |
+
st.session_state.pending_box_coords = None
|
| 259 |
+
st.session_state.selected_detection_id = None
|
| 260 |
+
st.rerun()
|
| 261 |
+
|
| 262 |
+
# --- Center column: Canvas ---
|
| 263 |
+
with col_canvas:
|
| 264 |
+
image = st.session_state.image
|
| 265 |
+
if image is None:
|
| 266 |
+
st.info("Upload images in the left panel to get started.")
|
| 267 |
+
else:
|
| 268 |
+
img_idx = st.session_state.image_index
|
| 269 |
+
n_images = len(st.session_state.images)
|
| 270 |
+
img_label = f" ({img_idx + 1} of {n_images})" if n_images > 1 else ""
|
| 271 |
+
st.subheader(f"{st.session_state.filename}{img_label}")
|
| 272 |
+
|
| 273 |
+
# Compute canvas dimensions
|
| 274 |
+
img_w, img_h = image.size
|
| 275 |
+
canvas_w = min(img_w, CANVAS_MAX_WIDTH)
|
| 276 |
+
scale = canvas_w / img_w
|
| 277 |
+
canvas_h = int(img_h * scale)
|
| 278 |
+
st.session_state.canvas_scale = scale
|
| 279 |
+
|
| 280 |
+
# Build background with visible detections overlaid
|
| 281 |
+
visible_dets = _get_current_image_detections(visible_only=True)
|
| 282 |
+
bg = image.copy()
|
| 283 |
+
if visible_dets:
|
| 284 |
+
# Build color map from class definitions
|
| 285 |
+
color_map = {}
|
| 286 |
+
for cls in st.session_state.classes:
|
| 287 |
+
if cls["visible"]:
|
| 288 |
+
color_map[cls["name"]] = _hex_to_rgb(cls["color"])
|
| 289 |
+
color_map[""] = (180, 180, 180)
|
| 290 |
+
hl_ids = {st.session_state.selected_detection_id} if st.session_state.selected_detection_id is not None else None
|
| 291 |
+
bg = overlay_detections_by_class(bg, visible_dets, color_override=color_map, highlight_ids=hl_ids)
|
| 292 |
+
bg_rgb = bg.convert("RGB")
|
| 293 |
+
|
| 294 |
+
# Drawing mode
|
| 295 |
drawing_mode = st.radio(
|
| 296 |
"Drawing mode",
|
| 297 |
["rect", "point", "transform"],
|
|
|
|
| 299 |
key="drawing_mode",
|
| 300 |
)
|
| 301 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
canvas_result = st_canvas(
|
| 303 |
fill_color="rgba(255, 0, 0, 0.1)",
|
| 304 |
stroke_width=2,
|
|
|
|
| 308 |
height=canvas_h,
|
| 309 |
drawing_mode=drawing_mode,
|
| 310 |
point_display_radius=5,
|
|
|
|
| 311 |
key=f"canvas_{img_idx}_{st.session_state.label_round}",
|
| 312 |
)
|
| 313 |
|
| 314 |
+
# Detect new drawings
|
| 315 |
if canvas_result.json_data is not None:
|
| 316 |
canvas_objects = canvas_result.json_data.get("objects", [])
|
|
|
|
|
|
|
|
|
|
| 317 |
n_canvas = len(canvas_objects)
|
| 318 |
+
last_count = st.session_state._last_canvas_count
|
| 319 |
+
|
| 320 |
+
if n_canvas > last_count and st.session_state.pending_box_coords is None:
|
| 321 |
+
# New object drawn — convert the last one
|
| 322 |
+
new_obj = canvas_objects[-1]
|
| 323 |
+
converted = _canvas_to_image(new_obj, scale)
|
| 324 |
+
if converted:
|
| 325 |
+
st.session_state.pending_box_coords = converted
|
| 326 |
+
st.session_state._last_canvas_count = n_canvas
|
| 327 |
+
st.rerun()
|
| 328 |
+
|
| 329 |
+
# Open assignment dialog when a new drawing is pending
|
| 330 |
+
if st.session_state.pending_box_coords is not None:
|
| 331 |
+
assign_drawing_dialog()
|
| 332 |
+
|
| 333 |
+
# --- Right column: Class controls ---
|
| 334 |
+
with col_controls:
|
| 335 |
+
st.subheader("Classes")
|
| 336 |
+
|
| 337 |
+
# Class input
|
| 338 |
+
new_class = st.text_input("New class name", key="new_class_input", placeholder="e.g. Server, Cable, Label...")
|
| 339 |
+
if new_class:
|
| 340 |
+
existing_names = {c["name"] for c in st.session_state.classes}
|
| 341 |
+
if new_class not in existing_names:
|
| 342 |
+
color = CLASS_COLORS[len(st.session_state.classes) % len(CLASS_COLORS)]
|
| 343 |
+
st.session_state.classes.append({
|
| 344 |
+
"name": new_class,
|
| 345 |
+
"color": color,
|
| 346 |
+
"visible": True,
|
| 347 |
+
"threshold": 0.85,
|
| 348 |
+
"detections": [],
|
| 349 |
+
})
|
| 350 |
+
st.rerun()
|
| 351 |
|
| 352 |
+
# Class cards
|
| 353 |
+
classes_to_delete = []
|
| 354 |
+
dets_to_delete = [] # list of (class_idx, det_id)
|
| 355 |
+
find_single_class_idx = None # index of class to run per-class find
|
| 356 |
+
|
| 357 |
+
for ci, cls in enumerate(st.session_state.classes):
|
| 358 |
+
with st.container(border=True):
|
| 359 |
+
# Header row
|
| 360 |
+
hcol_name, hcol_vis, hcol_del = st.columns([3, 1, 1])
|
| 361 |
+
with hcol_name:
|
| 362 |
+
st.markdown(
|
| 363 |
+
f"<span style='color:{cls['color']};font-weight:bold;font-size:1.1em'>"
|
| 364 |
+
f"{cls['name']}</span>",
|
| 365 |
+
unsafe_allow_html=True,
|
| 366 |
+
)
|
| 367 |
+
with hcol_vis:
|
| 368 |
+
vis = st.checkbox("👁", value=cls["visible"], key=f"vis_{ci}", label_visibility="collapsed")
|
| 369 |
+
if vis != cls["visible"]:
|
| 370 |
+
st.session_state.classes[ci]["visible"] = vis
|
| 371 |
+
st.rerun()
|
| 372 |
+
with hcol_del:
|
| 373 |
+
if st.button("🗑", key=f"del_class_{ci}"):
|
| 374 |
+
classes_to_delete.append(ci)
|
| 375 |
+
|
| 376 |
+
# Detections for current image — colored buttons
|
| 377 |
+
fname = st.session_state.filename
|
| 378 |
+
if fname:
|
| 379 |
+
img_dets = [d for d in cls["detections"] if d.get("image_path") == fname]
|
| 380 |
+
if img_dets:
|
| 381 |
+
for det in img_dets:
|
| 382 |
+
dcol_label, dcol_del = st.columns([4, 1])
|
| 383 |
+
with dcol_label:
|
| 384 |
+
is_sel = st.session_state.selected_detection_id == det["id"]
|
| 385 |
+
# Colored detection button via markdown + button
|
| 386 |
+
border_style = "3px solid yellow" if is_sel else f"2px solid {cls['color']}"
|
| 387 |
+
st.markdown(
|
| 388 |
+
f"<div style='background:{cls['color']}22;border:{border_style};"
|
| 389 |
+
f"border-radius:6px;padding:4px 8px;text-align:center;"
|
| 390 |
+
f"color:{cls['color']};font-weight:600;cursor:default'>"
|
| 391 |
+
f"{cls['name']} {det['id']} — {det['score']:.0%}</div>",
|
| 392 |
+
unsafe_allow_html=True,
|
| 393 |
+
)
|
| 394 |
+
if st.button(
|
| 395 |
+
"Select" if not is_sel else "Deselect",
|
| 396 |
+
key=f"sel_det_{ci}_{det['id']}",
|
| 397 |
+
use_container_width=True,
|
| 398 |
+
):
|
| 399 |
+
if is_sel:
|
| 400 |
+
st.session_state.selected_detection_id = None
|
| 401 |
+
else:
|
| 402 |
+
st.session_state.selected_detection_id = det["id"]
|
| 403 |
+
st.rerun()
|
| 404 |
+
with dcol_del:
|
| 405 |
+
if st.button("🗑", key=f"del_det_{ci}_{det['id']}"):
|
| 406 |
+
dets_to_delete.append((ci, det["id"]))
|
| 407 |
+
else:
|
| 408 |
+
st.caption("No detections on this image")
|
| 409 |
+
|
| 410 |
+
# Per-class confidence threshold
|
| 411 |
+
new_thresh = st.slider(
|
| 412 |
+
"Confidence threshold", 0.0, 1.0, cls["threshold"], 0.05,
|
| 413 |
+
key=f"thresh_{ci}",
|
| 414 |
+
)
|
| 415 |
+
st.caption(f"Default 85%")
|
| 416 |
+
if new_thresh != cls["threshold"]:
|
| 417 |
+
st.session_state.classes[ci]["threshold"] = new_thresh
|
| 418 |
+
|
| 419 |
+
# Per-class Find Objects button
|
| 420 |
+
if st.session_state.image is not None:
|
| 421 |
+
if st.button(f"🔍 Find Objects for this Class", key=f"find_class_{ci}", use_container_width=True):
|
| 422 |
+
find_single_class_idx = ci
|
| 423 |
+
|
| 424 |
+
# Process deletions
|
| 425 |
+
if classes_to_delete:
|
| 426 |
+
for ci in sorted(classes_to_delete, reverse=True):
|
| 427 |
+
st.session_state.classes.pop(ci)
|
| 428 |
+
st.rerun()
|
| 429 |
+
|
| 430 |
+
if dets_to_delete:
|
| 431 |
+
for ci, det_id in dets_to_delete:
|
| 432 |
+
if st.session_state.selected_detection_id == det_id:
|
| 433 |
+
st.session_state.selected_detection_id = None
|
| 434 |
+
st.session_state.classes[ci]["detections"] = [
|
| 435 |
+
d for d in st.session_state.classes[ci]["detections"] if d["id"] != det_id
|
| 436 |
+
]
|
| 437 |
+
st.session_state.label_round += 1
|
| 438 |
+
st.session_state._last_canvas_count = 0
|
| 439 |
+
st.rerun()
|
| 440 |
+
|
| 441 |
+
# --- Per-class Find Objects execution ---
|
| 442 |
+
if find_single_class_idx is not None:
|
| 443 |
+
cls = st.session_state.classes[find_single_class_idx]
|
| 444 |
+
image = st.session_state.image
|
| 445 |
+
fname = st.session_state.filename
|
| 446 |
+
status = st.status(f"Finding {cls['name']}...", expanded=True)
|
| 447 |
+
status.write(f"Running on {get_device()} (threshold {cls['threshold']:.0%})...")
|
| 448 |
+
|
| 449 |
+
existing_boxes = [
|
| 450 |
+
d["box"] for d in cls["detections"]
|
| 451 |
+
if d.get("image_path") == fname
|
| 452 |
+
]
|
| 453 |
+
dets = combined_prompt_inference(
|
| 454 |
+
image,
|
| 455 |
+
text=cls["name"],
|
| 456 |
+
boxes=existing_boxes if existing_boxes else None,
|
| 457 |
+
threshold=cls["threshold"],
|
| 458 |
+
)
|
| 459 |
+
for d in dets:
|
| 460 |
+
d["label"] = cls["name"]
|
| 461 |
+
d["accepted"] = True
|
| 462 |
+
d["image_path"] = fname
|
| 463 |
+
d["id"] = _next_detection_id()
|
| 464 |
+
|
| 465 |
+
existing_for_class = [
|
| 466 |
+
d for d in cls["detections"]
|
| 467 |
+
if d.get("image_path") == fname
|
| 468 |
+
]
|
| 469 |
+
unique = deduplicate(dets, existing_for_class) if existing_for_class else dets
|
| 470 |
+
cls["detections"].extend(unique)
|
| 471 |
+
|
| 472 |
+
status.write(f"Found {len(unique)} new {cls['name']} detection(s)")
|
| 473 |
+
status.update(label=f"Found {len(unique)} {cls['name']}", state="complete")
|
| 474 |
+
st.session_state.label_round += 1
|
| 475 |
+
st.session_state._last_canvas_count = 0
|
| 476 |
+
st.rerun()
|
| 477 |
+
|
| 478 |
+
# --- Find Objects for ALL classes button (with confirmation) ---
|
| 479 |
+
if st.session_state.classes and st.session_state.image is not None:
|
| 480 |
st.divider()
|
|
|
|
| 481 |
|
| 482 |
+
@st.fragment
|
| 483 |
+
def find_all_objects():
|
| 484 |
+
if "confirm_find_all" not in st.session_state:
|
| 485 |
+
st.session_state.confirm_find_all = False
|
| 486 |
+
|
| 487 |
+
if not st.session_state.confirm_find_all:
|
| 488 |
+
if st.button("Find Objects for all classes", use_container_width=True):
|
| 489 |
+
st.session_state.confirm_find_all = True
|
| 490 |
+
st.rerun(scope="fragment")
|
| 491 |
+
else:
|
| 492 |
+
st.warning(f"This will run SAM3 for **{len(st.session_state.classes)}** class(es). Continue?")
|
| 493 |
+
yes_col, no_col = st.columns(2)
|
| 494 |
+
with yes_col:
|
| 495 |
+
if st.button("Yes, find all", type="primary", use_container_width=True):
|
| 496 |
+
st.session_state.confirm_find_all = False
|
| 497 |
+
image = st.session_state.image
|
| 498 |
+
fname = st.session_state.filename
|
| 499 |
+
status = st.status("Running SAM3 inference...", expanded=True)
|
| 500 |
+
status.write(f"Running on {get_device()}...")
|
| 501 |
+
|
| 502 |
+
for cls in st.session_state.classes:
|
| 503 |
+
status.write(f"Finding **{cls['name']}** (threshold {cls['threshold']:.0%})...")
|
| 504 |
+
|
| 505 |
+
existing_boxes = [
|
| 506 |
+
d["box"] for d in cls["detections"]
|
| 507 |
+
if d.get("image_path") == fname
|
| 508 |
+
]
|
| 509 |
+
dets = combined_prompt_inference(
|
| 510 |
+
image,
|
| 511 |
+
text=cls["name"],
|
| 512 |
+
boxes=existing_boxes if existing_boxes else None,
|
| 513 |
+
threshold=cls["threshold"],
|
| 514 |
+
)
|
| 515 |
+
for d in dets:
|
| 516 |
+
d["label"] = cls["name"]
|
| 517 |
+
d["accepted"] = True
|
| 518 |
+
d["image_path"] = fname
|
| 519 |
+
d["id"] = _next_detection_id()
|
| 520 |
+
|
| 521 |
+
existing_for_class = [
|
| 522 |
+
d for d in cls["detections"]
|
| 523 |
+
if d.get("image_path") == fname
|
| 524 |
+
]
|
| 525 |
+
unique = deduplicate(dets, existing_for_class) if existing_for_class else dets
|
| 526 |
+
cls["detections"].extend(unique)
|
| 527 |
+
status.write(f" → {len(unique)} new {cls['name']} detection(s)")
|
| 528 |
+
|
| 529 |
+
status.update(label="Inference complete", state="complete")
|
| 530 |
+
st.session_state.label_round += 1
|
| 531 |
+
st.session_state._last_canvas_count = 0
|
| 532 |
+
st.rerun(scope="app")
|
| 533 |
+
with no_col:
|
| 534 |
+
if st.button("Cancel", use_container_width=True):
|
| 535 |
+
st.session_state.confirm_find_all = False
|
| 536 |
+
st.rerun(scope="fragment")
|
| 537 |
+
|
| 538 |
+
find_all_objects()
|
| 539 |
+
|
| 540 |
+
# --- Update Label Manifest button ---
|
| 541 |
+
if st.session_state.classes:
|
| 542 |
+
st.divider()
|
| 543 |
+
if st.button("Update Label Manifest", use_container_width=True):
|
| 544 |
+
all_dets = []
|
| 545 |
+
for cls in st.session_state.classes:
|
| 546 |
+
all_dets.extend(cls["detections"])
|
| 547 |
+
st.session_state.all_image_detections = all_dets
|
| 548 |
+
st.success(f"Manifest updated: {len(all_dets)} detections")
|
| 549 |
+
|
| 550 |
+
# --- Navigation ---
|
| 551 |
+
if st.session_state.image is not None:
|
| 552 |
st.divider()
|
| 553 |
+
total = sum(len(c["detections"]) for c in st.session_state.classes)
|
| 554 |
+
if st.button(f"Done — Export ({total} detections)" if total else "Done — Export"):
|
| 555 |
+
# Flatten all class detections into all_image_detections
|
| 556 |
+
all_dets = []
|
| 557 |
+
for cls in st.session_state.classes:
|
| 558 |
+
all_dets.extend(cls["detections"])
|
| 559 |
+
st.session_state.all_image_detections = all_dets
|
| 560 |
+
go_to(3)
|
| 561 |
+
st.rerun()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 562 |
|
| 563 |
# =============================================================================
|
| 564 |
# Step 3: Export
|
|
|
|
| 567 |
st.header("Step 3: Export Manifest")
|
| 568 |
|
| 569 |
combined = list(st.session_state.all_image_detections)
|
|
|
|
| 570 |
for i, det in enumerate(combined):
|
| 571 |
det["id"] = i
|
| 572 |
|
|
|
|
| 607 |
elif st.session_state.step == 4:
|
| 608 |
st.header("Step 4: Fine-Tune SAM3")
|
| 609 |
|
|
|
|
| 610 |
combined_dets = list(st.session_state.all_image_detections)
|
|
|
|
| 611 |
for det in combined_dets:
|
| 612 |
if "image_path" not in det:
|
| 613 |
det["image_path"] = st.session_state.filename
|
| 614 |
|
|
|
|
| 615 |
train_dets = [d for d in combined_dets if d.get("accepted") and d.get("mask") is not None]
|
| 616 |
image_names = list(set(d["image_path"] for d in train_dets))
|
| 617 |
|
|
|
|
| 623 |
go_to(3)
|
| 624 |
st.rerun()
|
| 625 |
else:
|
|
|
|
| 626 |
col_ep, col_lr = st.columns(2)
|
| 627 |
with col_ep:
|
| 628 |
epochs = st.slider("Epochs", 1, 50, 5, key="train_epochs")
|
|
|
|
| 643 |
processor = None
|
| 644 |
result = None
|
| 645 |
try:
|
|
|
|
| 646 |
status = st.status("Preparing for training...", expanded=True)
|
| 647 |
status.write("Clearing cached inference model to free GPU memory...")
|
| 648 |
load_model.clear()
|
|
|
|
| 651 |
elif _torch.backends.mps.is_available():
|
| 652 |
_torch.mps.empty_cache()
|
| 653 |
|
|
|
|
| 654 |
status.write("Loading fresh model for training...")
|
| 655 |
processor, model = load_model_for_training()
|
| 656 |
|
|
|
|
| 657 |
trainable, total = freeze_encoder(model)
|
| 658 |
status.write(f"Frozen encoder. Trainable params: {trainable:,} / {total:,}")
|
| 659 |
|
|
|
|
| 660 |
images_dict = {name: img for name, img in st.session_state.images}
|
| 661 |
dataset = SAM3FineTuneDataset(images_dict, train_dets, processor)
|
| 662 |
status.write(f"Dataset ready: {len(dataset)} samples")
|
| 663 |
|
|
|
|
| 664 |
status.update(label="Training...", state="running")
|
| 665 |
progress_bar = st.progress(0, text="Starting training...")
|
| 666 |
|
|
|
|
| 673 |
|
| 674 |
st.session_state.training_loss_history = result["loss_history"]
|
| 675 |
|
|
|
|
| 676 |
status.write("Packaging fine-tuned model...")
|
| 677 |
st.session_state.finetuned_model_bytes = get_model_zip_bytes(result["model"], processor)
|
| 678 |
|
| 679 |
st.session_state.training_complete = True
|
| 680 |
status.update(label="Training complete!", state="complete")
|
| 681 |
finally:
|
|
|
|
| 682 |
del model, processor, result
|
| 683 |
if _torch.cuda.is_available():
|
| 684 |
_torch.cuda.empty_cache()
|
|
|
|
| 687 |
|
| 688 |
st.rerun()
|
| 689 |
else:
|
|
|
|
| 690 |
st.success("Training complete!")
|
| 691 |
|
|
|
|
| 692 |
loss_hist = st.session_state.training_loss_history
|
| 693 |
if loss_hist:
|
|
|
|
| 694 |
df = pd.DataFrame({"Epoch": range(1, len(loss_hist) + 1), "Avg Loss": loss_hist})
|
| 695 |
st.line_chart(df, x="Epoch", y="Avg Loss")
|
| 696 |
|
|
|
|
| 697 |
if st.session_state.finetuned_model_bytes:
|
| 698 |
st.download_button(
|
| 699 |
label="Download fine-tuned model (.zip)",
|