pabbelt commited on
Commit
9b594c8
·
verified ·
1 Parent(s): 375543c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -34
app.py CHANGED
@@ -1,26 +1,14 @@
1
- import io, base64
 
2
  import gradio as gr
3
- import spaces
4
  from PIL import Image
5
  import torch
6
- from transformers import (
7
- AutoModelForCausalLM,
8
- AutoTokenizer,
9
- AutoImageProcessor,
10
- )
11
 
12
  MODEL_ID = "starvector/starvector-8b-im2svg"
13
 
14
- # Load separate components explicitly
15
- tokenizer = AutoTokenizer.from_pretrained(
16
- MODEL_ID,
17
- use_fast=False, # ensure GPT2Tokenizer (python) not Siglip fast
18
- trust_remote_code=True
19
- )
20
- image_processor = AutoImageProcessor.from_pretrained(
21
- MODEL_ID,
22
- trust_remote_code=True
23
- )
24
  model = AutoModelForCausalLM.from_pretrained(
25
  MODEL_ID,
26
  torch_dtype=torch.bfloat16,
@@ -29,37 +17,28 @@ model = AutoModelForCausalLM.from_pretrained(
29
  trust_remote_code=True
30
  ).eval()
31
 
32
- # Safety: some GPT2 checkpoints have no pad token
33
  if tokenizer.pad_token_id is None:
34
  tokenizer.pad_token = tokenizer.eos_token
35
 
36
- def _prep_inputs(image: Image.Image | None, text: str):
37
  text = text or ""
38
  toks = tokenizer(text, return_tensors="pt", add_special_tokens=True)
39
  batch = {"input_ids": toks.input_ids}
40
  if image is not None:
41
  pix = image_processor(images=image, return_tensors="pt").pixel_values
42
  batch["pixel_values"] = pix
43
- return batch
44
-
45
- @spaces.GPU(duration=180)
46
- def run_starvector(image: Image.Image | None, text: str) -> str:
47
- inputs = _prep_inputs(image, text)
48
- # move tensors to the model device(s)
49
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
50
  with torch.no_grad():
51
  out = model.generate(
52
- **inputs,
53
  max_new_tokens=2048,
54
  temperature=0.2,
55
  do_sample=False,
56
  pad_token_id=tokenizer.pad_token_id,
57
- eos_token_id=tokenizer.eos_token_id
58
  )
59
- svg = tokenizer.batch_decode(out, skip_special_tokens=True)[0]
60
- return svg
61
 
62
- # --- Your UI wiring (example) ---
63
  with gr.Blocks(title="StarVector: Image/Text → SVG") as demo:
64
  gr.Markdown("# StarVector: Image/Text → SVG")
65
  img = gr.Image(type="pil", label="Upload image (optional)")
@@ -68,9 +47,6 @@ with gr.Blocks(title="StarVector: Image/Text → SVG") as demo:
68
  code = gr.Code(label="SVG Output", language="xml")
69
  btn.click(run_starvector, [img, txt], code)
70
 
71
- import os
72
  if __name__ == "__main__":
73
  port = int(os.environ.get("PORT", 7860))
74
- # Gradio will listen on 0.0.0.0 for Spaces
75
  demo.launch(server_name="0.0.0.0", server_port=port)
76
-
 
1
+ # app.py
2
+ import os, io, base64
3
  import gradio as gr
 
4
  from PIL import Image
5
  import torch
6
+ from transformers import AutoTokenizer, AutoImageProcessor, AutoModelForCausalLM
 
 
 
 
7
 
8
  MODEL_ID = "starvector/starvector-8b-im2svg"
9
 
10
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False, trust_remote_code=True)
11
+ image_processor = AutoImageProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
 
 
 
 
 
 
 
 
12
  model = AutoModelForCausalLM.from_pretrained(
13
  MODEL_ID,
14
  torch_dtype=torch.bfloat16,
 
17
  trust_remote_code=True
18
  ).eval()
19
 
 
20
  if tokenizer.pad_token_id is None:
21
  tokenizer.pad_token = tokenizer.eos_token
22
 
23
+ def run_starvector(image: Image.Image | None, text: str) -> str:
24
  text = text or ""
25
  toks = tokenizer(text, return_tensors="pt", add_special_tokens=True)
26
  batch = {"input_ids": toks.input_ids}
27
  if image is not None:
28
  pix = image_processor(images=image, return_tensors="pt").pixel_values
29
  batch["pixel_values"] = pix
30
+ batch = {k: v.to(model.device) for k, v in batch.items()}
 
 
 
 
 
 
31
  with torch.no_grad():
32
  out = model.generate(
33
+ **batch,
34
  max_new_tokens=2048,
35
  temperature=0.2,
36
  do_sample=False,
37
  pad_token_id=tokenizer.pad_token_id,
38
+ eos_token_id=tokenizer.eos_token_id,
39
  )
40
+ return tokenizer.batch_decode(out, skip_special_tokens=True)[0]
 
41
 
 
42
  with gr.Blocks(title="StarVector: Image/Text → SVG") as demo:
43
  gr.Markdown("# StarVector: Image/Text → SVG")
44
  img = gr.Image(type="pil", label="Upload image (optional)")
 
47
  code = gr.Code(label="SVG Output", language="xml")
48
  btn.click(run_starvector, [img, txt], code)
49
 
 
50
  if __name__ == "__main__":
51
  port = int(os.environ.get("PORT", 7860))
 
52
  demo.launch(server_name="0.0.0.0", server_port=port)