G-Paris
Initialize Space with LFS for images
19f31ed
import gradio as gr
import numpy as np
import os
import glob
from PIL import Image, ImageDraw, ImageFont
from src.sam3_annotation_tool.theme import CustomBlueTheme
from src.sam3_annotation_tool.controller import controller
from src.sam3_annotation_tool.inference import load_models
from src.sam3_annotation_tool.utils import apply_mask_overlay, get_bbox_from_mask
from src.sam3_annotation_tool.view_helpers import (
draw_boxes_on_image, format_box_list, parse_dataframe, on_dataframe_change,
delete_checked_boxes, on_upload, on_input_image_select, undo_last_click,
on_crop_dataframe_change, format_crop_box
)
# Dynamically load example images from the example_img folder
EXAMPLE_IMAGES = sorted(glob.glob("example_img/*.jpg") + glob.glob("example_img/*.png") + glob.glob("example_img/*.jpeg"))
# Load models immediately on startup
load_models()
app_theme = CustomBlueTheme()
# --- Helper Functions ---
# (Moved to src/view_helpers.py)
def run_inference_step1(clean_image, text_prompt, boxes, labels, class_name_override, crop_box=None):
"""Step 1: Run Inference and switch screens."""
print(f"🖱️ Run Inference Clicked! Prompt: '{text_prompt}', Override: '{class_name_override}', Boxes: {len(boxes)}, Crop: {crop_box}")
if clean_image is None:
raise gr.Error("Please upload an image.")
if not text_prompt and not boxes:
raise gr.Error("Please enter a text prompt or add box prompts.")
# Only set image if not in playlist mode (to avoid resetting project state)
if not controller.project.playlist:
controller.set_image(clean_image)
try:
candidates = controller.search_and_add(text_prompt, boxes, labels, class_name_override, crop_box)
print(f"✅ Search returned {len(candidates)} candidates.")
except Exception as e:
print(f"❌ Error during search: {e}")
raise gr.Error(f"Inference failed: {str(e)}")
# Return candidates, image, and screen visibility updates
return (
candidates,
clean_image,
gr.update(selected=2), # Switch to Results Tab
gr.update(choices=controller.project.prompt_history),
gr.update(choices=controller.project.class_name_history)
)
def render_results_step2(candidates, image):
"""Step 2: Render List and Preview."""
if image is None: return gr.update(), gr.update(), set()
print("🖼️ Rendering results...")
# Preview Image (All candidates dim)
preview_img = controller.get_candidate_preview(candidates, selected_index=None)
# List Items
list_items = []
for i, cand in enumerate(candidates):
label = f"ID {i+1}: {cand.class_name} ({cand.score:.2f})"
list_items.append(label)
return (
gr.update(choices=list_items, value=[]),
gr.update(value=preview_img),
set() # Reset selected indices
)
def on_selection_change(selected_values, candidates):
"""Handle selection change in checkbox group."""
# Reconstruct indices from selected values
# Value format: "ID {i+1}: ..."
selected_indices = set()
# Create a map of label -> index for robust lookup
label_map = {}
for i, cand in enumerate(candidates):
label = f"ID {i+1}: {cand.class_name} ({cand.score:.2f})"
label_map[label] = i
for val in selected_values:
if val in label_map:
selected_indices.add(label_map[val])
# Update Preview Image
# Pass the set of selected indices to highlight them
preview_img = controller.get_candidate_preview(candidates, selected_index=selected_indices)
return gr.update(value=preview_img), selected_indices
def select_all_candidates(candidates):
"""Select all candidates."""
if not candidates: return gr.update(), gr.update(), set()
all_indices = set(range(len(candidates)))
list_items = []
for i, cand in enumerate(candidates):
label = f"ID {i+1}: {cand.class_name} ({cand.score:.2f})"
list_items.append(label)
# Preview Image - Highlight ALL
preview_img = controller.get_candidate_preview(candidates, selected_index=all_indices)
return gr.update(value=preview_img), gr.update(value=list_items), all_indices
def deselect_all_candidates(candidates):
"""Deselect all candidates."""
if not candidates: return gr.update(), gr.update(), set()
# Preview Image
preview_img = controller.get_candidate_preview(candidates, selected_index=None)
return gr.update(value=preview_img), gr.update(value=[]), set()
def add_to_store_wrapper(candidates, selected_indices):
if not selected_indices: raise gr.Error("No masks selected.")
# Convert set to list
return add_to_store(candidates, list(selected_indices))
def toggle_click_mode(current_mode):
"""Toggle between Include and Exclude."""
if "Include" in current_mode:
return "Exclude (Red)"
return "Include (Green)"
def revert_object_refinement(obj_id):
"""Revert object to initial state."""
if not obj_id: return gr.update()
controller.revert_object(obj_id)
return init_editor(obj_id)[0]
def export_results(output_path, export_type="YOLO", zip_output=False):
"""Export results to output folder."""
print(f"📦 Exporting results to {output_path} (Type: {export_type}, Zip: {zip_output})...")
if "Not supported yet" in export_type:
raise gr.Error(f"Export type '{export_type}' is not supported yet.")
try:
res = controller.export_data(output_path, purge=True, zip_output=zip_output)
if res:
_, msg = res
return msg
else:
return "Export failed: No data to export."
except Exception as e:
return f"Export failed: {e}"
def add_to_store(candidates, selected_indices):
if not selected_indices: raise gr.Error("No masks selected.")
controller.add_candidates_to_store(candidates, selected_indices)
return "Added to Store!", gr.update(selected=3) # Go to Editor Tab
# --- UI Layout ---
custom_css="""
#col-container { margin: 0 auto; max-width: 1400px; }
#main-title h1 { font-size: 2.1em !important; }
#input_image { position: relative; overflow: hidden; }
#input_image button, #input_image img, #input_image canvas { cursor: crosshair !important; }
.zoom-image img { transition: transform 0.1s ease-out; }
/* Use max-height to constrain large images without upscaling small ones */
/* This fixes zoom/click coordinate issues for images smaller than container */
#input_image img,
.zoom-image img {
max-height: 600px !important;
width: auto !important;
height: auto !important;
}
/* Dataframe Font Size */
.box-list-df td, .box-list-df th, .box-list-df td span, .box-list-df td input, .box-list-df td div { font-size: 10px !important; line-height: 1.0 !important; padding: 2px !important; }
/* Hide Checkbox in Header for 'Del' column (assuming it's the first column) */
thead th:first-child input[type="checkbox"] { display: none !important; }
/* Column Widths */
.box-list-df th:nth-child(1), .box-list-df td:nth-child(1) { width: 30px !important; min-width: 30px !important; }
.box-list-df th:nth-child(2), .box-list-df td:nth-child(2) { width: 80px !important; }
/* Export Status Font Size */
#export-status textarea { font-size: 0.8em !important; }
/* Horizontal Radio Buttons */
.horizontal-radio .wrap { display: flex !important; flex-direction: row !important; gap: 10px !important; }
.horizontal-radio label { margin-bottom: 0 !important; align-items: center !important; }
.horizontal-radio span { font-size: 0.8em !important; }
/* Scrollable Radio List */
.scrollable-radio { max-height: 200px !important; overflow-y: auto !important; border: 1px solid #e5e7eb; padding: 5px; border-radius: 5px; }
/* Hide Footer */
footer { display: none !important; }
/* Ensure the tab bar is a flex container and takes full width */
.main-tabs > div:first-of-type {
display: flex !important;
width: 100% !important;
}
/* Navigation Row Compactness */
.nav-row {
align-items: center !important;
gap: 5px !important;
margin-bottom: 5px !important;
}
.nav-status textarea {
font-size: 0.8em !important;
text-align: center !important;
height: 30px !important;
min-height: 30px !important;
padding: 5px !important;
}
"""
# JS for Crosshair and Zoom
custom_js = """
function setupInteractions() {
// Crosshair Logic
const setupCrosshair = () => {
const c = document.querySelector('#input_image');
if (c && !c.dataset.crosshairSetup) {
c.dataset.crosshairSetup = "true";
c.style.position = 'relative';
const createLine = (id, isH) => {
let l = document.createElement('div');
l.style.cssText = `position:absolute;background:cyan;pointer-events:none;z-index:10000;display:none;box-shadow:0 0 2px rgba(0,0,0,0.5);${isH ? 'height:1px;width:100%;' : 'width:1px;height:100%;top:0;'}`;
c.appendChild(l);
return l;
};
const h = createLine('h', true), v = createLine('v', false);
c.addEventListener('mousemove', (e) => {
const r = c.getBoundingClientRect();
const x = e.clientX - r.left, y = e.clientY - r.top;
if (x >= 0 && x <= r.width && y >= 0 && y <= r.height) {
h.style.display = v.style.display = 'block';
h.style.top = (y - 2) + 'px';
v.style.left = (x - 2) + 'px';
} else { h.style.display = v.style.display = 'none'; }
});
c.addEventListener('mouseleave', () => { h.style.display = v.style.display = 'none'; });
}
};
// Zoom Logic
const setupZoom = () => {
document.querySelectorAll('.zoom-image').forEach(container => {
if (container.dataset.zoomSetup) return;
container.dataset.zoomSetup = "true";
container.style.overflow = 'hidden';
let scale = 1, pointX = 0, pointY = 0, startX = 0, startY = 0, isDragging = false;
container.addEventListener('wheel', (e) => {
e.preventDefault();
const img = container.querySelector('img');
if (!img) return;
img.style.transformOrigin = "0 0";
img.style.transition = "transform 0.1s ease-out";
const rect = container.getBoundingClientRect();
const xs = (e.clientX - rect.left - pointX) / scale;
const ys = (e.clientY - rect.top - pointY) / scale;
const delta = -e.deltaY;
(delta > 0) ? (scale *= 1.2) : (scale /= 1.2);
if (scale < 1) scale = 1;
pointX = e.clientX - rect.left - xs * scale;
pointY = e.clientY - rect.top - ys * scale;
img.style.transform = `translate(${pointX}px, ${pointY}px) scale(${scale})`;
});
// Panning
container.addEventListener('mousedown', (e) => {
isDragging = true;
startX = e.clientX - pointX;
startY = e.clientY - pointY;
});
window.addEventListener('mousemove', (e) => {
if (!isDragging) return;
e.preventDefault();
const img = container.querySelector('img');
if (!img) return;
pointX = e.clientX - startX;
pointY = e.clientY - startY;
img.style.transform = `translate(${pointX}px, ${pointY}px) scale(${scale})`;
});
window.addEventListener('mouseup', () => { isDragging = false; });
});
};
// Observer
const observer = new MutationObserver(() => {
setupCrosshair();
setupZoom();
});
observer.observe(document.body, { childList: true, subtree: true });
setupCrosshair();
setupZoom();
}
"""
with gr.Blocks() as demo:
gr.HTML(f"<style>{custom_css}</style>")
# State Variables
st_boxes = gr.State([])
st_labels = gr.State([])
st_candidates = gr.State([])
st_selected_indices = gr.State(set()) # Track selected indices
st_current_image = gr.State(None)
st_selected_box_index = gr.State(None) # Track selected box for deletion
st_clean_input_image = gr.State(None) # Store original uploaded image
st_pending_point = gr.State(None) # Store first point of box click
st_crop_box = gr.State(None) # Store crop box [x1, y1, x2, y2]
# Hidden status box for messages
status_box = gr.Textbox(visible=False)
with gr.Column(elem_id="col-container"):
gr.Markdown("# **GP-engineering's SAM3 Annotator**", elem_id="main-title")
with gr.Tabs(elem_classes=["main-tabs"]) as tabs:
# --- SCREEN 0: SETUP ---
with gr.TabItem("Setup", id=0) as setup_screen:
with gr.Tabs():
with gr.TabItem("New Project"):
gr.Markdown("### Select Data Source")
with gr.Row():
with gr.Column():
gr.Markdown("#### Batch (Folder)")
upload_files = gr.File(label="Upload Folder", file_count="directory", file_types=["image"], height=200)
with gr.Column():
gr.Markdown("#### Single Image")
single_image_input = gr.Image(
label="Upload or Capture Image",
sources=["upload", "webcam", "clipboard"],
type="filepath",
height=400
)
if EXAMPLE_IMAGES:
gr.Examples(
examples=EXAMPLE_IMAGES,
inputs=single_image_input
)
gr.Markdown("### Project Details")
new_project_name = gr.Textbox(label="Project Name", placeholder="Enter a name for your new project")
start_btn = gr.Button("Start Annotation", variant="primary", interactive=False)
with gr.TabItem("Save / Load"):
gr.Markdown("### Project State")
with gr.Row():
with gr.Column(scale=1):
project_name_input = gr.Textbox(label="Project Name", placeholder="my_project", value="my_project")
save_project_btn = gr.Button("Save Project State")
with gr.Column(scale=1):
# Changed from File Upload to Dropdown for Local Projects
project_dropdown = gr.Dropdown(label="Load Existing Project", choices=[], interactive=True)
refresh_projects_btn = gr.Button("🔄 Refresh List", size="sm")
load_project_btn = gr.Button("Load Project State")
with gr.Row():
delete_project_btn = gr.Button("Delete Project", variant="stop", size="sm")
confirm_delete_btn = gr.Button("⚠️ Confirm Delete", variant="stop", size="sm", visible=False)
cancel_delete_btn = gr.Button("Cancel", size="sm", visible=False)
with gr.Row():
delete_all_btn = gr.Button("Delete ALL Projects", variant="stop", size="sm")
confirm_delete_all_btn = gr.Button("⚠️ Confirm Delete ALL", variant="stop", size="sm", visible=False)
cancel_delete_all_btn = gr.Button("Cancel", size="sm", visible=False)
project_status = gr.Textbox(label="Status", interactive=False, lines=10, scale=1)
# --- SCREEN 1: INPUT ---
with gr.TabItem("Input", id=1) as input_screen:
gr.Markdown("### Generate initial objects")
# Navigation (Full Width)
with gr.Row(elem_classes="nav-row"):
prev_btn = gr.Button("Previous", size="sm", scale=0)
nav_status = gr.Textbox(label="Status", show_label=False, value="0/0", interactive=False, scale=1, elem_classes="nav-status")
next_btn = gr.Button("Next", size="sm", scale=0)
# Prompt Row (Full Width)
with gr.Row():
txt_prompt = gr.Dropdown(label="Text Prompt", choices=[], allow_custom_value=True, show_label=True, scale=4)
txt_class_name = gr.Dropdown(label="Class Name Override", choices=[], allow_custom_value=True, show_label=True, scale=2)
run_btn = gr.Button("Run Inference", variant="primary", scale=1)
with gr.Row():
# Left Column: Image
with gr.Column(scale=3):
img_input = gr.Image(
label="Current Image (Click 2 Points for Box)",
type="pil",
interactive=True,
elem_id="input_image",
elem_classes="zoom-image"
)
# Right Column: Controls
with gr.Column(scale=1):
# Box Controls (Top Right)
with gr.Group():
# gr.Markdown("### Box Controls") # Removed header
click_effect = gr.Radio(["Crop Initial Image", "Include Area", "Exclude Area"], value="Include Area", label="Click Effect")
undo_click_btn = gr.Button("Undo Last Click", variant="secondary", size="sm")
# Crop List
crop_list_display = gr.Dataframe(
headers=["Del", "x1", "y1", "x2", "y2"],
datatype=["bool", "number", "number", "number", "number"],
column_count=5,
interactive=True,
label="Crop Area (Model runs at 1024x1024)",
wrap=True,
elem_classes="box-list-df"
)
# Box List (Moved here)
gr.Markdown("")
# [Delete?, Type, x1, y1, x2, y2]
box_list_display = gr.Dataframe(
headers=["Del", "Type", "x1", "y1", "x2", "y2"],
datatype=["bool", "str", "number", "number", "number", "number"],
column_count=6,
interactive=True,
label="Added Boxes",
wrap=True,
elem_classes="box-list-df"
)
delete_box_btn = gr.Button("Delete Checked Boxes", variant="stop", size="sm")
# --- SCREEN 2: RESULTS ---
with gr.TabItem("Results", id=2) as result_screen:
with gr.Row():
gr.Markdown("### Select relevant objects")
result_img_counter = gr.Markdown("Image 0/0", elem_id="result-img-counter")
with gr.Row():
with gr.Column(scale=3):
# Preview Image with ALL masks
preview_image = gr.Image(
label="Selected Candidates Preview",
type="pil",
interactive=False,
elem_classes="zoom-image"
)
with gr.Column(scale=1):
# Checkbox List of candidates
results_list = gr.CheckboxGroup(
label="Candidates",
choices=[],
value=[],
interactive=True,
elem_classes="candidate-list"
)
with gr.Row():
select_all_btn = gr.Button("Select All", size="sm", variant="secondary")
deselect_all_btn = gr.Button("Deselect All", size="sm", variant="secondary")
with gr.Row():
confirm_btn = gr.Button("Add Selected to Store", variant="primary")
with gr.Row():
add_more_btn = gr.Button("Add Selected & Generate More", variant="secondary")
# --- SCREEN 3: EDITOR ---
with gr.TabItem("Editor", id=3) as editor_screen:
with gr.Row():
gr.Markdown("### Refine individual objects")
editor_img_counter = gr.Markdown("Image 0/0", elem_id="editor-img-counter")
refresh_editor_btn = gr.Button("🔄 Refresh", size="sm", scale=0)
with gr.Row():
with gr.Column(scale=3):
# Main interactive image for refinement
refine_image = gr.Image(
label="Click to Refine",
type="pil",
interactive=False,
elem_classes="zoom-image"
)
# Export moved to separate tab
with gr.Column(scale=1):
gr.Markdown("")
with gr.Row():
object_list = gr.Radio(
label="Select Object",
choices=[],
interactive=True,
elem_classes="scrollable-radio"
)
with gr.Row():
# revert_btn = gr.Button("Revert", size="sm", variant="secondary") # Moved below
delete_btn = gr.Button("Delete", size="sm", variant="stop")
gr.Markdown("")
with gr.Row():
click_mode = gr.Radio(["Include (Green)", "Exclude (Red)"], value="Include (Green)", label="Click Mode", interactive=True, elem_classes="horizontal-radio", scale=2)
undo_btn = gr.Button("Undo Last Click", variant="secondary", size="sm", scale=1)
revert_btn = gr.Button("Revert Object", size="sm", variant="secondary")
gr.Markdown("")
with gr.Row():
finish_img_btn = gr.Button("Finish & Next Image", variant="primary")
finish_save_btn = gr.Button("Finish (Save)", variant="secondary")
# --- SCREEN 4: EXPORT ---
with gr.TabItem("Export", id=4) as export_screen:
gr.Markdown("### Export data")
with gr.Row():
with gr.Column():
# Project State Display
export_status_display = gr.JSON(label="Ready for Export", value={})
with gr.Row():
txt_output_dir = gr.Textbox(label="Output Folder", value="output", scale=3)
export_type = gr.Dropdown(label="Export Type", choices=["YOLO", "COCO (Not supported yet)"], value="YOLO", scale=1)
gr.Markdown("⚠️ **Warning:** Exporting will delete the current contents in the output directory before saving.")
with gr.Row():
zip_export = gr.Checkbox(label="Zip Output", value=False, scale=0)
export_btn = gr.Button("Export", scale=0, min_width=150)
reset_btn = gr.Button("Reset Project", variant="stop", scale=0, min_width=150)
gr.Markdown("ℹ️ **Note:** All classes in the project history will be included in `data.yaml` to ensure compatibility with tools like CVAT, make sure all labels in the CVAT project are added to the project history (prevents KeyError).")
export_status = gr.Textbox(label="Export Status", interactive=False, elem_id="export-status", lines=5)
# --- Helper Functions for Editor ---
def init_editor(selected_obj_id=None):
"""Initialize editor screen with current image and objects."""
base_img = controller.current_image
if base_img is None: return None, gr.update(choices=[])
# Create choices for Radio
choices = []
for obj_id, obj in controller.store.objects.items():
# Limit ID display to first 4 chars
display_id = obj_id[:4]
choices.append((f"{obj.class_name} ({display_id})", obj_id))
# Determine selection
if selected_obj_id is None and choices:
selected_obj_id = choices[0][1]
elif selected_obj_id and selected_obj_id not in [c[1] for c in choices]:
selected_obj_id = choices[0][1] if choices else None
# Create overlay
overlay_img = base_img.copy()
draw = ImageDraw.Draw(overlay_img)
# Load font
try:
font = ImageFont.truetype("arial.ttf", 20)
except:
font = ImageFont.load_default()
if selected_obj_id and selected_obj_id in controller.store.objects:
# Show ONLY selected object (as per original logic)
obj = controller.store.objects[selected_obj_id]
mask = obj.binary_mask
overlay_img = apply_mask_overlay(base_img, np.array([mask]), opacity=0.6)
# Draw Points
draw = ImageDraw.Draw(overlay_img)
radius = 5
for pt, lbl in zip(obj.input_points, obj.input_labels):
color = "#00FF00" if lbl == 1 else "#FF0000"
x, y = pt
draw.ellipse((x-radius, y-radius, x+radius, y+radius), fill=color, outline="white")
# Draw ID
bbox = get_bbox_from_mask(mask)
if bbox:
x, y = bbox[0], bbox[1]
draw.text((x, y - 20), selected_obj_id[:5], fill="white", font=font, stroke_width=2, stroke_fill="black")
return overlay_img, gr.update(choices=choices, value=selected_obj_id)
def on_image_click(img, evt: gr.SelectData, obj_id, mode):
"""Handle click on image to refine object."""
if not obj_id: raise gr.Error("Please select an object to refine first.")
point = [evt.index[0], evt.index[1]]
label = 1 if "Include" in mode else 0
# Call controller
controller.refine_object(obj_id, point, label)
# Re-render overlay
return init_editor(obj_id)[0]
def on_undo(obj_id):
if not obj_id: return gr.update()
controller.undo_last_point(obj_id)
return init_editor(obj_id)[0]
def on_delete(obj_id):
if not obj_id: return gr.update(), gr.update()
controller.remove_object(obj_id)
# Auto-save
controller.auto_save()
# Refresh everything
img, radio = init_editor(None)
return img, radio
# --- Event Wiring ---
# 1. Upload Files
def handle_upload(files):
# Load playlist
img, _, _, _ = on_upload(files)
# Enable start button if images found
count = len(controller.project.playlist)
if count > 0:
return gr.update(interactive=True, value=f"Start Annotation ({count} images)")
else:
return gr.update(interactive=False, value="Start Annotation")
upload_files.upload(
fn=handle_upload,
inputs=[upload_files],
outputs=[start_btn]
)
def handle_single_image(file_path):
if not file_path:
return gr.update(interactive=False, value="Start Annotation")
# Reuse on_upload logic which expects a list of paths
img, _, _, _ = on_upload([file_path])
if controller.project.playlist:
return gr.update(interactive=True, value="Start Annotation (1 Image)")
return gr.update(interactive=False, value="Start Annotation")
single_image_input.change(
fn=handle_single_image,
inputs=[single_image_input],
outputs=[start_btn]
)
# Project Save/Load
def on_save_project(name):
if not name: return "Please enter a project name."
# Create saved_projects directory
import os
os.makedirs("saved_projects", exist_ok=True)
filename = f"saved_projects/{name}.json"
success, msg = controller.save_project(filename)
return msg
save_project_btn.click(
fn=on_save_project,
inputs=[project_name_input],
outputs=[project_status]
)
def list_projects():
import os
if not os.path.exists("saved_projects"):
return []
files = [f for f in os.listdir("saved_projects") if f.endswith(".json")]
return sorted(files)
def on_load_project(filename):
if not filename: return "No project selected.", gr.update(), gr.update(), None, [], [], None, "0/0", None, gr.update(), gr.update(), {}, gr.update(), gr.update()
file_path = f"saved_projects/{filename}"
success, msg = controller.load_project(file_path)
if success:
# Update UI state
img = controller.current_image
status = f"Image {controller.project.current_index + 1}/{len(controller.project.playlist)}"
return (
msg,
gr.update(selected=1), # Go to Input
gr.update(value=img, interactive=True),
img, [], [], None, status, None, gr.update(value=[]), gr.update(value="Crop Initial Image"),
get_export_status(),
gr.update(choices=controller.project.prompt_history),
gr.update(choices=controller.project.class_name_history)
)
return msg, gr.update(), gr.update(), None, [], [], None, "0/0", None, gr.update(), gr.update(), {}, gr.update(), gr.update()
load_project_btn.click(
fn=on_load_project,
inputs=[project_dropdown],
outputs=[project_status, tabs, img_input, st_clean_input_image, st_boxes, st_labels, st_pending_point, nav_status, st_crop_box, crop_list_display, click_effect, export_status_display, txt_prompt, txt_class_name]
)
refresh_projects_btn.click(
fn=lambda: gr.update(choices=list_projects()),
inputs=[],
outputs=[project_dropdown]
)
# Delete Project Logic
def on_delete_click(filename):
if not filename: return gr.update(visible=False), gr.update(visible=False), "Please select a project to delete."
return gr.update(visible=True), gr.update(visible=True), f"Are you sure you want to delete '{filename}'? This cannot be undone."
delete_project_btn.click(
fn=on_delete_click,
inputs=[project_dropdown],
outputs=[confirm_delete_btn, cancel_delete_btn, project_status]
)
def on_cancel_delete():
return gr.update(visible=False), gr.update(visible=False), "Deletion cancelled."
cancel_delete_btn.click(
fn=on_cancel_delete,
inputs=[],
outputs=[confirm_delete_btn, cancel_delete_btn, project_status]
)
def on_confirm_delete(filename):
if not filename: return gr.update(visible=False), gr.update(visible=False), "No project selected.", gr.update()
import os
import shutil
json_path = f"saved_projects/{filename}"
assets_dir = f"saved_projects/{filename.replace('.json', '')}_assets"
msg = []
try:
if os.path.exists(json_path):
os.remove(json_path)
msg.append(f"Deleted {filename}")
if os.path.exists(assets_dir):
shutil.rmtree(assets_dir)
msg.append(f"Deleted assets folder")
if not msg:
return gr.update(visible=False), gr.update(visible=False), "Project files not found.", gr.update(choices=list_projects())
return gr.update(visible=False), gr.update(visible=False), f"Success: {', '.join(msg)}", gr.update(choices=list_projects(), value=None)
except Exception as e:
return gr.update(visible=False), gr.update(visible=False), f"Error deleting project: {e}", gr.update(choices=list_projects())
confirm_delete_btn.click(
fn=on_confirm_delete,
inputs=[project_dropdown],
outputs=[confirm_delete_btn, cancel_delete_btn, project_status, project_dropdown]
)
# Delete All Projects Logic
def on_delete_all_click():
return gr.update(visible=True), gr.update(visible=True), "Are you sure you want to delete ALL projects? This cannot be undone."
delete_all_btn.click(
fn=on_delete_all_click,
inputs=[],
outputs=[confirm_delete_all_btn, cancel_delete_all_btn, project_status]
)
def on_cancel_delete_all():
return gr.update(visible=False), gr.update(visible=False), "Deletion cancelled."
cancel_delete_all_btn.click(
fn=on_cancel_delete_all,
inputs=[],
outputs=[confirm_delete_all_btn, cancel_delete_all_btn, project_status]
)
def on_confirm_delete_all():
import os
import shutil
folder = "saved_projects"
if not os.path.exists(folder):
return gr.update(visible=False), gr.update(visible=False), "No projects found.", gr.update(choices=[])
try:
# Delete all files and folders in saved_projects
for filename in os.listdir(folder):
file_path = os.path.join(folder, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
print(f"Failed to delete {file_path}. Reason: {e}")
return gr.update(visible=False), gr.update(visible=False), "All projects deleted.", gr.update(choices=[], value=None)
except Exception as e:
return gr.update(visible=False), gr.update(visible=False), f"Error deleting projects: {e}", gr.update(choices=list_projects())
confirm_delete_all_btn.click(
fn=on_confirm_delete_all,
inputs=[],
outputs=[confirm_delete_all_btn, cancel_delete_all_btn, project_status, project_dropdown]
)
# Init project list on load
demo.load(fn=lambda: gr.update(choices=list_projects()), inputs=[], outputs=[project_dropdown])
def start_session(project_name):
if not controller.project.playlist:
raise gr.Error("No images loaded.")
# Initialize project path only if name provided
if project_name:
import os
os.makedirs("saved_projects", exist_ok=True)
filename = f"saved_projects/{project_name}.json"
# Check if exists
if os.path.exists(filename):
# Optional: Warn user? For now, just proceed (will overwrite if saved)
pass
controller.active_project_path = filename
# Initial Save
controller.save_project(filename)
else:
controller.active_project_path = None
# Ensure we have the first image loaded
if controller.current_image is None:
print("⚠️ Current image is None, attempting to load index 0...")
controller.load_image_at_index(0)
img = controller.current_image
if img is None:
raise gr.Error("Failed to load first image.")
status = f"Image {controller.project.current_index + 1}/{len(controller.project.playlist)}"
return (
gr.update(selected=1), # Switch to Input Tab
gr.update(value=img, interactive=True), # img_input - Explicit update
img, # st_clean_input_image
[], # st_boxes
[], # st_labels
None, # st_pending
status, # nav_status
None, # st_crop_box
gr.update(value=[]), # crop_list_display
gr.update(value="Crop Initial Image") # click_effect
)
start_btn.click(
fn=start_session,
inputs=[new_project_name],
outputs=[tabs, img_input, st_clean_input_image, st_boxes, st_labels, st_pending_point, nav_status, st_crop_box, crop_list_display, click_effect]
)
# Navigation
def on_nav_prev():
img = controller.prev_image()
status = f"Image {controller.project.current_index + 1}/{len(controller.project.playlist)}" if img else "0/0"
return img, img, [], [], None, status
def on_nav_next():
img = controller.next_image()
status = f"Image {controller.project.current_index + 1}/{len(controller.project.playlist)}" if img else "0/0"
return img, img, [], [], None, status
prev_btn.click(
fn=on_nav_prev,
outputs=[img_input, st_clean_input_image, st_boxes, st_labels, st_pending_point, nav_status]
)
next_btn.click(
fn=on_nav_next,
outputs=[img_input, st_clean_input_image, st_boxes, st_labels, st_pending_point, nav_status]
)
# 2. Click on Image (Add Box)
img_input.select(
fn=on_input_image_select,
inputs=[st_pending_point, st_boxes, st_labels, click_effect, st_clean_input_image, st_crop_box],
outputs=[img_input, st_pending_point, st_boxes, st_labels, box_list_display, st_crop_box, crop_list_display]
)
# 2b. Undo Click
undo_click_btn.click(
fn=undo_last_click,
inputs=[st_pending_point, st_boxes, st_labels, st_clean_input_image, st_crop_box],
outputs=[img_input, st_pending_point, st_boxes, st_labels, box_list_display, st_crop_box, crop_list_display]
)
# 3. Dataframe Edits
box_list_display.change(
fn=on_dataframe_change,
inputs=[box_list_display, st_clean_input_image, st_crop_box],
outputs=[img_input, st_boxes, st_labels]
)
crop_list_display.change(
fn=on_crop_dataframe_change,
inputs=[crop_list_display, st_clean_input_image, st_boxes, st_labels],
outputs=[img_input, st_crop_box]
)
# 3b. Delete Checked
delete_box_btn.click(
fn=delete_checked_boxes,
inputs=[box_list_display, st_clean_input_image, st_crop_box],
outputs=[st_boxes, st_labels, box_list_display, img_input]
)
# 4. Run Inference (Button + Enter)
run_inference_fn = lambda img, txt, boxes, labels, cls_name, crop: run_inference_step1(img, txt, boxes, labels, cls_name, crop)
def start_inference(img, prompt, boxes):
if img is None:
raise gr.Error("Please upload an image.")
if not prompt and not boxes:
raise gr.Error("Please enter a text prompt or add box prompts.")
return gr.update(value="Running Inference...", interactive=False)
# Event listeners for inference are defined at the end of the file to include counter updates
# 3b. Select All
select_all_btn.click(
fn=select_all_candidates,
inputs=[st_candidates],
outputs=[preview_image, results_list, st_selected_indices]
)
deselect_all_btn.click(
fn=deselect_all_candidates,
inputs=[st_candidates],
outputs=[preview_image, results_list, st_selected_indices]
)
# 3c. List Select
results_list.change(
fn=on_selection_change,
inputs=[results_list, st_candidates],
outputs=[preview_image, st_selected_indices]
)
# 5. Confirm Selection -> Go to Editor
confirm_btn.click(
fn=add_to_store_wrapper,
inputs=[st_candidates, st_selected_indices],
outputs=[status_box, tabs]
).then(
fn=init_editor,
inputs=[],
outputs=[refine_image, object_list]
)
# 6. Editor Interactions
object_list.change(
fn=init_editor,
inputs=[object_list],
outputs=[refine_image, object_list]
)
refine_image.select(
fn=on_image_click,
inputs=[refine_image, object_list, click_mode],
outputs=[refine_image]
)
undo_btn.click(
fn=on_undo,
inputs=[object_list],
outputs=[refine_image]
)
revert_btn.click(
fn=revert_object_refinement,
inputs=[object_list],
outputs=[refine_image]
)
delete_btn.click(
fn=on_delete,
inputs=[object_list],
outputs=[refine_image, object_list]
)
export_btn.click(
fn=export_results,
inputs=[txt_output_dir, export_type, zip_export],
outputs=[export_status]
)
# Helper to get project status
def get_project_status():
if not controller.project: return {}
# Build dict directly to avoid type inference issues
details = {}
for path, store in controller.project.annotations.items():
name = path.split("/")[-1]
details[name] = len(store.objects)
stats = {
"total_images": len(controller.project.playlist),
"current_index": controller.project.current_index,
"annotated_images": len(controller.project.annotations),
"total_objects": sum(len(s.objects) for s in controller.project.annotations.values()),
"details": details
}
return stats
# Helper for Export Status
def get_export_status():
if not controller.project: return {}
# Map paths to indices
playlist_map = {path: i for i, path in enumerate(controller.project.playlist)}
finished_images = []
total_objects = 0
# Sort by index
sorted_annotations = sorted(
controller.project.annotations.items(),
key=lambda x: playlist_map.get(x[0], -1)
)
for path, store in sorted_annotations:
idx = playlist_map.get(path, -1)
count = len(store.objects)
# Only include images that have objects
if count == 0:
continue
# Include if it has objects or is in the annotations map (meaning visited/saved)
total_objects += count
name = path.split("/")[-1]
finished_images.append({
"index": idx + 1, # 1-based index for display
"filename": name,
"object_count": count
})
return {
"total_objects_annotated": total_objects,
"finished_images_count": len(finished_images),
"finished_images_list": finished_images
}
# Helper to get current image counter string
def get_image_counter():
if not controller.project.playlist: return "Image 0/0"
return f"Image {controller.project.current_index + 1}/{len(controller.project.playlist)}"
# Finish Image & Next
def on_finish_image():
print("🏁 Finishing image...")
# Ensure current state is saved before moving
if controller.current_image_path:
controller.project.annotations[controller.current_image_path] = controller.store
# Auto-save project
controller.auto_save()
img = controller.next_image()
status = get_image_counter() if img else "Finished"
if img:
print(f"➡️ Next image loaded: {controller.project.current_index}")
return (
gr.update(selected=1), # Switch to Input Tab
gr.update(value=img, interactive=True), # img_input
img, # st_clean_input_image
[], # st_boxes
[], # st_labels
None, # st_pending
status, # nav_status
None, # st_crop_box
gr.update(value=[]), # crop_list_display
gr.update(value="Crop Initial Image"), # click_effect
get_export_status() # Update export status
)
else:
print("🛑 Playlist finished.")
# End of playlist
return (
gr.update(selected=3), # Stay on Editor Tab
gr.update(),
gr.update(),
[], [], None,
"Finished",
gr.update(),
gr.update(),
gr.update(), # click_effect
get_export_status() # Update export status
)
finish_img_btn.click(
fn=on_finish_image,
outputs=[tabs, img_input, st_clean_input_image, st_boxes, st_labels, st_pending_point, nav_status, st_crop_box, crop_list_display, click_effect, export_status_display]
)
# Finish (Save Only) Logic
def on_finish_save():
print("💾 Saving current image state...")
# Save current state
if controller.current_image_path:
controller.project.annotations[controller.current_image_path] = controller.store
# Auto-save project
controller.auto_save()
# Update status display (but don't move tabs)
return get_export_status()
finish_save_btn.click(
fn=on_finish_save,
outputs=[export_status_display]
)
# Add & Generate More Logic
def add_and_restart(candidates, selected_indices, clean_img):
if not selected_indices: raise gr.Error("No masks selected.")
controller.add_candidates_to_store(candidates, selected_indices)
# Auto-save
controller.auto_save()
# Return updates to switch to Input tab and clear prompts
return (
gr.update(selected=1), # Switch to Input
gr.update(value=clean_img, interactive=True), # Reset img_input
[], # st_boxes (clear)
[], # st_labels (clear)
None, # st_pending (clear)
gr.update(value=[]), # box_list_display (clear)
gr.update(value="Crop Initial Image") # Reset click effect
)
add_more_btn.click(
fn=add_and_restart,
inputs=[st_candidates, st_selected_indices, st_clean_input_image],
outputs=[tabs, img_input, st_boxes, st_labels, st_pending_point, box_list_display, click_effect]
)
# Update status on enter editor
confirm_btn.click(
fn=add_to_store_wrapper,
inputs=[st_candidates, st_selected_indices],
outputs=[status_box, tabs]
).then(
fn=lambda: (controller.auto_save(), None)[1], # Auto-save on confirm
inputs=[],
outputs=[]
).then(
fn=init_editor,
inputs=[],
outputs=[refine_image, object_list]
).then(
fn=get_image_counter,
outputs=[editor_img_counter]
)
# Update counter on enter results
run_btn.click(
fn=start_inference,
inputs=[st_clean_input_image, txt_prompt, st_boxes],
outputs=[run_btn]
).then(
fn=run_inference_fn,
inputs=[st_clean_input_image, txt_prompt, st_boxes, st_labels, txt_class_name, st_crop_box],
outputs=[st_candidates, st_current_image, tabs, txt_prompt, txt_class_name]
).then(
fn=render_results_step2,
inputs=[st_candidates, st_current_image],
outputs=[results_list, preview_image, st_selected_indices]
).then(
fn=get_image_counter,
outputs=[result_img_counter]
).then(
fn=lambda: gr.update(value="Run Inference", interactive=True),
inputs=[],
outputs=[run_btn]
)
# Reset Project Logic
def on_reset():
controller.reset_project()
return (
gr.update(selected=0), # Go to Setup
gr.update(value={}), # Clear export status display
gr.update(value=""), # Clear export log
gr.update(interactive=False, value="Start Annotation"), # Reset start button
None, # Clear current image
[], [], None, # Clear boxes, labels, pending
"0/0", # Clear nav status
None, # Clear crop box
gr.update(value=[]), # Clear crop list
gr.update(value="Crop Initial Image"), # Reset click effect
gr.update(value="", choices=[]), # Clear prompt history
gr.update(value="", choices=[]) # Clear class name history
)
reset_btn.click(
fn=on_reset,
inputs=[],
outputs=[tabs, export_status_display, export_status, start_btn, img_input, st_boxes, st_labels, st_pending_point, nav_status, st_crop_box, crop_list_display, click_effect, txt_prompt, txt_class_name]
)
# Tab Switching Logic
def on_tab_select(evt: gr.SelectData):
print(f"Tab selected: {evt.value}, Index: {evt.index}")
if evt.value == "Editor" or evt.index == 3:
img, radio = init_editor(None)
counter = get_image_counter()
print(f"Editor Init: Image={img is not None}, Counter={counter}")
return img, radio, counter
return gr.update(), gr.update(), gr.update()
tabs.select(
fn=on_tab_select,
inputs=[],
outputs=[refine_image, object_list, editor_img_counter]
)
# Refresh Button Logic
def on_refresh_editor():
img, radio = init_editor(None)
counter = get_image_counter()
return img, radio, counter
refresh_editor_btn.click(
fn=on_refresh_editor,
inputs=[],
outputs=[refine_image, object_list, editor_img_counter]
)
# Load JS
demo.load(None, None, None, js=custom_js)
if __name__ == "__main__":
demo.launch(css=custom_css, theme=app_theme, ssr_mode=False, mcp_server=False, show_error=True)