File size: 1,906 Bytes
9b594c8
 
16fa907
 
 
9b594c8
16fa907
e581ee6
 
9b594c8
 
e581ee6
 
 
 
31b1a0b
16fa907
e581ee6
 
31b1a0b
 
 
9b594c8
e581ee6
31b1a0b
 
e581ee6
31b1a0b
 
9b594c8
e581ee6
 
9b594c8
e581ee6
 
31b1a0b
 
9b594c8
e581ee6
9b594c8
16fa907
e581ee6
 
31b1a0b
 
e581ee6
31b1a0b
 
16fa907
e581ee6
ee01f42
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# app.py
import os, io, base64
import gradio as gr
from PIL import Image
import torch
from transformers import AutoTokenizer, AutoImageProcessor, AutoModelForCausalLM

MODEL_ID = "starvector/starvector-8b-im2svg"

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False, trust_remote_code=True)
image_processor = AutoImageProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    device_map="auto",
    trust_remote_code=True
).eval()

if tokenizer.pad_token_id is None:
    tokenizer.pad_token = tokenizer.eos_token

def run_starvector(image: Image.Image | None, text: str) -> str:
    text = text or ""
    toks = tokenizer(text, return_tensors="pt", add_special_tokens=True)
    batch = {"input_ids": toks.input_ids}
    if image is not None:
        pix = image_processor(images=image, return_tensors="pt").pixel_values
        batch["pixel_values"] = pix
    batch = {k: v.to(model.device) for k, v in batch.items()}
    with torch.no_grad():
        out = model.generate(
            **batch,
            max_new_tokens=2048,
            temperature=0.2,
            do_sample=False,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )
    return tokenizer.batch_decode(out, skip_special_tokens=True)[0]

with gr.Blocks(title="StarVector: Image/Text → SVG") as demo:
    gr.Markdown("# StarVector: Image/Text → SVG")
    img = gr.Image(type="pil", label="Upload image (optional)")
    txt = gr.Textbox(label="Text prompt (optional)")
    btn = gr.Button("Generate SVG")
    code = gr.Code(label="SVG Output", language="xml")
    btn.click(run_starvector, [img, txt], code)

if __name__ == "__main__":
    port = int(os.environ.get("PORT", 7860))
    demo.launch(server_name="0.0.0.0", server_port=port)