Spaces:
Running on Zero
Running on Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -14,7 +14,7 @@ from decoder import SketchDecoder
|
|
| 14 |
from transformers import AutoTokenizer, AutoProcessor
|
| 15 |
from qwen_vl_utils import process_vision_info
|
| 16 |
from tokenizer import SVGTokenizer
|
| 17 |
-
import spaces
|
| 18 |
|
| 19 |
# 读取配置
|
| 20 |
with open('config.yaml', 'r') as f:
|
|
@@ -53,19 +53,38 @@ def load_models():
|
|
| 53 |
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", padding_side="left")
|
| 54 |
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", padding_side="left")
|
| 55 |
|
| 56 |
-
|
|
|
|
| 57 |
sketch_weight_path = hf_hub_download(repo_id="OmniSVG/OmniSVG", filename="pytorch_model.bin")
|
| 58 |
|
| 59 |
-
# 加载权重
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
-
if device.type == "cuda":
|
| 63 |
-
sketch_decoder = sketch_decoder.to(device)
|
| 64 |
-
|
| 65 |
sketch_decoder.eval()
|
| 66 |
svg_tokenizer = SVGTokenizer('config.yaml')
|
| 67 |
|
| 68 |
-
print(f"✅ Models loaded
|
| 69 |
|
| 70 |
def process_and_resize_image(image_input, target_size=(200, 200)):
|
| 71 |
if isinstance(image_input, str):
|
|
@@ -264,7 +283,6 @@ def create_interface():
|
|
| 264 |
]
|
| 265 |
example_images = get_example_images()
|
| 266 |
|
| 267 |
-
# 删除了 theme 参数
|
| 268 |
with gr.Blocks(title="OmniSVG Demo Page") as demo:
|
| 269 |
gr.Markdown("# OmniSVG Demo Page")
|
| 270 |
gr.Markdown("Generate SVG code from images or text descriptions")
|
|
@@ -279,7 +297,6 @@ def create_interface():
|
|
| 279 |
image_generate_btn = gr.Button("Generate SVG", variant="primary")
|
| 280 |
|
| 281 |
with gr.Column():
|
| 282 |
-
# 删除了 show_copy_button=True
|
| 283 |
image_svg_output = gr.Textbox(label="Generated SVG Code", lines=10, max_lines=20)
|
| 284 |
image_png_preview = gr.Image(label="SVG Preview", type="pil")
|
| 285 |
|
|
@@ -293,7 +310,6 @@ def create_interface():
|
|
| 293 |
text_generate_btn = gr.Button("Generate SVG", variant="primary")
|
| 294 |
|
| 295 |
with gr.Column():
|
| 296 |
-
# 删除了 show_copy_button=True
|
| 297 |
text_svg_output = gr.Textbox(label="Generated SVG Code", lines=10, max_lines=20)
|
| 298 |
text_png_preview = gr.Image(label="SVG Preview", type="pil")
|
| 299 |
|
|
|
|
| 14 |
from transformers import AutoTokenizer, AutoProcessor
|
| 15 |
from qwen_vl_utils import process_vision_info
|
| 16 |
from tokenizer import SVGTokenizer
|
| 17 |
+
import spaces
|
| 18 |
|
| 19 |
# 读取配置
|
| 20 |
with open('config.yaml', 'r') as f:
|
|
|
|
| 53 |
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", padding_side="left")
|
| 54 |
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", padding_side="left")
|
| 55 |
|
| 56 |
+
# [关键修改] 1. 下载权重
|
| 57 |
+
print("Downloading weights...")
|
| 58 |
sketch_weight_path = hf_hub_download(repo_id="OmniSVG/OmniSVG", filename="pytorch_model.bin")
|
| 59 |
|
| 60 |
+
# [关键修改] 2. 加载权重到 CPU 内存
|
| 61 |
+
print("Loading state dict...")
|
| 62 |
+
full_state_dict = torch.load(sketch_weight_path, map_location="cpu")
|
| 63 |
+
|
| 64 |
+
# [关键修改] 3. 处理权重键名 (Strip 'transformer.' prefix)
|
| 65 |
+
# 原始权重是 SketchDecoder 保存的,带有 "transformer." 前缀
|
| 66 |
+
# 我们要直接传给内部的 Qwen 模型,所以需要去掉这个前缀
|
| 67 |
+
qwen_state_dict = {}
|
| 68 |
+
for key in list(full_state_dict.keys()):
|
| 69 |
+
if key.startswith("transformer."):
|
| 70 |
+
new_key = key.replace("transformer.", "", 1)
|
| 71 |
+
qwen_state_dict[new_key] = full_state_dict.pop(key) # pop saving memory
|
| 72 |
+
|
| 73 |
+
del full_state_dict # 释放内存
|
| 74 |
+
gc.collect()
|
| 75 |
+
|
| 76 |
+
# [关键修改] 4. 初始化 Decoder 并传入处理后的权重
|
| 77 |
+
# 此时 bitsandbytes 会在初始化时直接将这些权重量化为 4-bit
|
| 78 |
+
print("Initializing quantized model with custom weights...")
|
| 79 |
+
sketch_decoder = SketchDecoder(state_dict=qwen_state_dict)
|
| 80 |
+
|
| 81 |
+
# 此时模型已经在 GPU 上(由 device_map="auto" 处理)
|
| 82 |
+
# 且权重已经是 OmniSVG 的权重了,无需再次 load_state_dict
|
| 83 |
|
|
|
|
|
|
|
|
|
|
| 84 |
sketch_decoder.eval()
|
| 85 |
svg_tokenizer = SVGTokenizer('config.yaml')
|
| 86 |
|
| 87 |
+
print(f"✅ Models loaded successfully!")
|
| 88 |
|
| 89 |
def process_and_resize_image(image_input, target_size=(200, 200)):
|
| 90 |
if isinstance(image_input, str):
|
|
|
|
| 283 |
]
|
| 284 |
example_images = get_example_images()
|
| 285 |
|
|
|
|
| 286 |
with gr.Blocks(title="OmniSVG Demo Page") as demo:
|
| 287 |
gr.Markdown("# OmniSVG Demo Page")
|
| 288 |
gr.Markdown("Generate SVG code from images or text descriptions")
|
|
|
|
| 297 |
image_generate_btn = gr.Button("Generate SVG", variant="primary")
|
| 298 |
|
| 299 |
with gr.Column():
|
|
|
|
| 300 |
image_svg_output = gr.Textbox(label="Generated SVG Code", lines=10, max_lines=20)
|
| 301 |
image_png_preview = gr.Image(label="SVG Preview", type="pil")
|
| 302 |
|
|
|
|
| 310 |
text_generate_btn = gr.Button("Generate SVG", variant="primary")
|
| 311 |
|
| 312 |
with gr.Column():
|
|
|
|
| 313 |
text_svg_output = gr.Textbox(label="Generated SVG Code", lines=10, max_lines=20)
|
| 314 |
text_png_preview = gr.Image(label="SVG Preview", type="pil")
|
| 315 |
|