pabbelt commited on
Commit
e581ee6
·
verified ·
1 Parent(s): 845da2a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -58
app.py CHANGED
@@ -1,69 +1,57 @@
 
1
  import gradio as gr
 
2
  from PIL import Image
3
- from transformers import AutoModelForCausalLM
4
- from starvector.data.util import process_and_rasterize_svg
5
  import torch
 
6
 
7
- print("Loading model...")
8
- model_name = "starvector/starvector-8b-im2svg"
9
- starvector = AutoModelForCausalLM.from_pretrained(
10
- model_name,
11
- torch_dtype=torch.float16,
 
 
 
12
  trust_remote_code=True
13
- )
14
- processor = starvector.model.processor
15
- tokenizer = starvector.model.svg_transformer.tokenizer
 
 
 
 
16
 
17
- # Move to GPU
18
- device = "cuda" if torch.cuda.is_available() else "cpu"
19
- starvector.to(device)
20
- starvector.eval()
21
- print(f"Model loaded on {device}!")
 
 
 
 
 
 
 
 
 
22
 
23
- def convert_image_to_svg(image_pil):
24
- """Convert uploaded image to SVG"""
25
- try:
26
- # Process image
27
- image = processor(image_pil, return_tensors="pt")['pixel_values'].to(device)
28
- if image.shape[0] != 1:
29
- image = image.unsqueeze(0)
30
-
31
- batch = {"image": image}
32
-
33
- # Generate SVG
34
- raw_svg = starvector.generate_im2svg(batch, max_length=4000)[0]
35
- svg, raster_image = process_and_rasterize_svg(raw_svg)
36
-
37
- # Save SVG to file
38
- svg_path = "output.svg"
39
- with open(svg_path, 'w') as f:
40
- f.write(svg)
41
-
42
- return svg_path, raster_image
43
-
44
- except Exception as e:
45
- return None, f"Error: {str(e)}"
46
 
47
- # Create Gradio interface
48
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
49
- gr.Markdown("# 🎨 Image to SVG Converter")
50
- gr.Markdown("Convert your images to SVG format using StarVector AI")
51
-
52
  with gr.Row():
53
- with gr.Column():
54
- input_image = gr.Image(type="pil", label="Upload Image")
55
- convert_btn = gr.Button("Convert to SVG", variant="primary")
56
-
57
- with gr.Column():
58
- output_file = gr.File(label="Download SVG")
59
- output_preview = gr.Image(label="Preview")
60
-
61
- convert_btn.click(
62
- fn=convert_image_to_svg,
63
- inputs=input_image,
64
- outputs=[output_file, output_preview]
65
  )
66
-
67
- gr.Markdown("### Example: Upload a PNG/JPG image and get an SVG file!")
68
 
69
- demo.launch()
 
 
1
+ import io, base64
2
  import gradio as gr
3
+ import spaces
4
  from PIL import Image
 
 
5
  import torch
6
+ from transformers import AutoProcessor, AutoModelForCausalLM
7
 
8
+ MODEL_ID = "starvector/starvector-8b-im2svg"
9
+
10
+ # Load once at startup; ZeroGPU allocates GPU when the decorated function runs.
11
+ processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
12
+ model = AutoModelForCausalLM.from_pretrained(
13
+ MODEL_ID,
14
+ torch_dtype=torch.bfloat16,
15
+ low_cpu_mem_usage=True,
16
  trust_remote_code=True
17
+ ).eval()
18
+
19
+ def _prep_inputs(image: Image.Image | None, text: str):
20
+ text = text or ""
21
+ if image is not None:
22
+ return processor(images=image, text=text, return_tensors="pt")
23
+ return processor(text=text, return_tensors="pt")
24
 
25
+ @spaces.GPU(duration=180) # request ZeroGPU for up to 180s per call
26
+ def run_starvector(image: Image.Image | None, text: str) -> str:
27
+ inputs = _prep_inputs(image, text)
28
+ # Move only tensors to the GPU at call time
29
+ inputs = {k: v.to("cuda") if hasattr(v, "to") else v for k, v in inputs.items()}
30
+ with torch.no_grad():
31
+ out = model.generate(
32
+ **inputs,
33
+ max_new_tokens=2048,
34
+ temperature=0.2,
35
+ do_sample=False
36
+ )
37
+ svg = processor.batch_decode(out, skip_special_tokens=True)[0]
38
+ return svg
39
 
40
+ def preview_svg(svg_code: str) -> str:
41
+ # Safe inline preview wrapper
42
+ return f"<div style='height:480px;overflow:auto;border:1px solid #ccc'><pre>{gr.utils.sanitize_html(svg_code)}</pre></div>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
+ with gr.Blocks(title="StarVector: Image/Text → SVG") as demo:
45
+ gr.Markdown("# StarVector: Image/Text → SVG")
 
 
 
46
  with gr.Row():
47
+ img = gr.Image(type="pil", label="Upload image (optional)")
48
+ txt = gr.Textbox(label="Text prompt (optional)")
49
+ btn = gr.Button("Generate SVG")
50
+ svg_code = gr.Code(label="SVG Output", language="xml")
51
+ svg_render = gr.HTML(label="Preview")
52
+ btn.click(fn=run_starvector, inputs=[img, txt], outputs=svg_code).then(
53
+ fn=preview_svg, inputs=svg_code, outputs=svg_render
 
 
 
 
 
54
  )
 
 
55
 
56
+ if __name__ == "__main__":
57
+ demo.launch()