starvector-gpu / app.py
scott-ashton-tds
Trigger rebuild with comment
dc71e86
import os
import traceback
import gradio as gr
import torch
from PIL import Image
from starvector.model.starvector_arch import StarVectorForCausalLM
from starvector.data.util import process_and_rasterize_svg
def _ensure_hf_token_env() -> None:
# The repo docs mention HUGGING_FACE_HUB_TOKEN; huggingface_hub also looks for HF_TOKEN.
tok = os.environ.get("HUGGING_FACE_HUB_TOKEN") or os.environ.get("HF_TOKEN")
if tok and not os.environ.get("HF_TOKEN"):
os.environ["HF_TOKEN"] = tok
_ensure_hf_token_env()
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
MODEL_NAME = os.environ.get("STARVECTOR_MODEL", "starvector/starvector-1b-im2svg")
MAX_LENGTH = int(os.environ.get("STARVECTOR_MAX_LENGTH", "4000"))
# Load model with proper dtype for GPU/CPU compatibility
print(f"Starting StarVector Space on device={DEVICE} dtype={DTYPE} model={MODEL_NAME}", flush=True)
starvector = StarVectorForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=DTYPE,
device_map="auto" if DEVICE == "cuda" else None
)
if DEVICE != "cuda":
starvector = starvector.to(device=DEVICE)
starvector.eval()
def im2svg(image: Image.Image, max_length: int = MAX_LENGTH):
if image is None:
return "", None, "No image provided."
try:
# Preprocess image
image_tensor = starvector.process_images([image])[0]
image_tensor = image_tensor.to(device=DEVICE)
if DEVICE == "cuda":
image_tensor = image_tensor.to(dtype=DTYPE)
batch = {"image": image_tensor}
# Generate raw svg
raw_svg = starvector.generate_im2svg(batch, max_length=max_length)[0]
# Clean + rasterize preview
svg, raster = process_and_rasterize_svg(raw_svg)
# raster may be a PIL image
return svg, raster, ""
except Exception:
return "", None, traceback.format_exc()
with gr.Blocks() as demo:
gr.Markdown(
"# StarVector (GPU)\n"
"Upload an icon/logo/diagram-like image and generate SVG code.\n\n"
"Notes:\n"
"- This Space requires a GPU and a Hugging Face token with access to the gated `bigcode/starcoderbase-1b` model.\n"
"- Set Space secret `HUGGING_FACE_HUB_TOKEN`."
)
with gr.Row():
inp = gr.Image(type="pil", label="Input image")
preview = gr.Image(type="pil", label="Rasterized preview")
max_len = gr.Slider(256, 8000, value=MAX_LENGTH, step=128, label="max_length")
out_svg = gr.Code(language="xml", label="SVG")
err = gr.Textbox(label="Error", visible=True)
btn = gr.Button("Generate SVG")
btn.click(im2svg, inputs=[inp, max_len], outputs=[out_svg, preview, err])
if __name__ == "__main__":
port = int(os.environ.get("PORT", "7860"))
demo.queue(concurrency_count=1)
demo.launch(server_name="0.0.0.0", server_port=port)