import os import traceback import gradio as gr import torch from PIL import Image from starvector.model.starvector_arch import StarVectorForCausalLM from starvector.data.util import process_and_rasterize_svg def _ensure_hf_token_env() -> None: # The repo docs mention HUGGING_FACE_HUB_TOKEN; huggingface_hub also looks for HF_TOKEN. tok = os.environ.get("HUGGING_FACE_HUB_TOKEN") or os.environ.get("HF_TOKEN") if tok and not os.environ.get("HF_TOKEN"): os.environ["HF_TOKEN"] = tok _ensure_hf_token_env() DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32 MODEL_NAME = os.environ.get("STARVECTOR_MODEL", "starvector/starvector-1b-im2svg") MAX_LENGTH = int(os.environ.get("STARVECTOR_MAX_LENGTH", "4000")) # Load model with proper dtype for GPU/CPU compatibility print(f"Starting StarVector Space on device={DEVICE} dtype={DTYPE} model={MODEL_NAME}", flush=True) starvector = StarVectorForCausalLM.from_pretrained( MODEL_NAME, torch_dtype=DTYPE, device_map="auto" if DEVICE == "cuda" else None ) if DEVICE != "cuda": starvector = starvector.to(device=DEVICE) starvector.eval() def im2svg(image: Image.Image, max_length: int = MAX_LENGTH): if image is None: return "", None, "No image provided." try: # Preprocess image image_tensor = starvector.process_images([image])[0] image_tensor = image_tensor.to(device=DEVICE) if DEVICE == "cuda": image_tensor = image_tensor.to(dtype=DTYPE) batch = {"image": image_tensor} # Generate raw svg raw_svg = starvector.generate_im2svg(batch, max_length=max_length)[0] # Clean + rasterize preview svg, raster = process_and_rasterize_svg(raw_svg) # raster may be a PIL image return svg, raster, "" except Exception: return "", None, traceback.format_exc() with gr.Blocks() as demo: gr.Markdown( "# StarVector (GPU)\n" "Upload an icon/logo/diagram-like image and generate SVG code.\n\n" "Notes:\n" "- This Space requires a GPU and a Hugging Face token with access to the gated `bigcode/starcoderbase-1b` model.\n" "- Set Space secret `HUGGING_FACE_HUB_TOKEN`." ) with gr.Row(): inp = gr.Image(type="pil", label="Input image") preview = gr.Image(type="pil", label="Rasterized preview") max_len = gr.Slider(256, 8000, value=MAX_LENGTH, step=128, label="max_length") out_svg = gr.Code(language="xml", label="SVG") err = gr.Textbox(label="Error", visible=True) btn = gr.Button("Generate SVG") btn.click(im2svg, inputs=[inp, max_len], outputs=[out_svg, preview, err]) if __name__ == "__main__": port = int(os.environ.get("PORT", "7860")) demo.queue(concurrency_count=1) demo.launch(server_name="0.0.0.0", server_port=port)