Spaces:
Sleeping
Sleeping
| import os | |
| import traceback | |
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from starvector.model.starvector_arch import StarVectorForCausalLM | |
| from starvector.data.util import process_and_rasterize_svg | |
| def _ensure_hf_token_env() -> None: | |
| # The repo docs mention HUGGING_FACE_HUB_TOKEN; huggingface_hub also looks for HF_TOKEN. | |
| tok = os.environ.get("HUGGING_FACE_HUB_TOKEN") or os.environ.get("HF_TOKEN") | |
| if tok and not os.environ.get("HF_TOKEN"): | |
| os.environ["HF_TOKEN"] = tok | |
| _ensure_hf_token_env() | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32 | |
| MODEL_NAME = os.environ.get("STARVECTOR_MODEL", "starvector/starvector-1b-im2svg") | |
| MAX_LENGTH = int(os.environ.get("STARVECTOR_MAX_LENGTH", "4000")) | |
| # Load model with proper dtype for GPU/CPU compatibility | |
| print(f"Starting StarVector Space on device={DEVICE} dtype={DTYPE} model={MODEL_NAME}", flush=True) | |
| starvector = StarVectorForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| torch_dtype=DTYPE, | |
| device_map="auto" if DEVICE == "cuda" else None | |
| ) | |
| if DEVICE != "cuda": | |
| starvector = starvector.to(device=DEVICE) | |
| starvector.eval() | |
| def im2svg(image: Image.Image, max_length: int = MAX_LENGTH): | |
| if image is None: | |
| return "", None, "No image provided." | |
| try: | |
| # Preprocess image | |
| image_tensor = starvector.process_images([image])[0] | |
| image_tensor = image_tensor.to(device=DEVICE) | |
| if DEVICE == "cuda": | |
| image_tensor = image_tensor.to(dtype=DTYPE) | |
| batch = {"image": image_tensor} | |
| # Generate raw svg | |
| raw_svg = starvector.generate_im2svg(batch, max_length=max_length)[0] | |
| # Clean + rasterize preview | |
| svg, raster = process_and_rasterize_svg(raw_svg) | |
| # raster may be a PIL image | |
| return svg, raster, "" | |
| except Exception: | |
| return "", None, traceback.format_exc() | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| "# StarVector (GPU)\n" | |
| "Upload an icon/logo/diagram-like image and generate SVG code.\n\n" | |
| "Notes:\n" | |
| "- This Space requires a GPU and a Hugging Face token with access to the gated `bigcode/starcoderbase-1b` model.\n" | |
| "- Set Space secret `HUGGING_FACE_HUB_TOKEN`." | |
| ) | |
| with gr.Row(): | |
| inp = gr.Image(type="pil", label="Input image") | |
| preview = gr.Image(type="pil", label="Rasterized preview") | |
| max_len = gr.Slider(256, 8000, value=MAX_LENGTH, step=128, label="max_length") | |
| out_svg = gr.Code(language="xml", label="SVG") | |
| err = gr.Textbox(label="Error", visible=True) | |
| btn = gr.Button("Generate SVG") | |
| btn.click(im2svg, inputs=[inp, max_len], outputs=[out_svg, preview, err]) | |
| if __name__ == "__main__": | |
| port = int(os.environ.get("PORT", "7860")) | |
| demo.queue(concurrency_count=1) | |
| demo.launch(server_name="0.0.0.0", server_port=port) | |