File size: 1,906 Bytes
9b594c8 16fa907 9b594c8 16fa907 e581ee6 9b594c8 e581ee6 31b1a0b 16fa907 e581ee6 31b1a0b 9b594c8 e581ee6 31b1a0b e581ee6 31b1a0b 9b594c8 e581ee6 9b594c8 e581ee6 31b1a0b 9b594c8 e581ee6 9b594c8 16fa907 e581ee6 31b1a0b e581ee6 31b1a0b 16fa907 e581ee6 ee01f42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
# app.py
import os, io, base64
import gradio as gr
from PIL import Image
import torch
from transformers import AutoTokenizer, AutoImageProcessor, AutoModelForCausalLM
MODEL_ID = "starvector/starvector-8b-im2svg"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False, trust_remote_code=True)
image_processor = AutoImageProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
device_map="auto",
trust_remote_code=True
).eval()
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
def run_starvector(image: Image.Image | None, text: str) -> str:
text = text or ""
toks = tokenizer(text, return_tensors="pt", add_special_tokens=True)
batch = {"input_ids": toks.input_ids}
if image is not None:
pix = image_processor(images=image, return_tensors="pt").pixel_values
batch["pixel_values"] = pix
batch = {k: v.to(model.device) for k, v in batch.items()}
with torch.no_grad():
out = model.generate(
**batch,
max_new_tokens=2048,
temperature=0.2,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
return tokenizer.batch_decode(out, skip_special_tokens=True)[0]
with gr.Blocks(title="StarVector: Image/Text → SVG") as demo:
gr.Markdown("# StarVector: Image/Text → SVG")
img = gr.Image(type="pil", label="Upload image (optional)")
txt = gr.Textbox(label="Text prompt (optional)")
btn = gr.Button("Generate SVG")
code = gr.Code(label="SVG Output", language="xml")
btn.click(run_starvector, [img, txt], code)
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
demo.launch(server_name="0.0.0.0", server_port=port)
|