Spaces:
Sleeping
Sleeping
scott-ashton-tds Oz commited on
Commit ·
e6af059
1
Parent(s): a6acd5a
Add Docker Space app for StarVector
Browse filesCo-Authored-By: Oz <oz-agent@warp.dev>
- Dockerfile +32 -0
- README.md +11 -3
- app.py +85 -0
Dockerfile
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM pytorch/pytorch:2.5.1-cuda12.4-cudnn9-devel
|
| 2 |
+
|
| 3 |
+
# System deps
|
| 4 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 5 |
+
git \
|
| 6 |
+
build-essential \
|
| 7 |
+
libcairo2 \
|
| 8 |
+
libglib2.0-0 \
|
| 9 |
+
libsm6 \
|
| 10 |
+
libxext6 \
|
| 11 |
+
libxrender1 \
|
| 12 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 13 |
+
|
| 14 |
+
WORKDIR /app
|
| 15 |
+
|
| 16 |
+
# Install StarVector from GitHub
|
| 17 |
+
# Pin to main branch by default; you can pin a commit for reproducibility.
|
| 18 |
+
RUN git clone --depth 1 https://github.com/joanrod/star-vector.git /app/star-vector
|
| 19 |
+
|
| 20 |
+
# Install python dependencies
|
| 21 |
+
RUN pip install --upgrade pip \
|
| 22 |
+
&& pip install -e /app/star-vector
|
| 23 |
+
|
| 24 |
+
# Copy Space app
|
| 25 |
+
COPY app.py /app/app.py
|
| 26 |
+
|
| 27 |
+
# Hugging Face Spaces sets PORT; default to 7860 locally.
|
| 28 |
+
ENV PORT=7860
|
| 29 |
+
|
| 30 |
+
EXPOSE 7860
|
| 31 |
+
|
| 32 |
+
CMD ["python", "/app/app.py"]
|
README.md
CHANGED
|
@@ -1,11 +1,19 @@
|
|
| 1 |
---
|
| 2 |
-
title: Starvector
|
| 3 |
emoji: 👀
|
| 4 |
colorFrom: indigo
|
| 5 |
colorTo: indigo
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
-
short_description:
|
| 9 |
---
|
| 10 |
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Starvector GPU
|
| 3 |
emoji: 👀
|
| 4 |
colorFrom: indigo
|
| 5 |
colorTo: indigo
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
+
short_description: StarVector Image→SVG on a CUDA GPU
|
| 9 |
---
|
| 10 |
|
| 11 |
+
This is a Docker Space that runs **StarVector** (image→SVG generation) with **CUDA**.
|
| 12 |
+
|
| 13 |
+
Setup
|
| 14 |
+
- In Space settings, pick a GPU hardware tier.
|
| 15 |
+
- Add a Space secret `HUGGING_FACE_HUB_TOKEN` with access to the gated model `bigcode/starcoderbase-1b`.
|
| 16 |
+
|
| 17 |
+
Env vars (optional)
|
| 18 |
+
- `STARVECTOR_MODEL` (default: `starvector/starvector-1b-im2svg`)
|
| 19 |
+
- `STARVECTOR_MAX_LENGTH` (default: `4000`)
|
app.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import traceback
|
| 3 |
+
|
| 4 |
+
import gradio as gr
|
| 5 |
+
import torch
|
| 6 |
+
from PIL import Image
|
| 7 |
+
|
| 8 |
+
from starvector.model.starvector_arch import StarVectorForCausalLM
|
| 9 |
+
from starvector.data.util import process_and_rasterize_svg
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _ensure_hf_token_env() -> None:
|
| 13 |
+
# The repo docs mention HUGGING_FACE_HUB_TOKEN; huggingface_hub also looks for HF_TOKEN.
|
| 14 |
+
tok = os.environ.get("HUGGING_FACE_HUB_TOKEN") or os.environ.get("HF_TOKEN")
|
| 15 |
+
if tok and not os.environ.get("HF_TOKEN"):
|
| 16 |
+
os.environ["HF_TOKEN"] = tok
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
_ensure_hf_token_env()
|
| 20 |
+
|
| 21 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 22 |
+
DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
|
| 23 |
+
|
| 24 |
+
MODEL_NAME = os.environ.get("STARVECTOR_MODEL", "starvector/starvector-1b-im2svg")
|
| 25 |
+
MAX_LENGTH = int(os.environ.get("STARVECTOR_MAX_LENGTH", "4000"))
|
| 26 |
+
|
| 27 |
+
print(f"Starting StarVector Space on device={DEVICE} dtype={DTYPE} model={MODEL_NAME}", flush=True)
|
| 28 |
+
|
| 29 |
+
starvector = StarVectorForCausalLM.from_pretrained(MODEL_NAME)
|
| 30 |
+
starvector = starvector.to(device=DEVICE)
|
| 31 |
+
starvector.eval()
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def im2svg(image: Image.Image, max_length: int = MAX_LENGTH):
|
| 35 |
+
if image is None:
|
| 36 |
+
return "", None, "No image provided."
|
| 37 |
+
|
| 38 |
+
try:
|
| 39 |
+
# Preprocess image
|
| 40 |
+
image_tensor = starvector.process_images([image])[0]
|
| 41 |
+
image_tensor = image_tensor.to(device=DEVICE)
|
| 42 |
+
if DEVICE == "cuda":
|
| 43 |
+
image_tensor = image_tensor.to(dtype=DTYPE)
|
| 44 |
+
|
| 45 |
+
batch = {"image": image_tensor}
|
| 46 |
+
|
| 47 |
+
# Generate raw svg
|
| 48 |
+
raw_svg = starvector.generate_im2svg(batch, max_length=max_length)[0]
|
| 49 |
+
|
| 50 |
+
# Clean + rasterize preview
|
| 51 |
+
svg, raster = process_and_rasterize_svg(raw_svg)
|
| 52 |
+
|
| 53 |
+
# raster may be a PIL image
|
| 54 |
+
return svg, raster, ""
|
| 55 |
+
|
| 56 |
+
except Exception:
|
| 57 |
+
return "", None, traceback.format_exc()
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
with gr.Blocks() as demo:
|
| 61 |
+
gr.Markdown(
|
| 62 |
+
"# StarVector (GPU)\n"
|
| 63 |
+
"Upload an icon/logo/diagram-like image and generate SVG code.\n\n"
|
| 64 |
+
"Notes:\n"
|
| 65 |
+
"- This Space requires a GPU and a Hugging Face token with access to the gated `bigcode/starcoderbase-1b` model.\n"
|
| 66 |
+
"- Set Space secret `HUGGING_FACE_HUB_TOKEN`."
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
with gr.Row():
|
| 70 |
+
inp = gr.Image(type="pil", label="Input image")
|
| 71 |
+
preview = gr.Image(type="pil", label="Rasterized preview")
|
| 72 |
+
|
| 73 |
+
max_len = gr.Slider(256, 8000, value=MAX_LENGTH, step=128, label="max_length")
|
| 74 |
+
|
| 75 |
+
out_svg = gr.Code(language="xml", label="SVG")
|
| 76 |
+
err = gr.Textbox(label="Error", visible=True)
|
| 77 |
+
|
| 78 |
+
btn = gr.Button("Generate SVG")
|
| 79 |
+
btn.click(im2svg, inputs=[inp, max_len], outputs=[out_svg, preview, err])
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
if __name__ == "__main__":
|
| 83 |
+
port = int(os.environ.get("PORT", "7860"))
|
| 84 |
+
demo.queue(concurrency_count=1)
|
| 85 |
+
demo.launch(server_name="0.0.0.0", server_port=port)
|