Spaces:
Running
Running
| import streamlit as st | |
| from PIL import Image | |
| import numpy as np | |
| from streamlit_image_coordinates import streamlit_image_coordinates | |
| from src.model import compute_image_embedding, predict_mask | |
| from src.vectorizer import mask_to_svg_path | |
| from src.xml_manager import load_image, create_base_svg | |
| st.set_page_config(page_title="SVG-SAMurai", layout="wide", page_icon="🗡️") | |
| # Session State Initialization | |
| if "image" not in st.session_state: | |
| st.session_state.image = None | |
| if "image_embedding" not in st.session_state: | |
| st.session_state.image_embedding = None | |
| if "points" not in st.session_state: | |
| st.session_state.points = [] | |
| if "labels" not in st.session_state: | |
| st.session_state.labels = [] | |
| if "current_mask" not in st.session_state: | |
| st.session_state.current_mask = None | |
| if "segments" not in st.session_state: | |
| st.session_state.segments = {} | |
| if "original_svg" not in st.session_state: | |
| st.session_state.original_svg = None | |
| st.title("SVG-SAMurai 🗡️") | |
| st.markdown( | |
| "Transform raster and vector images into segmented SVG paths using the **Segment Anything Model (SAM)**." | |
| ) | |
| # File uploader | |
| uploaded_file = st.file_uploader( | |
| "Upload an Image (PNG, JPG, SVG)", type=["png", "jpg", "jpeg", "svg"] | |
| ) | |
| if uploaded_file is not None: | |
| # Reset state if a new file is uploaded | |
| if ( | |
| "last_uploaded_file_id" not in st.session_state | |
| or st.session_state.last_uploaded_file_id != uploaded_file.file_id | |
| ): | |
| st.session_state.last_uploaded_file_id = uploaded_file.file_id | |
| st.session_state.image = None | |
| st.session_state.image_embedding = None | |
| st.session_state.points = [] | |
| st.session_state.labels = [] | |
| st.session_state.current_mask = None | |
| st.session_state.segments = {} | |
| st.session_state.original_svg = None | |
| if st.session_state.image is None: | |
| # Load the image | |
| with st.spinner("Processing Image..."): | |
| image = load_image(uploaded_file) | |
| st.session_state.image = image | |
| # If the original file was an SVG, save its string representation | |
| if uploaded_file.type == "image/svg+xml": | |
| st.session_state.original_svg = uploaded_file.getvalue().decode( | |
| "utf-8", errors="replace" | |
| ) | |
| else: | |
| # Create a blank SVG canvas with the original raster image dimensions | |
| width, height = image.size | |
| st.session_state.original_svg = create_base_svg(width, height) | |
| # Compute image embeddings once | |
| st.session_state.image_embedding = compute_image_embedding(image) | |
| st.success("Image embedded successfully!") | |
| col1, col2 = st.columns([2, 1]) | |
| with col1: | |
| st.subheader("Interactive Segmentation") | |
| # Display the image with coordinates clicker | |
| # If there's a mask, we overlay it | |
| display_image = st.session_state.image.copy() | |
| if st.session_state.current_mask is not None: | |
| # Create a semi-transparent blue overlay for the current mask | |
| overlay = np.zeros( | |
| (*st.session_state.current_mask.shape, 4), dtype=np.uint8 | |
| ) | |
| overlay[st.session_state.current_mask > 0] = [ | |
| 0, | |
| 0, | |
| 255, | |
| 128, | |
| ] # Blue, 50% opacity | |
| overlay_image = Image.fromarray(overlay, mode="RGBA") | |
| display_image = display_image.convert("RGBA") | |
| display_image.paste(overlay_image, (0, 0), overlay_image) | |
| display_image = display_image.convert( | |
| "RGB" | |
| ) # Convert back to RGB for display | |
| # Show the image using streamlit-image-coordinates | |
| # Note: we need to handle scaling if the image is wider than the container | |
| # streamlit-image-coordinates scales the image to the container width but gives the | |
| # coordinates relative to the original image dimensions. | |
| value = streamlit_image_coordinates(display_image, key="image_coord") | |
| # Handle clicks | |
| if value is not None: | |
| # streamlit_image_coordinates returns x, y relative to the original image size | |
| x, y = value["x"], value["y"] | |
| # Determine if it's a positive or negative prompt | |
| # For simplicity, let's say left click is positive, and we can add a toggle for negative | |
| is_positive = st.sidebar.checkbox( | |
| "Next Click is Negative Prompt (Exclude)", | |
| value=False, | |
| key="neg_prompt_toggle", | |
| ) | |
| label = 0 if is_positive else 1 | |
| # Check if this is a new click (prevent reruns from adding the same point repeatedly) | |
| new_point = [x, y] | |
| if not st.session_state.points or st.session_state.points[-1] != new_point: | |
| st.session_state.points.append(new_point) | |
| st.session_state.labels.append(label) | |
| # Predict new mask | |
| with st.spinner("Predicting Segment..."): | |
| mask = predict_mask( | |
| st.session_state.image, | |
| st.session_state.image_embedding, | |
| st.session_state.points, | |
| st.session_state.labels, | |
| ) | |
| st.session_state.current_mask = mask | |
| st.rerun() | |
| # Tools for interacting with the points | |
| col_btn1, col_btn2 = st.columns(2) | |
| with col_btn1: | |
| if st.button("Undo Last Click"): | |
| if st.session_state.points: | |
| st.session_state.points.pop() | |
| st.session_state.labels.pop() | |
| if st.session_state.points: | |
| # Repredict | |
| mask = predict_mask( | |
| st.session_state.image, | |
| st.session_state.image_embedding, | |
| st.session_state.points, | |
| st.session_state.labels, | |
| ) | |
| st.session_state.current_mask = mask | |
| else: | |
| st.session_state.current_mask = None | |
| st.rerun() | |
| with col_btn2: | |
| if st.button("Clear Current Selection"): | |
| st.session_state.points = [] | |
| st.session_state.labels = [] | |
| st.session_state.current_mask = None | |
| st.rerun() | |
| with col2: | |
| st.subheader("Segment Management") | |
| segment_name = st.text_input("Segment Name", placeholder="e.g., car_body") | |
| epsilon_factor = st.slider( | |
| "Vectorization Simplification (epsilon)", | |
| min_value=0.001, | |
| max_value=0.05, | |
| value=0.005, | |
| step=0.001, | |
| format="%.3f", | |
| ) | |
| if st.button( | |
| "Save Segment to SVG", | |
| disabled=st.session_state.current_mask is None or not segment_name, | |
| ): | |
| with st.spinner("Vectorizing..."): | |
| # 1. Convert mask to SVG path | |
| path_d = mask_to_svg_path( | |
| st.session_state.current_mask, epsilon_factor=epsilon_factor | |
| ) | |
| # 2. Add to session state segments dictionary | |
| st.session_state.segments[segment_name] = path_d | |
| # 3. Inject the path into the working SVG string | |
| from src.xml_manager import add_path_to_svg | |
| try: | |
| st.session_state.original_svg = add_path_to_svg( | |
| st.session_state.original_svg, | |
| path_d, | |
| segment_name, | |
| fill_color="#FF0000", | |
| opacity=0.5, | |
| ) | |
| # Clear current selection for the next segment | |
| st.session_state.points = [] | |
| st.session_state.labels = [] | |
| st.session_state.current_mask = None | |
| st.success(f"Segment '{segment_name}' saved!") | |
| st.rerun() | |
| except Exception as e: | |
| st.error(f"Failed to inject SVG: {e}") | |
| # Display saved segments list | |
| if st.session_state.segments: | |
| st.write("### Saved Segments") | |
| for name in st.session_state.segments.keys(): | |
| st.markdown(f"- **{name}**") | |
| # Provide download button for the final SVG | |
| st.download_button( | |
| label="Download Final SVG", | |
| data=st.session_state.original_svg, | |
| file_name="segmented_output.svg", | |
| mime="image/svg+xml", | |
| ) | |