lukeafullard commited on
Commit
51a42b8
·
verified ·
1 Parent(s): a650fe9

Upload 7 files

Browse files
README.md CHANGED
@@ -1,20 +1,93 @@
1
- ---
2
- title: SVG SAMurai
3
- emoji: 🚀
4
- colorFrom: red
5
- colorTo: red
6
- sdk: docker
7
- app_port: 8501
8
- tags:
9
- - streamlit
10
- pinned: false
11
- short_description: Tool to turn an image to an SVG with named image sections
12
- license: mit
13
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- # Welcome to Streamlit!
 
 
 
 
16
 
17
- Edit `/src/streamlit_app.py` to customize this app to your heart's desire. :heart:
 
 
 
18
 
19
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
20
- forums](https://discuss.streamlit.io).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SVG-SAMurai 🗡️
2
+
3
+ **SVG-SAMurai** is an interactive, Streamlit-based web application that leverages the power of Meta's **Segment Anything Model (SAM)** to transform raster (PNG, JPG) and vector (SVG) images into precisely segmented, editable SVG paths.
4
+
5
+ Whether you're starting from a flat image or an existing SVG file, SVG-SAMurai allows you to click on regions of interest, predict their boundaries, and inject those precise vector paths back into a master SVG document.
6
+
7
+ ## 🌟 Features
8
+
9
+ - **Interactive Segmentation:** Click on any part of an uploaded image to instantly generate an accurate mask using the SAM Vision Transformer (`facebook/sam-vit-base`).
10
+ - **Support for Multiple Formats:** Upload PNG, JPG, or SVG files. Vector images are rasterized cleanly in the backend for processing, allowing you to segment them seamlessly.
11
+ - **Smart Vectorization:** Extracted masks are converted into optimized SVG `<path>` elements using the Ramer-Douglas-Peucker algorithm (via OpenCV) for smooth, simplified contours.
12
+ - **Adjustable Simplification:** Fine-tune the vectorization epsilon factor directly from the UI to control the complexity of the generated paths.
13
+ - **Segment Management:** Name your segments and save them into a live-updating SVG document.
14
+ - **In-Memory Caching:** Heavy image embeddings are cached securely using Streamlit's `@st.cache_data` and `@st.cache_resource`, ensuring snappy performance and instant mask prediction on subsequent clicks.
15
+ - **Easy Export:** Download the final composed SVG with all your tagged, labeled segments neatly organized in `<g>` groups.
16
+
17
+ ## 🛠️ Tech Stack
18
+
19
+ * **Frontend:** [Streamlit](https://streamlit.io/), [streamlit-image-coordinates](https://pypi.org/project/streamlit-image-coordinates/)
20
+ * **Machine Learning:** [PyTorch](https://pytorch.org/), [Hugging Face Transformers](https://huggingface.co/docs/transformers/index) (Segment Anything Model)
21
+ * **Image Processing:** [OpenCV](https://opencv.org/) (Contour extraction & smoothing), [Pillow (PIL)](https://python-pillow.org/)
22
+ * **SVG / DOM Manipulation:** [lxml](https://lxml.de/) (XML parsing and injection), [CairoSVG](https://cairosvg.org/) (SVG rasterization)
23
+ * **Dependency Management:** [Poetry](https://python-poetry.org/)
24
+
25
+ ## 🚀 Quick Start
26
+
27
+ ### Prerequisites
28
+
29
+ - Python `>=3.10, <3.13` (Required for PyTorch and Triton compatibility)
30
+ - [Poetry](https://python-poetry.org/docs/#installation) installed on your system.
31
+ - System dependencies for CairoSVG and OpenCV (e.g., `libcairo2-dev`, `libgl1-mesa-glx` on Ubuntu/Debian).
32
+
33
+ ### Installation
34
 
35
+ 1. **Clone the repository:**
36
+ ```bash
37
+ git clone <repository-url>
38
+ cd svg-samurai
39
+ ```
40
 
41
+ 2. **Install dependencies using Poetry:**
42
+ ```bash
43
+ poetry install
44
+ ```
45
 
46
+ ### Running the App
47
+
48
+ Start the Streamlit development server:
49
+
50
+ ```bash
51
+ poetry run streamlit run app.py
52
+ ```
53
+
54
+ The application will launch in your default web browser at `http://localhost:8501`.
55
+
56
+ ## 📖 How to Use
57
+
58
+ 1. **Upload an Image:** Use the file uploader to select a PNG, JPG, or SVG file. The app will calculate the complex image embeddings once (this may take a few moments depending on your hardware).
59
+ 2. **Select Segments:** Click anywhere on the image in the left panel to prompt the model.
60
+ - *Tip:* You can toggle the "Next Click is Negative Prompt" checkbox in the sidebar to exclude specific regions from your mask.
61
+ 3. **Refine & Save:**
62
+ - Use the "Undo Last Click" or "Clear Current Selection" buttons to fix mistakes.
63
+ - Give your highlighted segment a descriptive name (e.g., `car_body`).
64
+ - Adjust the **Simplification (epsilon)** slider if you want fewer, smoother nodes in your final vector path.
65
+ - Click **Save Segment to SVG**.
66
+ 4. **Download:** Once you have saved all desired segments, click the **Download Final SVG** button to retrieve your newly layered vector graphic.
67
+
68
+ ## 📂 Project Structure
69
+
70
+ ```text
71
+ svg-samurai/
72
+ ├── app.py # Main Streamlit user interface and application state
73
+ ├── pyproject.toml # Poetry dependencies and project configuration
74
+ ├── src/ # Backend logic
75
+ │ ├── model.py # PyTorch SAM loading, embedding generation, and mask prediction
76
+ │ ├── vectorizer.py # OpenCV contour extraction and SVG path conversion
77
+ │ └── xml_manager.py # lxml DOM manipulation and CairoSVG rasterization utilities
78
+ └── tests/ # Unit tests for core logic
79
+ ├── test_model.py
80
+ ├── test_vectorizer.py
81
+ └��─ test_xml_manager.py
82
+ ```
83
+
84
+ ## 🧪 Testing
85
+
86
+ The project uses `pytest` for unit testing. To run the test suite, simply execute:
87
+
88
+ ```bash
89
+ poetry run pytest
90
+ ```
91
+
92
+ ---
93
+ *Developed with Streamlit and Meta's Segment Anything Model.*
requirements.txt CHANGED
@@ -1,3 +1,83 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ altair==4.2.2
2
+ annotated-doc==0.0.4
3
+ anyio==3.7.1
4
+ attrs==23.1.0
5
+ blinker==1.6.2
6
+ cachetools==5.3.1
7
+ cairocffi==1.6.1
8
+ CairoSVG==2.7.0
9
+ certifi==2023.7.22
10
+ cffi==1.15.1
11
+ charset-normalizer==3.2.0
12
+ click==8.1.7
13
+ cssselect2==0.7.0
14
+ defusedxml==0.7.1
15
+ entrypoints==0.4
16
+ filelock==3.12.2
17
+ fsspec==2023.6.0
18
+ gitdb==4.0.10
19
+ GitPython==3.1.32
20
+ h11==0.14.0
21
+ httpcore==0.17.3
22
+ httpx==0.24.1
23
+ huggingface-hub==0.16.4
24
+ idna==3.4
25
+ iniconfig==2.0.0
26
+ Jinja2==3.1.2
27
+ jsonschema==4.18.6
28
+ jsonschema-specifications==2023.7.1
29
+ lxml==4.9.3
30
+ markdown-it-py==3.0.0
31
+ MarkupSafe==2.1.3
32
+ mdurl==0.1.2
33
+ mpmath==1.3.0
34
+ networkx==3.1
35
+ numpy==1.25.2
36
+ opencv-python==4.8.0.76
37
+ packaging==23.1
38
+ pandas==2.0.3
39
+ Pillow==10.0.0
40
+ pluggy==1.2.0
41
+ protobuf==4.23.4
42
+ pyarrow==12.0.1
43
+ pycparser==2.21
44
+ pydeck==0.8.0b4
45
+ Pygments==2.16.1
46
+ pytest==7.4.0
47
+ pytest-mock==3.11.1
48
+ python-dateutil==2.8.2
49
+ pytz==2023.3
50
+ pytz-deprecation-shim==0.1.0.post0
51
+ PyYAML==6.0.1
52
+ referencing==0.30.2
53
+ regex==2023.8.8
54
+ requests==2.31.0
55
+ rich==13.5.2
56
+ rpds-py==0.9.2
57
+ safetensors==0.3.2
58
+ setuptools==68.0.0
59
+ six==1.16.0
60
+ smmap==5.0.0
61
+ streamlit==1.25.0
62
+ streamlit-image-coordinates==0.1.4
63
+ sympy==1.12
64
+ tenacity==8.2.3
65
+ tinycss2==1.2.1
66
+ tokenizers==0.13.3
67
+ toml==0.10.2
68
+ toolz==0.12.0
69
+ torch==2.0.1
70
+ tornado==6.3.3
71
+ tqdm==4.66.1
72
+ transformers==4.30.2
73
+ triton==2.0.0
74
+ typer==0.9.0
75
+ typing_extensions==4.7.1
76
+ tzdata==2023.3
77
+ tzlocal==5.0.1
78
+ urllib3==2.0.4
79
+ validators==0.21.2
80
+ watchdog==3.0.0
81
+ webencodings==0.5.1
82
+ wheel==0.41.1
83
+ zipp==3.16.2
src/src/__init__.py ADDED
File without changes
src/src/model.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import SamModel, SamProcessor
3
+ import streamlit as st
4
+ import numpy as np
5
+ from PIL import Image
6
+ from typing import Tuple, List
7
+
8
+
9
+ # Use @st.cache_resource to avoid reloading the model on every rerun
10
+ @st.cache_resource(show_spinner="Loading Segment Anything Model (SAM)...")
11
+ def load_sam_model() -> Tuple[SamModel, SamProcessor, str]:
12
+ """Loads the SAM model and processor from Hugging Face."""
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ # Using facebook/sam-vit-base as the standard baseline
15
+ model_id = "facebook/sam-vit-base"
16
+ model = SamModel.from_pretrained(model_id).to(device)
17
+ processor = SamProcessor.from_pretrained(model_id)
18
+ return model, processor, device
19
+
20
+
21
+ @st.cache_resource(show_spinner="Computing Image Embeddings...")
22
+ def compute_image_embedding(image: Image.Image) -> torch.Tensor:
23
+ """
24
+ Computes and caches the SAM image embedding for a given image.
25
+ This is the heavy part of the computation.
26
+ """
27
+ model, processor, device = load_sam_model()
28
+
29
+ # Preprocess the image to get pixel values
30
+ inputs = processor(images=image, return_tensors="pt").to(device)
31
+
32
+ # Compute image embeddings
33
+ with torch.no_grad():
34
+ image_embeddings = model.get_image_embeddings(inputs.pixel_values)
35
+
36
+ return image_embeddings
37
+
38
+
39
+ def predict_mask(
40
+ image: Image.Image,
41
+ image_embeddings: torch.Tensor,
42
+ input_points: List[List[int]],
43
+ input_labels: List[int],
44
+ ) -> np.ndarray:
45
+ """
46
+ Predicts a binary mask given the image embeddings and prompt points.
47
+ input_points: list of [x, y] coordinates
48
+ input_labels: list of 1 (positive) or 0 (negative) for each point
49
+ """
50
+ model, processor, device = load_sam_model()
51
+
52
+ # Format inputs for the processor
53
+ # The processor expects points in the format [[[x1, y1], [x2, y2], ...]]
54
+ # and labels in [[1, 0, ...]] for a single batch
55
+ points = [input_points]
56
+ labels = [input_labels]
57
+
58
+ # Preprocess prompts
59
+ inputs = processor(
60
+ images=image, input_points=points, input_labels=labels, return_tensors="pt"
61
+ ).to(device)
62
+
63
+ # Run prediction using the cached embeddings
64
+ with torch.no_grad():
65
+ outputs = model(
66
+ image_embeddings=image_embeddings,
67
+ input_points=inputs.input_points,
68
+ input_labels=inputs.input_labels,
69
+ multimask_output=False, # We only want the best mask
70
+ )
71
+
72
+ # Process the predicted mask back to the original image size
73
+ # inputs contains original_sizes and reshaped_input_sizes from the processor call
74
+ masks = processor.image_processor.post_process_masks(
75
+ outputs.pred_masks.cpu(),
76
+ inputs["original_sizes"].cpu(),
77
+ inputs["reshaped_input_sizes"].cpu(),
78
+ )
79
+
80
+ # masks is a list of tensors, get the first one and squeeze it to a 2D array
81
+ mask = masks[0]
82
+ # Squeeze out the batch and channel dimensions if present, but keep spatial dims.
83
+ # Usually shape is (1, 1, H, W) or (1, H, W)
84
+ if mask.ndim > 2:
85
+ mask = mask.squeeze()
86
+ # If the image was 1x1, squeeze might have removed all dimensions.
87
+ if mask.ndim < 2:
88
+ mask = mask.view(masks[0].shape[-2], masks[0].shape[-1])
89
+
90
+ mask = mask.numpy()
91
+
92
+ # The mask is boolean, convert to uint8 for OpenCV (0 and 255)
93
+ binary_mask = (mask * 255).astype(np.uint8)
94
+
95
+ return binary_mask
src/src/vectorizer.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+
5
+ def mask_to_svg_path(mask: np.ndarray, epsilon_factor: float = 0.005) -> str:
6
+ """
7
+ Converts a binary mask to an SVG path string.
8
+
9
+ Args:
10
+ mask (np.ndarray): The 2D binary mask.
11
+ epsilon_factor (float): The factor for approximating the contour with Ramer-Douglas-Peucker algorithm.
12
+ A higher value means more simplification (fewer points, smaller SVG size).
13
+
14
+ Returns:
15
+ str: An SVG path data string (`M x,y L x,y Z ...`).
16
+ """
17
+ if not isinstance(mask, np.ndarray) or mask.ndim != 2:
18
+ raise ValueError("Mask must be a 2D numpy array.")
19
+ # 1. Extract Contours
20
+ # RETR_CCOMP retrieves all of the contours and organizes them into a two-level hierarchy.
21
+ # At the top level, there are external boundaries of the components.
22
+ # At the second level, there are boundaries of the holes.
23
+ contours, hierarchy = cv2.findContours(
24
+ mask, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_SIMPLE
25
+ )
26
+
27
+ if contours is None or len(contours) == 0:
28
+ return ""
29
+
30
+ path_data = []
31
+
32
+ # 2. Iterate through contours and hierarchy to build the path
33
+ # The hierarchy array has shape (1, num_contours, 4)
34
+ # The 4 elements are: [Next, Previous, First_Child, Parent]
35
+ if hierarchy is None:
36
+ return ""
37
+
38
+ hierarchy = hierarchy[0]
39
+
40
+ for i, contour in enumerate(contours):
41
+ # We only want to process the contours if it has at least 3 points
42
+ if len(contour) < 3:
43
+ continue
44
+
45
+ # 3. Simplify Contour
46
+ # Calculate epsilon based on the contour's arc length
47
+ epsilon = epsilon_factor * cv2.arcLength(contour, True)
48
+ approx = cv2.approxPolyDP(contour, epsilon, True)
49
+
50
+ # We want to skip highly simplified contours that are just points or lines
51
+ if len(approx) < 3:
52
+ continue
53
+
54
+ # 4. Format to SVG path
55
+ # M = moveto (start point)
56
+ # L = lineto (subsequent points)
57
+ # Z = closepath (return to start)
58
+ pts = approx.reshape(-1, 2)
59
+
60
+ # Add the M command for the first point
61
+ path_data.append(f"M {pts[0][0]},{pts[0][1]}")
62
+
63
+ # Add the L commands for the rest
64
+ for x, y in pts[1:]:
65
+ path_data.append(f"L {x},{y}")
66
+
67
+ # Close the contour
68
+ path_data.append("Z")
69
+
70
+ return " ".join(path_data)
src/src/xml_manager.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lxml import etree
2
+ import cairosvg
3
+ import io
4
+ from PIL import Image
5
+ from typing import Any
6
+
7
+ # Namespace for SVG creation
8
+ SVG_NS = "http://www.w3.org/2000/svg"
9
+ NSMAP = {None: SVG_NS}
10
+
11
+
12
+ def create_base_svg(width: int, height: int) -> str:
13
+ """Creates a basic empty SVG string with specified dimensions."""
14
+ root = etree.Element(
15
+ "svg",
16
+ width=str(width),
17
+ height=str(height),
18
+ viewBox=f"0 0 {width} {height}",
19
+ nsmap=NSMAP,
20
+ )
21
+ return etree.tostring(root, pretty_print=True, encoding="unicode")
22
+
23
+
24
+ def add_path_to_svg(
25
+ svg_str: str,
26
+ path_d: str,
27
+ path_id: str,
28
+ fill_color: str = "#FF0000",
29
+ opacity: float = 0.5,
30
+ ) -> str:
31
+ """
32
+ Injects an SVG `<path>` into an existing SVG string within a `<g>` group using lxml.
33
+ """
34
+ if not path_d:
35
+ return svg_str
36
+
37
+ try:
38
+ # Provide a parser that handles basic errors and mitigates XXE injection
39
+ parser = etree.XMLParser(recover=True, resolve_entities=False, no_network=True)
40
+ root = etree.fromstring(
41
+ svg_str.encode("utf-8", errors="replace"), parser=parser
42
+ )
43
+ if root is None:
44
+ return svg_str
45
+ except Exception:
46
+ # If the string isn't an XML document or parsing fails
47
+ return svg_str
48
+
49
+ # Find the correct namespace for the root or default to SVG_NS
50
+ ns = SVG_NS
51
+ if root.nsmap and None in root.nsmap:
52
+ ns = root.nsmap[None]
53
+ elif root.tag.startswith("{"):
54
+ ns = root.tag[1:].split("}")[0]
55
+
56
+ # Clean the namespace map to avoid redundant ns0 prefixes
57
+ # Ensure xmlns is explicitly available in nsmap of new elements
58
+ new_nsmap = {None: ns} if ns else None
59
+
60
+ # Create the <g id="path_id">
61
+ group = etree.SubElement(
62
+ root, f"{{{ns}}}g" if ns else "g", id=path_id, nsmap=new_nsmap
63
+ )
64
+
65
+ # Create the <path>
66
+ # Using fill-rule="evenodd" is important when combining outer boundaries and inner holes
67
+ etree.SubElement(
68
+ group,
69
+ f"{{{ns}}}path" if ns else "path",
70
+ d=path_d,
71
+ fill=fill_color,
72
+ opacity=str(opacity),
73
+ attrib={"fill-rule": "evenodd"}, # Handles holes properly
74
+ )
75
+
76
+ return etree.tostring(root, pretty_print=True, encoding="unicode")
77
+
78
+
79
+ def parse_svg_to_image(svg_bytes: bytes) -> Image.Image:
80
+ """Converts uploaded SVG file bytes into a PIL Image."""
81
+ # Pass url_fetcher to block network and local file access from within SVG
82
+ png_bytes = cairosvg.svg2png(
83
+ bytestring=svg_bytes, url_fetcher=lambda *args, **kwargs: b""
84
+ )
85
+ return Image.open(io.BytesIO(png_bytes))
86
+
87
+
88
+ def load_image(uploaded_file: Any) -> Image.Image:
89
+ """Loads an uploaded image (Raster or Vector) and returns a PIL Image."""
90
+ if getattr(uploaded_file, "type", "") == "image/svg+xml":
91
+ return parse_svg_to_image(uploaded_file.getvalue())
92
+ else:
93
+ # Handle regular rasters (PNG, JPG)
94
+ return Image.open(uploaded_file).convert("RGB")
src/streamlit_app.py CHANGED
@@ -1,40 +1,223 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
1
  import streamlit as st
2
+ from PIL import Image
3
+ import numpy as np
4
+ from streamlit_image_coordinates import streamlit_image_coordinates
5
+
6
+ from src.model import compute_image_embedding, predict_mask
7
+ from src.vectorizer import mask_to_svg_path
8
+ from src.xml_manager import load_image, create_base_svg
9
+
10
+ st.set_page_config(page_title="SVG-SAMurai", layout="wide", page_icon="🗡️")
11
+
12
+ # Session State Initialization
13
+ if "image" not in st.session_state:
14
+ st.session_state.image = None
15
+ if "image_embedding" not in st.session_state:
16
+ st.session_state.image_embedding = None
17
+ if "points" not in st.session_state:
18
+ st.session_state.points = []
19
+ if "labels" not in st.session_state:
20
+ st.session_state.labels = []
21
+ if "current_mask" not in st.session_state:
22
+ st.session_state.current_mask = None
23
+ if "segments" not in st.session_state:
24
+ st.session_state.segments = {}
25
+ if "original_svg" not in st.session_state:
26
+ st.session_state.original_svg = None
27
+
28
+ st.title("SVG-SAMurai 🗡️")
29
+ st.markdown(
30
+ "Transform raster and vector images into segmented SVG paths using the **Segment Anything Model (SAM)**."
31
+ )
32
+
33
+ # File uploader
34
+ uploaded_file = st.file_uploader(
35
+ "Upload an Image (PNG, JPG, SVG)", type=["png", "jpg", "jpeg", "svg"]
36
+ )
37
+
38
+ if uploaded_file is not None:
39
+ # Reset state if a new file is uploaded
40
+ if (
41
+ "last_uploaded_file_id" not in st.session_state
42
+ or st.session_state.last_uploaded_file_id != uploaded_file.file_id
43
+ ):
44
+ st.session_state.last_uploaded_file_id = uploaded_file.file_id
45
+ st.session_state.image = None
46
+ st.session_state.image_embedding = None
47
+ st.session_state.points = []
48
+ st.session_state.labels = []
49
+ st.session_state.current_mask = None
50
+ st.session_state.segments = {}
51
+ st.session_state.original_svg = None
52
+
53
+ if st.session_state.image is None:
54
+ # Load the image
55
+ with st.spinner("Processing Image..."):
56
+ image = load_image(uploaded_file)
57
+ st.session_state.image = image
58
+
59
+ # If the original file was an SVG, save its string representation
60
+ if uploaded_file.type == "image/svg+xml":
61
+ st.session_state.original_svg = uploaded_file.getvalue().decode(
62
+ "utf-8", errors="replace"
63
+ )
64
+ else:
65
+ # Create a blank SVG canvas with the original raster image dimensions
66
+ width, height = image.size
67
+ st.session_state.original_svg = create_base_svg(width, height)
68
+
69
+ # Compute image embeddings once
70
+ st.session_state.image_embedding = compute_image_embedding(image)
71
+ st.success("Image embedded successfully!")
72
+
73
+ col1, col2 = st.columns([2, 1])
74
+
75
+ with col1:
76
+ st.subheader("Interactive Segmentation")
77
+
78
+ # Display the image with coordinates clicker
79
+ # If there's a mask, we overlay it
80
+ display_image = st.session_state.image.copy()
81
+ if st.session_state.current_mask is not None:
82
+ # Create a semi-transparent blue overlay for the current mask
83
+ overlay = np.zeros(
84
+ (*st.session_state.current_mask.shape, 4), dtype=np.uint8
85
+ )
86
+ overlay[st.session_state.current_mask > 0] = [
87
+ 0,
88
+ 0,
89
+ 255,
90
+ 128,
91
+ ] # Blue, 50% opacity
92
+ overlay_image = Image.fromarray(overlay, mode="RGBA")
93
+ display_image = display_image.convert("RGBA")
94
+ display_image.paste(overlay_image, (0, 0), overlay_image)
95
+ display_image = display_image.convert(
96
+ "RGB"
97
+ ) # Convert back to RGB for display
98
+
99
+ # Show the image using streamlit-image-coordinates
100
+ # Note: we need to handle scaling if the image is wider than the container
101
+ # streamlit-image-coordinates scales the image to the container width but gives the
102
+ # coordinates relative to the original image dimensions.
103
+ value = streamlit_image_coordinates(display_image, key="image_coord")
104
+
105
+ # Handle clicks
106
+ if value is not None:
107
+ # streamlit_image_coordinates returns x, y relative to the original image size
108
+ x, y = value["x"], value["y"]
109
+
110
+ # Determine if it's a positive or negative prompt
111
+ # For simplicity, let's say left click is positive, and we can add a toggle for negative
112
+ is_positive = st.sidebar.checkbox(
113
+ "Next Click is Negative Prompt (Exclude)",
114
+ value=False,
115
+ key="neg_prompt_toggle",
116
+ )
117
+ label = 0 if is_positive else 1
118
+
119
+ # Check if this is a new click (prevent reruns from adding the same point repeatedly)
120
+ new_point = [x, y]
121
+ if not st.session_state.points or st.session_state.points[-1] != new_point:
122
+ st.session_state.points.append(new_point)
123
+ st.session_state.labels.append(label)
124
+
125
+ # Predict new mask
126
+ with st.spinner("Predicting Segment..."):
127
+ mask = predict_mask(
128
+ st.session_state.image,
129
+ st.session_state.image_embedding,
130
+ st.session_state.points,
131
+ st.session_state.labels,
132
+ )
133
+ st.session_state.current_mask = mask
134
+ st.rerun()
135
+
136
+ # Tools for interacting with the points
137
+ col_btn1, col_btn2 = st.columns(2)
138
+ with col_btn1:
139
+ if st.button("Undo Last Click"):
140
+ if st.session_state.points:
141
+ st.session_state.points.pop()
142
+ st.session_state.labels.pop()
143
+ if st.session_state.points:
144
+ # Repredict
145
+ mask = predict_mask(
146
+ st.session_state.image,
147
+ st.session_state.image_embedding,
148
+ st.session_state.points,
149
+ st.session_state.labels,
150
+ )
151
+ st.session_state.current_mask = mask
152
+ else:
153
+ st.session_state.current_mask = None
154
+ st.rerun()
155
+
156
+ with col_btn2:
157
+ if st.button("Clear Current Selection"):
158
+ st.session_state.points = []
159
+ st.session_state.labels = []
160
+ st.session_state.current_mask = None
161
+ st.rerun()
162
+
163
+ with col2:
164
+ st.subheader("Segment Management")
165
+
166
+ segment_name = st.text_input("Segment Name", placeholder="e.g., car_body")
167
+ epsilon_factor = st.slider(
168
+ "Vectorization Simplification (epsilon)",
169
+ min_value=0.001,
170
+ max_value=0.05,
171
+ value=0.005,
172
+ step=0.001,
173
+ format="%.3f",
174
+ )
175
+
176
+ if st.button(
177
+ "Save Segment to SVG",
178
+ disabled=st.session_state.current_mask is None or not segment_name,
179
+ ):
180
+ with st.spinner("Vectorizing..."):
181
+ # 1. Convert mask to SVG path
182
+ path_d = mask_to_svg_path(
183
+ st.session_state.current_mask, epsilon_factor=epsilon_factor
184
+ )
185
+
186
+ # 2. Add to session state segments dictionary
187
+ st.session_state.segments[segment_name] = path_d
188
+
189
+ # 3. Inject the path into the working SVG string
190
+ from src.xml_manager import add_path_to_svg
191
+
192
+ try:
193
+ st.session_state.original_svg = add_path_to_svg(
194
+ st.session_state.original_svg,
195
+ path_d,
196
+ segment_name,
197
+ fill_color="#FF0000",
198
+ opacity=0.5,
199
+ )
200
+
201
+ # Clear current selection for the next segment
202
+ st.session_state.points = []
203
+ st.session_state.labels = []
204
+ st.session_state.current_mask = None
205
+
206
+ st.success(f"Segment '{segment_name}' saved!")
207
+ st.rerun()
208
+ except Exception as e:
209
+ st.error(f"Failed to inject SVG: {e}")
210
+
211
+ # Display saved segments list
212
+ if st.session_state.segments:
213
+ st.write("### Saved Segments")
214
+ for name in st.session_state.segments.keys():
215
+ st.markdown(f"- **{name}**")
216
 
217
+ # Provide download button for the final SVG
218
+ st.download_button(
219
+ label="Download Final SVG",
220
+ data=st.session_state.original_svg,
221
+ file_name="segmented_output.svg",
222
+ mime="image/svg+xml",
223
+ )