Spaces:
Paused
Paused
| import streamlit as st | |
| import pandas as pd | |
| from PIL import Image | |
| from streamlit_drawable_canvas import st_canvas | |
| from sam3_engine import get_device, load_model, load_model_for_training, combined_prompt_inference | |
| from viz import overlay_detections_by_class, _hex_to_rgb, CLASS_COLORS | |
| from manifest import build_manifest, manifest_to_json, deduplicate | |
| from training import SAM3FineTuneDataset, freeze_encoder, run_training, get_model_zip_bytes | |
| # --- Page config --- | |
| st.set_page_config(page_title="SAM3 Training Data Accelerator", layout="wide") | |
| # --- Constants --- | |
| CANVAS_MAX_WIDTH = 700 | |
| # --- Session state defaults --- | |
| defaults = { | |
| "step": 2, | |
| "image": None, | |
| "filename": None, | |
| "images": [], # list of (filename, PIL.Image) tuples | |
| "image_index": 0, # current position in batch | |
| "all_image_detections": [], # accumulated detections across ALL images | |
| "classes": [], # list of class dicts | |
| "pending_box_coords": None, # drawn box awaiting class assignment | |
| "detection_id_counter": 0, # monotonic ID for detections | |
| "label_round": 0, # iteration counter for canvas key stability | |
| "canvas_scale": 1.0, # image-to-canvas scale factor | |
| "_last_canvas_count": 0, # track canvas object count for new-drawing detection | |
| "selected_detection_id": None, # ID of detection selected for highlighting | |
| "training_loss_history": [], | |
| "training_complete": False, | |
| "finetuned_model_bytes": None, | |
| } | |
| for key, val in defaults.items(): | |
| if key not in st.session_state: | |
| st.session_state[key] = val | |
| def _load_image_at_index(idx: int): | |
| """Load the image at the given batch index into session state.""" | |
| filename, image = st.session_state.images[idx] | |
| st.session_state.image = image | |
| st.session_state.filename = filename | |
| st.session_state.image_index = idx | |
| def go_to(step: int): | |
| st.session_state.step = step | |
| def _next_detection_id() -> int: | |
| st.session_state.detection_id_counter += 1 | |
| return st.session_state.detection_id_counter | |
| def _get_current_image_detections(visible_only=False): | |
| """Get all detections for the current image across all classes.""" | |
| fname = st.session_state.filename | |
| if not fname: | |
| return [] | |
| dets = [] | |
| for cls in st.session_state.classes: | |
| if visible_only and not cls["visible"]: | |
| continue | |
| for det in cls["detections"]: | |
| if det.get("image_path") == fname: | |
| dets.append(det) | |
| return dets | |
| # --- Coordinate scaling helpers --- | |
| def _canvas_to_image(obj: dict, scale: float): | |
| """Convert a Fabric.js canvas object to image-space coordinates.""" | |
| obj_type = obj.get("type", "") | |
| sx = obj.get("scaleX", 1.0) | |
| sy = obj.get("scaleY", 1.0) | |
| left = obj.get("left", 0) | |
| top = obj.get("top", 0) | |
| if obj_type == "rect": | |
| w = obj.get("width", 0) * sx | |
| h = obj.get("height", 0) * sy | |
| return { | |
| "type": "box", | |
| "coords": [ | |
| left / scale, | |
| top / scale, | |
| (left + w) / scale, | |
| (top + h) / scale, | |
| ], | |
| } | |
| elif obj_type == "circle": | |
| r = obj.get("radius", 0) | |
| cx = (left + r * sx) / scale | |
| cy = (top + r * sy) / scale | |
| return { | |
| "type": "point", | |
| "coords": [cx, cy], | |
| } | |
| return None | |
| def _add_class(name: str): | |
| """Create a new class and return it.""" | |
| color = CLASS_COLORS[len(st.session_state.classes) % len(CLASS_COLORS)] | |
| cls = { | |
| "name": name, | |
| "color": color, | |
| "visible": True, | |
| "threshold": 0.85, | |
| "detections": [], | |
| } | |
| st.session_state.classes.append(cls) | |
| return cls | |
| def assign_drawing_dialog(): | |
| """Modal dialog for assigning a drawn box/point to a class.""" | |
| pending = st.session_state.pending_box_coords | |
| if pending is None: | |
| st.warning("No pending drawing.") | |
| return | |
| st.write(f"New **{pending['type']}** drawn. Choose a class to assign it to:") | |
| # Existing class selector | |
| class_names = [c["name"] for c in st.session_state.classes] | |
| chosen_existing = None | |
| if class_names: | |
| chosen_existing = st.selectbox("Existing class", class_names, key="dlg_class_select") | |
| # Or create a new class | |
| st.divider() | |
| new_name = st.text_input("Or create a new class", key="dlg_new_class", placeholder="e.g. Cable, Label...") | |
| st.divider() | |
| assign_col, cancel_col = st.columns(2) | |
| with assign_col: | |
| can_assign = bool(new_name) or bool(chosen_existing) | |
| if st.button("Assign", type="primary", disabled=not can_assign, use_container_width=True): | |
| # Determine target class | |
| if new_name: | |
| existing_names = {c["name"] for c in st.session_state.classes} | |
| if new_name not in existing_names: | |
| target_cls = _add_class(new_name) | |
| else: | |
| target_cls = next(c for c in st.session_state.classes if c["name"] == new_name) | |
| else: | |
| target_cls = next(c for c in st.session_state.classes if c["name"] == chosen_existing) | |
| det = { | |
| "id": _next_detection_id(), | |
| "mask": None, | |
| "box": pending["coords"] if pending["type"] == "box" else [ | |
| pending["coords"][0] - 10, pending["coords"][1] - 10, | |
| pending["coords"][0] + 10, pending["coords"][1] + 10, | |
| ], | |
| "score": 1.0, | |
| "label": target_cls["name"], | |
| "accepted": True, | |
| "image_path": st.session_state.filename, | |
| } | |
| target_cls["detections"].append(det) | |
| st.session_state.pending_box_coords = None | |
| st.session_state.label_round += 1 | |
| st.session_state._last_canvas_count = 0 | |
| st.rerun() | |
| with cancel_col: | |
| if st.button("Cancel", use_container_width=True): | |
| st.session_state.pending_box_coords = None | |
| st.session_state.label_round += 1 | |
| st.session_state._last_canvas_count = 0 | |
| st.rerun() | |
| # --- Sidebar --- | |
| with st.sidebar: | |
| st.title("SAM3 Accelerator") | |
| device = get_device() | |
| st.caption(f"Device: **{device}**") | |
| st.caption("Model: `facebook/sam3`") | |
| with st.spinner("Loading SAM3 model..."): | |
| load_model() | |
| st.caption("Model loaded") | |
| st.divider() | |
| step_labels = ["Label", "Export", "Train"] | |
| current = st.session_state.step | |
| for i, label in enumerate(step_labels, start=2): | |
| if current == i: | |
| marker = f"-> {i}. {label}" | |
| else: | |
| marker = f" {i}. {label}" | |
| st.text(marker) | |
| n_images = len(st.session_state.images) | |
| if n_images > 1: | |
| st.divider() | |
| st.metric("Image", f"{st.session_state.image_index + 1} of {n_images}") | |
| total_dets = sum(len(c["detections"]) for c in st.session_state.classes) | |
| if total_dets: | |
| st.divider() | |
| st.metric("Total detections", total_dets) | |
| st.divider() | |
| if st.button("Start over"): | |
| for key, val in defaults.items(): | |
| st.session_state[key] = val | |
| st.rerun() | |
| # ============================================================================= | |
| # Step 2: Label (3-column class-centric layout) | |
| # ============================================================================= | |
| if st.session_state.step == 2: | |
| col_files, col_canvas, col_controls = st.columns([1, 3, 2]) | |
| # --- Left column: File list --- | |
| with col_files: | |
| st.subheader("Images") | |
| uploaded_files = st.file_uploader( | |
| "Upload images", | |
| type=["png", "jpg", "jpeg"], | |
| accept_multiple_files=True, | |
| label_visibility="collapsed", | |
| ) | |
| if uploaded_files: | |
| existing_names = {name for name, _ in st.session_state.images} | |
| for f in uploaded_files: | |
| if f.name not in existing_names: | |
| st.session_state.images.append((f.name, Image.open(f).convert("RGB"))) | |
| existing_names.add(f.name) | |
| # Auto-load first image if none loaded | |
| if st.session_state.image is None and st.session_state.images: | |
| _load_image_at_index(0) | |
| st.rerun() | |
| # Show file list with thumbnails | |
| if st.session_state.images: | |
| filenames = [name for name, _ in st.session_state.images] | |
| for i, (name, img) in enumerate(st.session_state.images): | |
| st.image(img, width=100) | |
| is_current = (i == st.session_state.image_index) | |
| if st.button( | |
| name, | |
| key=f"file_select_{i}", | |
| type="primary" if is_current else "secondary", | |
| use_container_width=True, | |
| ): | |
| if not is_current: | |
| _load_image_at_index(i) | |
| st.session_state.label_round += 1 | |
| st.session_state._last_canvas_count = 0 | |
| st.session_state.pending_box_coords = None | |
| st.session_state.selected_detection_id = None | |
| st.rerun() | |
| # --- Center column: Canvas --- | |
| with col_canvas: | |
| image = st.session_state.image | |
| if image is None: | |
| st.info("Upload images in the left panel to get started.") | |
| else: | |
| img_idx = st.session_state.image_index | |
| n_images = len(st.session_state.images) | |
| img_label = f" ({img_idx + 1} of {n_images})" if n_images > 1 else "" | |
| st.subheader(f"{st.session_state.filename}{img_label}") | |
| # Compute canvas dimensions | |
| img_w, img_h = image.size | |
| canvas_w = min(img_w, CANVAS_MAX_WIDTH) | |
| scale = canvas_w / img_w | |
| canvas_h = int(img_h * scale) | |
| st.session_state.canvas_scale = scale | |
| # Build background with visible detections overlaid | |
| visible_dets = _get_current_image_detections(visible_only=True) | |
| bg = image.copy() | |
| if visible_dets: | |
| # Build color map from class definitions | |
| color_map = {} | |
| for cls in st.session_state.classes: | |
| if cls["visible"]: | |
| color_map[cls["name"]] = _hex_to_rgb(cls["color"]) | |
| color_map[""] = (180, 180, 180) | |
| hl_ids = {st.session_state.selected_detection_id} if st.session_state.selected_detection_id is not None else None | |
| bg = overlay_detections_by_class(bg, visible_dets, color_override=color_map, highlight_ids=hl_ids) | |
| bg_rgb = bg.convert("RGB") | |
| # Drawing mode | |
| drawing_mode = st.radio( | |
| "Drawing mode", | |
| ["rect", "point", "transform"], | |
| horizontal=True, | |
| key="drawing_mode", | |
| ) | |
| canvas_result = st_canvas( | |
| fill_color="rgba(255, 0, 0, 0.1)", | |
| stroke_width=2, | |
| stroke_color="red", | |
| background_image=bg_rgb, | |
| width=canvas_w, | |
| height=canvas_h, | |
| drawing_mode=drawing_mode, | |
| point_display_radius=5, | |
| key=f"canvas_{img_idx}_{st.session_state.label_round}", | |
| ) | |
| # Detect new drawings | |
| if canvas_result.json_data is not None: | |
| canvas_objects = canvas_result.json_data.get("objects", []) | |
| n_canvas = len(canvas_objects) | |
| last_count = st.session_state._last_canvas_count | |
| if n_canvas > last_count and st.session_state.pending_box_coords is None: | |
| # New object drawn — convert the last one | |
| new_obj = canvas_objects[-1] | |
| converted = _canvas_to_image(new_obj, scale) | |
| if converted: | |
| st.session_state.pending_box_coords = converted | |
| st.session_state._last_canvas_count = n_canvas | |
| st.rerun() | |
| # Open assignment dialog when a new drawing is pending | |
| if st.session_state.pending_box_coords is not None: | |
| assign_drawing_dialog() | |
| # --- Right column: Class controls --- | |
| with col_controls: | |
| st.subheader("Classes") | |
| # Class input | |
| new_class = st.text_input("New class name", key="new_class_input", placeholder="e.g. Server, Cable, Label...") | |
| if new_class: | |
| existing_names = {c["name"] for c in st.session_state.classes} | |
| if new_class not in existing_names: | |
| color = CLASS_COLORS[len(st.session_state.classes) % len(CLASS_COLORS)] | |
| st.session_state.classes.append({ | |
| "name": new_class, | |
| "color": color, | |
| "visible": True, | |
| "threshold": 0.85, | |
| "detections": [], | |
| }) | |
| st.rerun() | |
| # Class cards | |
| classes_to_delete = [] | |
| dets_to_delete = [] # list of (class_idx, det_id) | |
| find_single_class_idx = None # index of class to run per-class find | |
| for ci, cls in enumerate(st.session_state.classes): | |
| with st.container(border=True): | |
| # Header row | |
| hcol_name, hcol_vis, hcol_del = st.columns([3, 1, 1]) | |
| with hcol_name: | |
| st.markdown( | |
| f"<span style='color:{cls['color']};font-weight:bold;font-size:1.1em'>" | |
| f"{cls['name']}</span>", | |
| unsafe_allow_html=True, | |
| ) | |
| with hcol_vis: | |
| vis = st.checkbox("👁", value=cls["visible"], key=f"vis_{ci}", label_visibility="collapsed") | |
| if vis != cls["visible"]: | |
| st.session_state.classes[ci]["visible"] = vis | |
| st.rerun() | |
| with hcol_del: | |
| if st.button("🗑", key=f"del_class_{ci}"): | |
| classes_to_delete.append(ci) | |
| # Detections for current image — colored buttons | |
| fname = st.session_state.filename | |
| if fname: | |
| img_dets = [d for d in cls["detections"] if d.get("image_path") == fname] | |
| if img_dets: | |
| for det in img_dets: | |
| dcol_label, dcol_del = st.columns([4, 1]) | |
| with dcol_label: | |
| is_sel = st.session_state.selected_detection_id == det["id"] | |
| # Colored detection button via markdown + button | |
| border_style = "3px solid yellow" if is_sel else f"2px solid {cls['color']}" | |
| st.markdown( | |
| f"<div style='background:{cls['color']}22;border:{border_style};" | |
| f"border-radius:6px;padding:4px 8px;text-align:center;" | |
| f"color:{cls['color']};font-weight:600;cursor:default'>" | |
| f"{cls['name']} {det['id']} — {det['score']:.0%}</div>", | |
| unsafe_allow_html=True, | |
| ) | |
| if st.button( | |
| "Select" if not is_sel else "Deselect", | |
| key=f"sel_det_{ci}_{det['id']}", | |
| use_container_width=True, | |
| ): | |
| if is_sel: | |
| st.session_state.selected_detection_id = None | |
| else: | |
| st.session_state.selected_detection_id = det["id"] | |
| st.rerun() | |
| with dcol_del: | |
| if st.button("🗑", key=f"del_det_{ci}_{det['id']}"): | |
| dets_to_delete.append((ci, det["id"])) | |
| else: | |
| st.caption("No detections on this image") | |
| # Per-class confidence threshold | |
| new_thresh = st.slider( | |
| "Confidence threshold", 0.0, 1.0, cls["threshold"], 0.05, | |
| key=f"thresh_{ci}", | |
| ) | |
| st.caption(f"Default 85%") | |
| if new_thresh != cls["threshold"]: | |
| st.session_state.classes[ci]["threshold"] = new_thresh | |
| # Per-class Find Objects button | |
| if st.session_state.image is not None: | |
| if st.button(f"🔍 Find Objects for this Class", key=f"find_class_{ci}", use_container_width=True): | |
| find_single_class_idx = ci | |
| # Process deletions | |
| if classes_to_delete: | |
| for ci in sorted(classes_to_delete, reverse=True): | |
| st.session_state.classes.pop(ci) | |
| st.rerun() | |
| if dets_to_delete: | |
| for ci, det_id in dets_to_delete: | |
| if st.session_state.selected_detection_id == det_id: | |
| st.session_state.selected_detection_id = None | |
| st.session_state.classes[ci]["detections"] = [ | |
| d for d in st.session_state.classes[ci]["detections"] if d["id"] != det_id | |
| ] | |
| st.session_state.label_round += 1 | |
| st.session_state._last_canvas_count = 0 | |
| st.rerun() | |
| # --- Per-class Find Objects execution --- | |
| if find_single_class_idx is not None: | |
| cls = st.session_state.classes[find_single_class_idx] | |
| image = st.session_state.image | |
| fname = st.session_state.filename | |
| status = st.status(f"Finding {cls['name']}...", expanded=True) | |
| status.write(f"Running on {get_device()} (threshold {cls['threshold']:.0%})...") | |
| existing_boxes = [ | |
| d["box"] for d in cls["detections"] | |
| if d.get("image_path") == fname | |
| ] | |
| dets = combined_prompt_inference( | |
| image, | |
| text=cls["name"], | |
| boxes=existing_boxes if existing_boxes else None, | |
| threshold=cls["threshold"], | |
| ) | |
| for d in dets: | |
| d["label"] = cls["name"] | |
| d["accepted"] = True | |
| d["image_path"] = fname | |
| d["id"] = _next_detection_id() | |
| existing_for_class = [ | |
| d for d in cls["detections"] | |
| if d.get("image_path") == fname | |
| ] | |
| unique = deduplicate(dets, existing_for_class) if existing_for_class else dets | |
| cls["detections"].extend(unique) | |
| status.write(f"Found {len(unique)} new {cls['name']} detection(s)") | |
| status.update(label=f"Found {len(unique)} {cls['name']}", state="complete") | |
| st.session_state.label_round += 1 | |
| st.session_state._last_canvas_count = 0 | |
| st.rerun() | |
| # --- Find Objects for ALL classes button (with confirmation) --- | |
| if st.session_state.classes and st.session_state.image is not None: | |
| st.divider() | |
| def find_all_objects(): | |
| if "confirm_find_all" not in st.session_state: | |
| st.session_state.confirm_find_all = False | |
| if not st.session_state.confirm_find_all: | |
| if st.button("Find Objects for all classes", use_container_width=True): | |
| st.session_state.confirm_find_all = True | |
| st.rerun(scope="fragment") | |
| else: | |
| st.warning(f"This will run SAM3 for **{len(st.session_state.classes)}** class(es). Continue?") | |
| yes_col, no_col = st.columns(2) | |
| with yes_col: | |
| if st.button("Yes, find all", type="primary", use_container_width=True): | |
| st.session_state.confirm_find_all = False | |
| image = st.session_state.image | |
| fname = st.session_state.filename | |
| status = st.status("Running SAM3 inference...", expanded=True) | |
| status.write(f"Running on {get_device()}...") | |
| for cls in st.session_state.classes: | |
| status.write(f"Finding **{cls['name']}** (threshold {cls['threshold']:.0%})...") | |
| existing_boxes = [ | |
| d["box"] for d in cls["detections"] | |
| if d.get("image_path") == fname | |
| ] | |
| dets = combined_prompt_inference( | |
| image, | |
| text=cls["name"], | |
| boxes=existing_boxes if existing_boxes else None, | |
| threshold=cls["threshold"], | |
| ) | |
| for d in dets: | |
| d["label"] = cls["name"] | |
| d["accepted"] = True | |
| d["image_path"] = fname | |
| d["id"] = _next_detection_id() | |
| existing_for_class = [ | |
| d for d in cls["detections"] | |
| if d.get("image_path") == fname | |
| ] | |
| unique = deduplicate(dets, existing_for_class) if existing_for_class else dets | |
| cls["detections"].extend(unique) | |
| status.write(f" → {len(unique)} new {cls['name']} detection(s)") | |
| status.update(label="Inference complete", state="complete") | |
| st.session_state.label_round += 1 | |
| st.session_state._last_canvas_count = 0 | |
| st.rerun(scope="app") | |
| with no_col: | |
| if st.button("Cancel", use_container_width=True): | |
| st.session_state.confirm_find_all = False | |
| st.rerun(scope="fragment") | |
| find_all_objects() | |
| # --- Update Label Manifest button --- | |
| if st.session_state.classes: | |
| st.divider() | |
| if st.button("Update Label Manifest", use_container_width=True): | |
| all_dets = [] | |
| for cls in st.session_state.classes: | |
| all_dets.extend(cls["detections"]) | |
| st.session_state.all_image_detections = all_dets | |
| st.success(f"Manifest updated: {len(all_dets)} detections") | |
| # --- Navigation --- | |
| if st.session_state.image is not None: | |
| st.divider() | |
| total = sum(len(c["detections"]) for c in st.session_state.classes) | |
| if st.button(f"Done — Export ({total} detections)" if total else "Done — Export"): | |
| # Flatten all class detections into all_image_detections | |
| all_dets = [] | |
| for cls in st.session_state.classes: | |
| all_dets.extend(cls["detections"]) | |
| st.session_state.all_image_detections = all_dets | |
| go_to(3) | |
| st.rerun() | |
| # ============================================================================= | |
| # Step 3: Export | |
| # ============================================================================= | |
| elif st.session_state.step == 3: | |
| st.header("Step 3: Export Manifest") | |
| combined = list(st.session_state.all_image_detections) | |
| for i, det in enumerate(combined): | |
| det["id"] = i | |
| manifest = build_manifest(combined) | |
| if not manifest: | |
| st.warning("No accepted and labeled detections to export.") | |
| else: | |
| st.success(f"Combined manifest: **{len(manifest)}** entries across **{len(set(e['image_path'] for e in manifest))}** image(s)") | |
| json_str = manifest_to_json(manifest) | |
| st.code(json_str, language="json") | |
| st.download_button( | |
| label="Download manifest JSON", | |
| data=json_str, | |
| file_name="sam3_manifest.json", | |
| mime="application/json", | |
| ) | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| if st.button("Back: Label more"): | |
| go_to(2) | |
| st.rerun() | |
| with col2: | |
| if manifest and st.button("Next: Fine-tune model", type="primary"): | |
| go_to(4) | |
| st.rerun() | |
| with col3: | |
| if st.button("Start over"): | |
| for key, val in defaults.items(): | |
| st.session_state[key] = val | |
| st.rerun() | |
| # ============================================================================= | |
| # Step 4: Fine-Tune SAM3 | |
| # ============================================================================= | |
| elif st.session_state.step == 4: | |
| st.header("Step 4: Fine-Tune SAM3") | |
| combined_dets = list(st.session_state.all_image_detections) | |
| for det in combined_dets: | |
| if "image_path" not in det: | |
| det["image_path"] = st.session_state.filename | |
| train_dets = [d for d in combined_dets if d.get("accepted") and d.get("mask") is not None] | |
| image_names = list(set(d["image_path"] for d in train_dets)) | |
| st.info(f"Training data: **{len(train_dets)}** detections across **{len(image_names)}** image(s)") | |
| if not train_dets: | |
| st.warning("No accepted detections with masks available. Go back and label some objects first.") | |
| if st.button("Back to Export"): | |
| go_to(3) | |
| st.rerun() | |
| else: | |
| col_ep, col_lr = st.columns(2) | |
| with col_ep: | |
| epochs = st.slider("Epochs", 1, 50, 5, key="train_epochs") | |
| with col_lr: | |
| lr = st.select_slider( | |
| "Learning rate", | |
| options=[1e-6, 5e-6, 1e-5, 5e-5, 1e-4], | |
| value=1e-5, | |
| format_func=lambda x: f"{x:.0e}", | |
| key="train_lr", | |
| ) | |
| if not st.session_state.training_complete: | |
| if st.button("Start training", type="primary"): | |
| import torch as _torch | |
| model = None | |
| processor = None | |
| result = None | |
| try: | |
| status = st.status("Preparing for training...", expanded=True) | |
| status.write("Clearing cached inference model to free GPU memory...") | |
| load_model.clear() | |
| if _torch.cuda.is_available(): | |
| _torch.cuda.empty_cache() | |
| elif _torch.backends.mps.is_available(): | |
| _torch.mps.empty_cache() | |
| status.write("Loading fresh model for training...") | |
| processor, model = load_model_for_training() | |
| trainable, total = freeze_encoder(model) | |
| status.write(f"Frozen encoder. Trainable params: {trainable:,} / {total:,}") | |
| images_dict = {name: img for name, img in st.session_state.images} | |
| dataset = SAM3FineTuneDataset(images_dict, train_dets, processor) | |
| status.write(f"Dataset ready: {len(dataset)} samples") | |
| status.update(label="Training...", state="running") | |
| progress_bar = st.progress(0, text="Starting training...") | |
| def on_progress(epoch, step, total_steps, loss_val): | |
| pct = (step + 1) / total_steps | |
| progress_bar.progress(pct, text=f"Epoch {epoch + 1}/{epochs} | Step {step + 1}/{total_steps} | Loss: {loss_val:.4f}") | |
| result = run_training(model, processor, dataset, epochs, lr, progress_callback=on_progress) | |
| progress_bar.progress(1.0, text="Training complete!") | |
| st.session_state.training_loss_history = result["loss_history"] | |
| status.write("Packaging fine-tuned model...") | |
| st.session_state.finetuned_model_bytes = get_model_zip_bytes(result["model"], processor) | |
| st.session_state.training_complete = True | |
| status.update(label="Training complete!", state="complete") | |
| finally: | |
| del model, processor, result | |
| if _torch.cuda.is_available(): | |
| _torch.cuda.empty_cache() | |
| elif _torch.backends.mps.is_available(): | |
| _torch.mps.empty_cache() | |
| st.rerun() | |
| else: | |
| st.success("Training complete!") | |
| loss_hist = st.session_state.training_loss_history | |
| if loss_hist: | |
| df = pd.DataFrame({"Epoch": range(1, len(loss_hist) + 1), "Avg Loss": loss_hist}) | |
| st.line_chart(df, x="Epoch", y="Avg Loss") | |
| if st.session_state.finetuned_model_bytes: | |
| st.download_button( | |
| label="Download fine-tuned model (.zip)", | |
| data=st.session_state.finetuned_model_bytes, | |
| file_name="sam3_finetuned.zip", | |
| mime="application/zip", | |
| ) | |
| st.divider() | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| if st.button("Back to Export"): | |
| go_to(3) | |
| st.rerun() | |
| with col2: | |
| if st.button("Start over"): | |
| for key, val in defaults.items(): | |
| st.session_state[key] = val | |
| st.rerun() | |