OmniSVG commited on
Commit
558fa67
·
verified ·
1 Parent(s): ac3151e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -11
app.py CHANGED
@@ -14,7 +14,7 @@ from decoder import SketchDecoder
14
  from transformers import AutoTokenizer, AutoProcessor
15
  from qwen_vl_utils import process_vision_info
16
  from tokenizer import SVGTokenizer
17
- import spaces # 引入 GPU 装饰器
18
 
19
  # 读取配置
20
  with open('config.yaml', 'r') as f:
@@ -53,19 +53,38 @@ def load_models():
53
  tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", padding_side="left")
54
  processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", padding_side="left")
55
 
56
- sketch_decoder = SketchDecoder()
 
57
  sketch_weight_path = hf_hub_download(repo_id="OmniSVG/OmniSVG", filename="pytorch_model.bin")
58
 
59
- # 加载权重
60
- sketch_decoder.load_state_dict(torch.load(sketch_weight_path, map_location="cpu"), strict=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
- if device.type == "cuda":
63
- sketch_decoder = sketch_decoder.to(device)
64
-
65
  sketch_decoder.eval()
66
  svg_tokenizer = SVGTokenizer('config.yaml')
67
 
68
- print(f"✅ Models loaded on {device}")
69
 
70
  def process_and_resize_image(image_input, target_size=(200, 200)):
71
  if isinstance(image_input, str):
@@ -264,7 +283,6 @@ def create_interface():
264
  ]
265
  example_images = get_example_images()
266
 
267
- # 删除了 theme 参数
268
  with gr.Blocks(title="OmniSVG Demo Page") as demo:
269
  gr.Markdown("# OmniSVG Demo Page")
270
  gr.Markdown("Generate SVG code from images or text descriptions")
@@ -279,7 +297,6 @@ def create_interface():
279
  image_generate_btn = gr.Button("Generate SVG", variant="primary")
280
 
281
  with gr.Column():
282
- # 删除了 show_copy_button=True
283
  image_svg_output = gr.Textbox(label="Generated SVG Code", lines=10, max_lines=20)
284
  image_png_preview = gr.Image(label="SVG Preview", type="pil")
285
 
@@ -293,7 +310,6 @@ def create_interface():
293
  text_generate_btn = gr.Button("Generate SVG", variant="primary")
294
 
295
  with gr.Column():
296
- # 删除了 show_copy_button=True
297
  text_svg_output = gr.Textbox(label="Generated SVG Code", lines=10, max_lines=20)
298
  text_png_preview = gr.Image(label="SVG Preview", type="pil")
299
 
 
14
  from transformers import AutoTokenizer, AutoProcessor
15
  from qwen_vl_utils import process_vision_info
16
  from tokenizer import SVGTokenizer
17
+ import spaces
18
 
19
  # 读取配置
20
  with open('config.yaml', 'r') as f:
 
53
  tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", padding_side="left")
54
  processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", padding_side="left")
55
 
56
+ # [关键修改] 1. 下载权重
57
+ print("Downloading weights...")
58
  sketch_weight_path = hf_hub_download(repo_id="OmniSVG/OmniSVG", filename="pytorch_model.bin")
59
 
60
+ # [关键修改] 2. 加载权重到 CPU 内存
61
+ print("Loading state dict...")
62
+ full_state_dict = torch.load(sketch_weight_path, map_location="cpu")
63
+
64
+ # [关键修改] 3. 处理权重键名 (Strip 'transformer.' prefix)
65
+ # 原始权重是 SketchDecoder 保存的,带有 "transformer." 前缀
66
+ # 我们要直接传给内部的 Qwen 模型,所以需要去掉这个前缀
67
+ qwen_state_dict = {}
68
+ for key in list(full_state_dict.keys()):
69
+ if key.startswith("transformer."):
70
+ new_key = key.replace("transformer.", "", 1)
71
+ qwen_state_dict[new_key] = full_state_dict.pop(key) # pop saving memory
72
+
73
+ del full_state_dict # 释放内存
74
+ gc.collect()
75
+
76
+ # [关键修改] 4. 初始化 Decoder 并传入处理后的权重
77
+ # 此时 bitsandbytes 会在初始化时直接将这些权重量化为 4-bit
78
+ print("Initializing quantized model with custom weights...")
79
+ sketch_decoder = SketchDecoder(state_dict=qwen_state_dict)
80
+
81
+ # 此时模型已经在 GPU 上(由 device_map="auto" 处理)
82
+ # 且权重已经是 OmniSVG 的权重了,无需再次 load_state_dict
83
 
 
 
 
84
  sketch_decoder.eval()
85
  svg_tokenizer = SVGTokenizer('config.yaml')
86
 
87
+ print(f"✅ Models loaded successfully!")
88
 
89
  def process_and_resize_image(image_input, target_size=(200, 200)):
90
  if isinstance(image_input, str):
 
283
  ]
284
  example_images = get_example_images()
285
 
 
286
  with gr.Blocks(title="OmniSVG Demo Page") as demo:
287
  gr.Markdown("# OmniSVG Demo Page")
288
  gr.Markdown("Generate SVG code from images or text descriptions")
 
297
  image_generate_btn = gr.Button("Generate SVG", variant="primary")
298
 
299
  with gr.Column():
 
300
  image_svg_output = gr.Textbox(label="Generated SVG Code", lines=10, max_lines=20)
301
  image_png_preview = gr.Image(label="SVG Preview", type="pil")
302
 
 
310
  text_generate_btn = gr.Button("Generate SVG", variant="primary")
311
 
312
  with gr.Column():
 
313
  text_svg_output = gr.Textbox(label="Generated SVG Code", lines=10, max_lines=20)
314
  text_png_preview = gr.Image(label="SVG Preview", type="pil")
315