File size: 2,906 Bytes
e6af059
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc71e86
e6af059
 
e5f9099
 
 
 
 
 
 
e6af059
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
import traceback

import gradio as gr
import torch
from PIL import Image

from starvector.model.starvector_arch import StarVectorForCausalLM
from starvector.data.util import process_and_rasterize_svg


def _ensure_hf_token_env() -> None:
    # The repo docs mention HUGGING_FACE_HUB_TOKEN; huggingface_hub also looks for HF_TOKEN.
    tok = os.environ.get("HUGGING_FACE_HUB_TOKEN") or os.environ.get("HF_TOKEN")
    if tok and not os.environ.get("HF_TOKEN"):
        os.environ["HF_TOKEN"] = tok


_ensure_hf_token_env()

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32

MODEL_NAME = os.environ.get("STARVECTOR_MODEL", "starvector/starvector-1b-im2svg")
MAX_LENGTH = int(os.environ.get("STARVECTOR_MAX_LENGTH", "4000"))

# Load model with proper dtype for GPU/CPU compatibility
print(f"Starting StarVector Space on device={DEVICE} dtype={DTYPE} model={MODEL_NAME}", flush=True)

starvector = StarVectorForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=DTYPE,
    device_map="auto" if DEVICE == "cuda" else None
)
if DEVICE != "cuda":
    starvector = starvector.to(device=DEVICE)
starvector.eval()


def im2svg(image: Image.Image, max_length: int = MAX_LENGTH):
    if image is None:
        return "", None, "No image provided."

    try:
        # Preprocess image
        image_tensor = starvector.process_images([image])[0]
        image_tensor = image_tensor.to(device=DEVICE)
        if DEVICE == "cuda":
            image_tensor = image_tensor.to(dtype=DTYPE)

        batch = {"image": image_tensor}

        # Generate raw svg
        raw_svg = starvector.generate_im2svg(batch, max_length=max_length)[0]

        # Clean + rasterize preview
        svg, raster = process_and_rasterize_svg(raw_svg)

        # raster may be a PIL image
        return svg, raster, ""

    except Exception:
        return "", None, traceback.format_exc()


with gr.Blocks() as demo:
    gr.Markdown(
        "# StarVector (GPU)\n"
        "Upload an icon/logo/diagram-like image and generate SVG code.\n\n"
        "Notes:\n"
        "- This Space requires a GPU and a Hugging Face token with access to the gated `bigcode/starcoderbase-1b` model.\n"
        "- Set Space secret `HUGGING_FACE_HUB_TOKEN`."
    )

    with gr.Row():
        inp = gr.Image(type="pil", label="Input image")
        preview = gr.Image(type="pil", label="Rasterized preview")

    max_len = gr.Slider(256, 8000, value=MAX_LENGTH, step=128, label="max_length")

    out_svg = gr.Code(language="xml", label="SVG")
    err = gr.Textbox(label="Error", visible=True)

    btn = gr.Button("Generate SVG")
    btn.click(im2svg, inputs=[inp, max_len], outputs=[out_svg, preview, err])


if __name__ == "__main__":
    port = int(os.environ.get("PORT", "7860"))
    demo.queue(concurrency_count=1)
    demo.launch(server_name="0.0.0.0", server_port=port)