|
|
|
|
|
import os, io, base64 |
|
|
import gradio as gr |
|
|
from PIL import Image |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoImageProcessor, AutoModelForCausalLM |
|
|
|
|
|
MODEL_ID = "starvector/starvector-8b-im2svg" |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False, trust_remote_code=True) |
|
|
image_processor = AutoImageProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
MODEL_ID, |
|
|
torch_dtype=torch.bfloat16, |
|
|
low_cpu_mem_usage=True, |
|
|
device_map="auto", |
|
|
trust_remote_code=True |
|
|
).eval() |
|
|
|
|
|
if tokenizer.pad_token_id is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
def run_starvector(image: Image.Image | None, text: str) -> str: |
|
|
text = text or "" |
|
|
toks = tokenizer(text, return_tensors="pt", add_special_tokens=True) |
|
|
batch = {"input_ids": toks.input_ids} |
|
|
if image is not None: |
|
|
pix = image_processor(images=image, return_tensors="pt").pixel_values |
|
|
batch["pixel_values"] = pix |
|
|
batch = {k: v.to(model.device) for k, v in batch.items()} |
|
|
with torch.no_grad(): |
|
|
out = model.generate( |
|
|
**batch, |
|
|
max_new_tokens=2048, |
|
|
temperature=0.2, |
|
|
do_sample=False, |
|
|
pad_token_id=tokenizer.pad_token_id, |
|
|
eos_token_id=tokenizer.eos_token_id, |
|
|
) |
|
|
return tokenizer.batch_decode(out, skip_special_tokens=True)[0] |
|
|
|
|
|
with gr.Blocks(title="StarVector: Image/Text β SVG") as demo: |
|
|
gr.Markdown("# StarVector: Image/Text β SVG") |
|
|
img = gr.Image(type="pil", label="Upload image (optional)") |
|
|
txt = gr.Textbox(label="Text prompt (optional)") |
|
|
btn = gr.Button("Generate SVG") |
|
|
code = gr.Code(label="SVG Output", language="xml") |
|
|
btn.click(run_starvector, [img, txt], code) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
port = int(os.environ.get("PORT", 7860)) |
|
|
demo.launch(server_name="0.0.0.0", server_port=port) |
|
|
|