imgtosvg / app.py
pabbelt's picture
Update app.py
9b594c8 verified
# app.py
import os, io, base64
import gradio as gr
from PIL import Image
import torch
from transformers import AutoTokenizer, AutoImageProcessor, AutoModelForCausalLM
MODEL_ID = "starvector/starvector-8b-im2svg"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False, trust_remote_code=True)
image_processor = AutoImageProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
device_map="auto",
trust_remote_code=True
).eval()
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
def run_starvector(image: Image.Image | None, text: str) -> str:
text = text or ""
toks = tokenizer(text, return_tensors="pt", add_special_tokens=True)
batch = {"input_ids": toks.input_ids}
if image is not None:
pix = image_processor(images=image, return_tensors="pt").pixel_values
batch["pixel_values"] = pix
batch = {k: v.to(model.device) for k, v in batch.items()}
with torch.no_grad():
out = model.generate(
**batch,
max_new_tokens=2048,
temperature=0.2,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
return tokenizer.batch_decode(out, skip_special_tokens=True)[0]
with gr.Blocks(title="StarVector: Image/Text β†’ SVG") as demo:
gr.Markdown("# StarVector: Image/Text β†’ SVG")
img = gr.Image(type="pil", label="Upload image (optional)")
txt = gr.Textbox(label="Text prompt (optional)")
btn = gr.Button("Generate SVG")
code = gr.Code(label="SVG Output", language="xml")
btn.click(run_starvector, [img, txt], code)
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
demo.launch(server_name="0.0.0.0", server_port=port)