import streamlit as st from PIL import Image import numpy as np from streamlit_image_coordinates import streamlit_image_coordinates from src.model import compute_image_embedding, predict_mask from src.vectorizer import mask_to_svg_path from src.xml_manager import load_image, create_base_svg st.set_page_config(page_title="SVG-SAMurai", layout="wide", page_icon="🗡️") # Session State Initialization if "image" not in st.session_state: st.session_state.image = None if "image_embedding" not in st.session_state: st.session_state.image_embedding = None if "points" not in st.session_state: st.session_state.points = [] if "labels" not in st.session_state: st.session_state.labels = [] if "current_mask" not in st.session_state: st.session_state.current_mask = None if "segments" not in st.session_state: st.session_state.segments = {} if "original_svg" not in st.session_state: st.session_state.original_svg = None st.title("SVG-SAMurai 🗡️") st.markdown( "Transform raster and vector images into segmented SVG paths using the **Segment Anything Model (SAM)**." ) # File uploader uploaded_file = st.file_uploader( "Upload an Image (PNG, JPG, SVG)", type=["png", "jpg", "jpeg", "svg"] ) if uploaded_file is not None: # Reset state if a new file is uploaded if ( "last_uploaded_file_id" not in st.session_state or st.session_state.last_uploaded_file_id != uploaded_file.file_id ): st.session_state.last_uploaded_file_id = uploaded_file.file_id st.session_state.image = None st.session_state.image_embedding = None st.session_state.points = [] st.session_state.labels = [] st.session_state.current_mask = None st.session_state.segments = {} st.session_state.original_svg = None if st.session_state.image is None: # Load the image with st.spinner("Processing Image..."): image = load_image(uploaded_file) st.session_state.image = image # If the original file was an SVG, save its string representation if uploaded_file.type == "image/svg+xml": st.session_state.original_svg = uploaded_file.getvalue().decode( "utf-8", errors="replace" ) else: # Create a blank SVG canvas with the original raster image dimensions width, height = image.size st.session_state.original_svg = create_base_svg(width, height) # Compute image embeddings once st.session_state.image_embedding = compute_image_embedding(image) st.success("Image embedded successfully!") col1, col2 = st.columns([2, 1]) with col1: st.subheader("Interactive Segmentation") # Display the image with coordinates clicker # If there's a mask, we overlay it display_image = st.session_state.image.copy() if st.session_state.current_mask is not None: # Create a semi-transparent blue overlay for the current mask overlay = np.zeros( (*st.session_state.current_mask.shape, 4), dtype=np.uint8 ) overlay[st.session_state.current_mask > 0] = [ 0, 0, 255, 128, ] # Blue, 50% opacity overlay_image = Image.fromarray(overlay, mode="RGBA") display_image = display_image.convert("RGBA") display_image.paste(overlay_image, (0, 0), overlay_image) display_image = display_image.convert( "RGB" ) # Convert back to RGB for display # Show the image using streamlit-image-coordinates # Note: we need to handle scaling if the image is wider than the container # streamlit-image-coordinates scales the image to the container width but gives the # coordinates relative to the original image dimensions. value = streamlit_image_coordinates(display_image, key="image_coord") # Handle clicks if value is not None: # streamlit_image_coordinates returns x, y relative to the original image size x, y = value["x"], value["y"] # Determine if it's a positive or negative prompt # For simplicity, let's say left click is positive, and we can add a toggle for negative is_positive = st.sidebar.checkbox( "Next Click is Negative Prompt (Exclude)", value=False, key="neg_prompt_toggle", ) label = 0 if is_positive else 1 # Check if this is a new click (prevent reruns from adding the same point repeatedly) new_point = [x, y] if not st.session_state.points or st.session_state.points[-1] != new_point: st.session_state.points.append(new_point) st.session_state.labels.append(label) # Predict new mask with st.spinner("Predicting Segment..."): mask = predict_mask( st.session_state.image, st.session_state.image_embedding, st.session_state.points, st.session_state.labels, ) st.session_state.current_mask = mask st.rerun() # Tools for interacting with the points col_btn1, col_btn2 = st.columns(2) with col_btn1: if st.button("Undo Last Click"): if st.session_state.points: st.session_state.points.pop() st.session_state.labels.pop() if st.session_state.points: # Repredict mask = predict_mask( st.session_state.image, st.session_state.image_embedding, st.session_state.points, st.session_state.labels, ) st.session_state.current_mask = mask else: st.session_state.current_mask = None st.rerun() with col_btn2: if st.button("Clear Current Selection"): st.session_state.points = [] st.session_state.labels = [] st.session_state.current_mask = None st.rerun() with col2: st.subheader("Segment Management") segment_name = st.text_input("Segment Name", placeholder="e.g., car_body") epsilon_factor = st.slider( "Vectorization Simplification (epsilon)", min_value=0.001, max_value=0.05, value=0.005, step=0.001, format="%.3f", ) if st.button( "Save Segment to SVG", disabled=st.session_state.current_mask is None or not segment_name, ): with st.spinner("Vectorizing..."): # 1. Convert mask to SVG path path_d = mask_to_svg_path( st.session_state.current_mask, epsilon_factor=epsilon_factor ) # 2. Add to session state segments dictionary st.session_state.segments[segment_name] = path_d # 3. Inject the path into the working SVG string from src.xml_manager import add_path_to_svg try: st.session_state.original_svg = add_path_to_svg( st.session_state.original_svg, path_d, segment_name, fill_color="#FF0000", opacity=0.5, ) # Clear current selection for the next segment st.session_state.points = [] st.session_state.labels = [] st.session_state.current_mask = None st.success(f"Segment '{segment_name}' saved!") st.rerun() except Exception as e: st.error(f"Failed to inject SVG: {e}") # Display saved segments list if st.session_state.segments: st.write("### Saved Segments") for name in st.session_state.segments.keys(): st.markdown(f"- **{name}**") # Provide download button for the final SVG st.download_button( label="Download Final SVG", data=st.session_state.original_svg, file_name="segmented_output.svg", mime="image/svg+xml", )