demomule / app.py
righthook75's picture
Upload app.py with huggingface_hub
26a0321 verified
import streamlit as st
import pandas as pd
from PIL import Image
from streamlit_drawable_canvas import st_canvas
from sam3_engine import get_device, load_model, load_model_for_training, combined_prompt_inference
from viz import overlay_detections_by_class, _hex_to_rgb, CLASS_COLORS
from manifest import build_manifest, manifest_to_json, deduplicate
from training import SAM3FineTuneDataset, freeze_encoder, run_training, get_model_zip_bytes
# --- Page config ---
st.set_page_config(page_title="SAM3 Training Data Accelerator", layout="wide")
# --- Constants ---
CANVAS_MAX_WIDTH = 700
# --- Session state defaults ---
defaults = {
"step": 2,
"image": None,
"filename": None,
"images": [], # list of (filename, PIL.Image) tuples
"image_index": 0, # current position in batch
"all_image_detections": [], # accumulated detections across ALL images
"classes": [], # list of class dicts
"pending_box_coords": None, # drawn box awaiting class assignment
"detection_id_counter": 0, # monotonic ID for detections
"label_round": 0, # iteration counter for canvas key stability
"canvas_scale": 1.0, # image-to-canvas scale factor
"_last_canvas_count": 0, # track canvas object count for new-drawing detection
"selected_detection_id": None, # ID of detection selected for highlighting
"training_loss_history": [],
"training_complete": False,
"finetuned_model_bytes": None,
}
for key, val in defaults.items():
if key not in st.session_state:
st.session_state[key] = val
def _load_image_at_index(idx: int):
"""Load the image at the given batch index into session state."""
filename, image = st.session_state.images[idx]
st.session_state.image = image
st.session_state.filename = filename
st.session_state.image_index = idx
def go_to(step: int):
st.session_state.step = step
def _next_detection_id() -> int:
st.session_state.detection_id_counter += 1
return st.session_state.detection_id_counter
def _get_current_image_detections(visible_only=False):
"""Get all detections for the current image across all classes."""
fname = st.session_state.filename
if not fname:
return []
dets = []
for cls in st.session_state.classes:
if visible_only and not cls["visible"]:
continue
for det in cls["detections"]:
if det.get("image_path") == fname:
dets.append(det)
return dets
# --- Coordinate scaling helpers ---
def _canvas_to_image(obj: dict, scale: float):
"""Convert a Fabric.js canvas object to image-space coordinates."""
obj_type = obj.get("type", "")
sx = obj.get("scaleX", 1.0)
sy = obj.get("scaleY", 1.0)
left = obj.get("left", 0)
top = obj.get("top", 0)
if obj_type == "rect":
w = obj.get("width", 0) * sx
h = obj.get("height", 0) * sy
return {
"type": "box",
"coords": [
left / scale,
top / scale,
(left + w) / scale,
(top + h) / scale,
],
}
elif obj_type == "circle":
r = obj.get("radius", 0)
cx = (left + r * sx) / scale
cy = (top + r * sy) / scale
return {
"type": "point",
"coords": [cx, cy],
}
return None
def _add_class(name: str):
"""Create a new class and return it."""
color = CLASS_COLORS[len(st.session_state.classes) % len(CLASS_COLORS)]
cls = {
"name": name,
"color": color,
"visible": True,
"threshold": 0.85,
"detections": [],
}
st.session_state.classes.append(cls)
return cls
@st.dialog("Assign to Class")
def assign_drawing_dialog():
"""Modal dialog for assigning a drawn box/point to a class."""
pending = st.session_state.pending_box_coords
if pending is None:
st.warning("No pending drawing.")
return
st.write(f"New **{pending['type']}** drawn. Choose a class to assign it to:")
# Existing class selector
class_names = [c["name"] for c in st.session_state.classes]
chosen_existing = None
if class_names:
chosen_existing = st.selectbox("Existing class", class_names, key="dlg_class_select")
# Or create a new class
st.divider()
new_name = st.text_input("Or create a new class", key="dlg_new_class", placeholder="e.g. Cable, Label...")
st.divider()
assign_col, cancel_col = st.columns(2)
with assign_col:
can_assign = bool(new_name) or bool(chosen_existing)
if st.button("Assign", type="primary", disabled=not can_assign, use_container_width=True):
# Determine target class
if new_name:
existing_names = {c["name"] for c in st.session_state.classes}
if new_name not in existing_names:
target_cls = _add_class(new_name)
else:
target_cls = next(c for c in st.session_state.classes if c["name"] == new_name)
else:
target_cls = next(c for c in st.session_state.classes if c["name"] == chosen_existing)
det = {
"id": _next_detection_id(),
"mask": None,
"box": pending["coords"] if pending["type"] == "box" else [
pending["coords"][0] - 10, pending["coords"][1] - 10,
pending["coords"][0] + 10, pending["coords"][1] + 10,
],
"score": 1.0,
"label": target_cls["name"],
"accepted": True,
"image_path": st.session_state.filename,
}
target_cls["detections"].append(det)
st.session_state.pending_box_coords = None
st.session_state.label_round += 1
st.session_state._last_canvas_count = 0
st.rerun()
with cancel_col:
if st.button("Cancel", use_container_width=True):
st.session_state.pending_box_coords = None
st.session_state.label_round += 1
st.session_state._last_canvas_count = 0
st.rerun()
# --- Sidebar ---
with st.sidebar:
st.title("SAM3 Accelerator")
device = get_device()
st.caption(f"Device: **{device}**")
st.caption("Model: `facebook/sam3`")
with st.spinner("Loading SAM3 model..."):
load_model()
st.caption("Model loaded")
st.divider()
step_labels = ["Label", "Export", "Train"]
current = st.session_state.step
for i, label in enumerate(step_labels, start=2):
if current == i:
marker = f"-> {i}. {label}"
else:
marker = f" {i}. {label}"
st.text(marker)
n_images = len(st.session_state.images)
if n_images > 1:
st.divider()
st.metric("Image", f"{st.session_state.image_index + 1} of {n_images}")
total_dets = sum(len(c["detections"]) for c in st.session_state.classes)
if total_dets:
st.divider()
st.metric("Total detections", total_dets)
st.divider()
if st.button("Start over"):
for key, val in defaults.items():
st.session_state[key] = val
st.rerun()
# =============================================================================
# Step 2: Label (3-column class-centric layout)
# =============================================================================
if st.session_state.step == 2:
col_files, col_canvas, col_controls = st.columns([1, 3, 2])
# --- Left column: File list ---
with col_files:
st.subheader("Images")
uploaded_files = st.file_uploader(
"Upload images",
type=["png", "jpg", "jpeg"],
accept_multiple_files=True,
label_visibility="collapsed",
)
if uploaded_files:
existing_names = {name for name, _ in st.session_state.images}
for f in uploaded_files:
if f.name not in existing_names:
st.session_state.images.append((f.name, Image.open(f).convert("RGB")))
existing_names.add(f.name)
# Auto-load first image if none loaded
if st.session_state.image is None and st.session_state.images:
_load_image_at_index(0)
st.rerun()
# Show file list with thumbnails
if st.session_state.images:
filenames = [name for name, _ in st.session_state.images]
for i, (name, img) in enumerate(st.session_state.images):
st.image(img, width=100)
is_current = (i == st.session_state.image_index)
if st.button(
name,
key=f"file_select_{i}",
type="primary" if is_current else "secondary",
use_container_width=True,
):
if not is_current:
_load_image_at_index(i)
st.session_state.label_round += 1
st.session_state._last_canvas_count = 0
st.session_state.pending_box_coords = None
st.session_state.selected_detection_id = None
st.rerun()
# --- Center column: Canvas ---
with col_canvas:
image = st.session_state.image
if image is None:
st.info("Upload images in the left panel to get started.")
else:
img_idx = st.session_state.image_index
n_images = len(st.session_state.images)
img_label = f" ({img_idx + 1} of {n_images})" if n_images > 1 else ""
st.subheader(f"{st.session_state.filename}{img_label}")
# Compute canvas dimensions
img_w, img_h = image.size
canvas_w = min(img_w, CANVAS_MAX_WIDTH)
scale = canvas_w / img_w
canvas_h = int(img_h * scale)
st.session_state.canvas_scale = scale
# Build background with visible detections overlaid
visible_dets = _get_current_image_detections(visible_only=True)
bg = image.copy()
if visible_dets:
# Build color map from class definitions
color_map = {}
for cls in st.session_state.classes:
if cls["visible"]:
color_map[cls["name"]] = _hex_to_rgb(cls["color"])
color_map[""] = (180, 180, 180)
hl_ids = {st.session_state.selected_detection_id} if st.session_state.selected_detection_id is not None else None
bg = overlay_detections_by_class(bg, visible_dets, color_override=color_map, highlight_ids=hl_ids)
bg_rgb = bg.convert("RGB")
# Drawing mode
drawing_mode = st.radio(
"Drawing mode",
["rect", "point", "transform"],
horizontal=True,
key="drawing_mode",
)
canvas_result = st_canvas(
fill_color="rgba(255, 0, 0, 0.1)",
stroke_width=2,
stroke_color="red",
background_image=bg_rgb,
width=canvas_w,
height=canvas_h,
drawing_mode=drawing_mode,
point_display_radius=5,
key=f"canvas_{img_idx}_{st.session_state.label_round}",
)
# Detect new drawings
if canvas_result.json_data is not None:
canvas_objects = canvas_result.json_data.get("objects", [])
n_canvas = len(canvas_objects)
last_count = st.session_state._last_canvas_count
if n_canvas > last_count and st.session_state.pending_box_coords is None:
# New object drawn — convert the last one
new_obj = canvas_objects[-1]
converted = _canvas_to_image(new_obj, scale)
if converted:
st.session_state.pending_box_coords = converted
st.session_state._last_canvas_count = n_canvas
st.rerun()
# Open assignment dialog when a new drawing is pending
if st.session_state.pending_box_coords is not None:
assign_drawing_dialog()
# --- Right column: Class controls ---
with col_controls:
st.subheader("Classes")
# Class input
new_class = st.text_input("New class name", key="new_class_input", placeholder="e.g. Server, Cable, Label...")
if new_class:
existing_names = {c["name"] for c in st.session_state.classes}
if new_class not in existing_names:
color = CLASS_COLORS[len(st.session_state.classes) % len(CLASS_COLORS)]
st.session_state.classes.append({
"name": new_class,
"color": color,
"visible": True,
"threshold": 0.85,
"detections": [],
})
st.rerun()
# Class cards
classes_to_delete = []
dets_to_delete = [] # list of (class_idx, det_id)
find_single_class_idx = None # index of class to run per-class find
for ci, cls in enumerate(st.session_state.classes):
with st.container(border=True):
# Header row
hcol_name, hcol_vis, hcol_del = st.columns([3, 1, 1])
with hcol_name:
st.markdown(
f"<span style='color:{cls['color']};font-weight:bold;font-size:1.1em'>"
f"{cls['name']}</span>",
unsafe_allow_html=True,
)
with hcol_vis:
vis = st.checkbox("👁", value=cls["visible"], key=f"vis_{ci}", label_visibility="collapsed")
if vis != cls["visible"]:
st.session_state.classes[ci]["visible"] = vis
st.rerun()
with hcol_del:
if st.button("🗑", key=f"del_class_{ci}"):
classes_to_delete.append(ci)
# Detections for current image — colored buttons
fname = st.session_state.filename
if fname:
img_dets = [d for d in cls["detections"] if d.get("image_path") == fname]
if img_dets:
for det in img_dets:
dcol_label, dcol_del = st.columns([4, 1])
with dcol_label:
is_sel = st.session_state.selected_detection_id == det["id"]
# Colored detection button via markdown + button
border_style = "3px solid yellow" if is_sel else f"2px solid {cls['color']}"
st.markdown(
f"<div style='background:{cls['color']}22;border:{border_style};"
f"border-radius:6px;padding:4px 8px;text-align:center;"
f"color:{cls['color']};font-weight:600;cursor:default'>"
f"{cls['name']} {det['id']}{det['score']:.0%}</div>",
unsafe_allow_html=True,
)
if st.button(
"Select" if not is_sel else "Deselect",
key=f"sel_det_{ci}_{det['id']}",
use_container_width=True,
):
if is_sel:
st.session_state.selected_detection_id = None
else:
st.session_state.selected_detection_id = det["id"]
st.rerun()
with dcol_del:
if st.button("🗑", key=f"del_det_{ci}_{det['id']}"):
dets_to_delete.append((ci, det["id"]))
else:
st.caption("No detections on this image")
# Per-class confidence threshold
new_thresh = st.slider(
"Confidence threshold", 0.0, 1.0, cls["threshold"], 0.05,
key=f"thresh_{ci}",
)
st.caption(f"Default 85%")
if new_thresh != cls["threshold"]:
st.session_state.classes[ci]["threshold"] = new_thresh
# Per-class Find Objects button
if st.session_state.image is not None:
if st.button(f"🔍 Find Objects for this Class", key=f"find_class_{ci}", use_container_width=True):
find_single_class_idx = ci
# Process deletions
if classes_to_delete:
for ci in sorted(classes_to_delete, reverse=True):
st.session_state.classes.pop(ci)
st.rerun()
if dets_to_delete:
for ci, det_id in dets_to_delete:
if st.session_state.selected_detection_id == det_id:
st.session_state.selected_detection_id = None
st.session_state.classes[ci]["detections"] = [
d for d in st.session_state.classes[ci]["detections"] if d["id"] != det_id
]
st.session_state.label_round += 1
st.session_state._last_canvas_count = 0
st.rerun()
# --- Per-class Find Objects execution ---
if find_single_class_idx is not None:
cls = st.session_state.classes[find_single_class_idx]
image = st.session_state.image
fname = st.session_state.filename
status = st.status(f"Finding {cls['name']}...", expanded=True)
status.write(f"Running on {get_device()} (threshold {cls['threshold']:.0%})...")
existing_boxes = [
d["box"] for d in cls["detections"]
if d.get("image_path") == fname
]
dets = combined_prompt_inference(
image,
text=cls["name"],
boxes=existing_boxes if existing_boxes else None,
threshold=cls["threshold"],
)
for d in dets:
d["label"] = cls["name"]
d["accepted"] = True
d["image_path"] = fname
d["id"] = _next_detection_id()
existing_for_class = [
d for d in cls["detections"]
if d.get("image_path") == fname
]
unique = deduplicate(dets, existing_for_class) if existing_for_class else dets
cls["detections"].extend(unique)
status.write(f"Found {len(unique)} new {cls['name']} detection(s)")
status.update(label=f"Found {len(unique)} {cls['name']}", state="complete")
st.session_state.label_round += 1
st.session_state._last_canvas_count = 0
st.rerun()
# --- Find Objects for ALL classes button (with confirmation) ---
if st.session_state.classes and st.session_state.image is not None:
st.divider()
@st.fragment
def find_all_objects():
if "confirm_find_all" not in st.session_state:
st.session_state.confirm_find_all = False
if not st.session_state.confirm_find_all:
if st.button("Find Objects for all classes", use_container_width=True):
st.session_state.confirm_find_all = True
st.rerun(scope="fragment")
else:
st.warning(f"This will run SAM3 for **{len(st.session_state.classes)}** class(es). Continue?")
yes_col, no_col = st.columns(2)
with yes_col:
if st.button("Yes, find all", type="primary", use_container_width=True):
st.session_state.confirm_find_all = False
image = st.session_state.image
fname = st.session_state.filename
status = st.status("Running SAM3 inference...", expanded=True)
status.write(f"Running on {get_device()}...")
for cls in st.session_state.classes:
status.write(f"Finding **{cls['name']}** (threshold {cls['threshold']:.0%})...")
existing_boxes = [
d["box"] for d in cls["detections"]
if d.get("image_path") == fname
]
dets = combined_prompt_inference(
image,
text=cls["name"],
boxes=existing_boxes if existing_boxes else None,
threshold=cls["threshold"],
)
for d in dets:
d["label"] = cls["name"]
d["accepted"] = True
d["image_path"] = fname
d["id"] = _next_detection_id()
existing_for_class = [
d for d in cls["detections"]
if d.get("image_path") == fname
]
unique = deduplicate(dets, existing_for_class) if existing_for_class else dets
cls["detections"].extend(unique)
status.write(f" → {len(unique)} new {cls['name']} detection(s)")
status.update(label="Inference complete", state="complete")
st.session_state.label_round += 1
st.session_state._last_canvas_count = 0
st.rerun(scope="app")
with no_col:
if st.button("Cancel", use_container_width=True):
st.session_state.confirm_find_all = False
st.rerun(scope="fragment")
find_all_objects()
# --- Update Label Manifest button ---
if st.session_state.classes:
st.divider()
if st.button("Update Label Manifest", use_container_width=True):
all_dets = []
for cls in st.session_state.classes:
all_dets.extend(cls["detections"])
st.session_state.all_image_detections = all_dets
st.success(f"Manifest updated: {len(all_dets)} detections")
# --- Navigation ---
if st.session_state.image is not None:
st.divider()
total = sum(len(c["detections"]) for c in st.session_state.classes)
if st.button(f"Done — Export ({total} detections)" if total else "Done — Export"):
# Flatten all class detections into all_image_detections
all_dets = []
for cls in st.session_state.classes:
all_dets.extend(cls["detections"])
st.session_state.all_image_detections = all_dets
go_to(3)
st.rerun()
# =============================================================================
# Step 3: Export
# =============================================================================
elif st.session_state.step == 3:
st.header("Step 3: Export Manifest")
combined = list(st.session_state.all_image_detections)
for i, det in enumerate(combined):
det["id"] = i
manifest = build_manifest(combined)
if not manifest:
st.warning("No accepted and labeled detections to export.")
else:
st.success(f"Combined manifest: **{len(manifest)}** entries across **{len(set(e['image_path'] for e in manifest))}** image(s)")
json_str = manifest_to_json(manifest)
st.code(json_str, language="json")
st.download_button(
label="Download manifest JSON",
data=json_str,
file_name="sam3_manifest.json",
mime="application/json",
)
col1, col2, col3 = st.columns(3)
with col1:
if st.button("Back: Label more"):
go_to(2)
st.rerun()
with col2:
if manifest and st.button("Next: Fine-tune model", type="primary"):
go_to(4)
st.rerun()
with col3:
if st.button("Start over"):
for key, val in defaults.items():
st.session_state[key] = val
st.rerun()
# =============================================================================
# Step 4: Fine-Tune SAM3
# =============================================================================
elif st.session_state.step == 4:
st.header("Step 4: Fine-Tune SAM3")
combined_dets = list(st.session_state.all_image_detections)
for det in combined_dets:
if "image_path" not in det:
det["image_path"] = st.session_state.filename
train_dets = [d for d in combined_dets if d.get("accepted") and d.get("mask") is not None]
image_names = list(set(d["image_path"] for d in train_dets))
st.info(f"Training data: **{len(train_dets)}** detections across **{len(image_names)}** image(s)")
if not train_dets:
st.warning("No accepted detections with masks available. Go back and label some objects first.")
if st.button("Back to Export"):
go_to(3)
st.rerun()
else:
col_ep, col_lr = st.columns(2)
with col_ep:
epochs = st.slider("Epochs", 1, 50, 5, key="train_epochs")
with col_lr:
lr = st.select_slider(
"Learning rate",
options=[1e-6, 5e-6, 1e-5, 5e-5, 1e-4],
value=1e-5,
format_func=lambda x: f"{x:.0e}",
key="train_lr",
)
if not st.session_state.training_complete:
if st.button("Start training", type="primary"):
import torch as _torch
model = None
processor = None
result = None
try:
status = st.status("Preparing for training...", expanded=True)
status.write("Clearing cached inference model to free GPU memory...")
load_model.clear()
if _torch.cuda.is_available():
_torch.cuda.empty_cache()
elif _torch.backends.mps.is_available():
_torch.mps.empty_cache()
status.write("Loading fresh model for training...")
processor, model = load_model_for_training()
trainable, total = freeze_encoder(model)
status.write(f"Frozen encoder. Trainable params: {trainable:,} / {total:,}")
images_dict = {name: img for name, img in st.session_state.images}
dataset = SAM3FineTuneDataset(images_dict, train_dets, processor)
status.write(f"Dataset ready: {len(dataset)} samples")
status.update(label="Training...", state="running")
progress_bar = st.progress(0, text="Starting training...")
def on_progress(epoch, step, total_steps, loss_val):
pct = (step + 1) / total_steps
progress_bar.progress(pct, text=f"Epoch {epoch + 1}/{epochs} | Step {step + 1}/{total_steps} | Loss: {loss_val:.4f}")
result = run_training(model, processor, dataset, epochs, lr, progress_callback=on_progress)
progress_bar.progress(1.0, text="Training complete!")
st.session_state.training_loss_history = result["loss_history"]
status.write("Packaging fine-tuned model...")
st.session_state.finetuned_model_bytes = get_model_zip_bytes(result["model"], processor)
st.session_state.training_complete = True
status.update(label="Training complete!", state="complete")
finally:
del model, processor, result
if _torch.cuda.is_available():
_torch.cuda.empty_cache()
elif _torch.backends.mps.is_available():
_torch.mps.empty_cache()
st.rerun()
else:
st.success("Training complete!")
loss_hist = st.session_state.training_loss_history
if loss_hist:
df = pd.DataFrame({"Epoch": range(1, len(loss_hist) + 1), "Avg Loss": loss_hist})
st.line_chart(df, x="Epoch", y="Avg Loss")
if st.session_state.finetuned_model_bytes:
st.download_button(
label="Download fine-tuned model (.zip)",
data=st.session_state.finetuned_model_bytes,
file_name="sam3_finetuned.zip",
mime="application/zip",
)
st.divider()
col1, col2 = st.columns(2)
with col1:
if st.button("Back to Export"):
go_to(3)
st.rerun()
with col2:
if st.button("Start over"):
for key, val in defaults.items():
st.session_state[key] = val
st.rerun()