OmniSVG commited on
Commit
01e90e4
·
verified ·
1 Parent(s): 992bf0d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1284 -261
app.py CHANGED
@@ -9,28 +9,519 @@ import argparse
9
  import gc
10
  import yaml
11
  import glob
12
- import shutil
13
- from huggingface_hub import hf_hub_download, snapshot_download
 
 
14
  from decoder import SketchDecoder
15
  from transformers import AutoTokenizer, AutoProcessor
16
  from qwen_vl_utils import process_vision_info
17
  from tokenizer import SVGTokenizer
18
- import spaces
19
 
20
- # 读取配置
21
- with open('config.yaml', 'r') as f:
22
  config = yaml.safe_load(f)
23
 
24
- # 全局变量
 
 
 
25
  tokenizer = None
26
  processor = None
27
  sketch_decoder = None
28
  svg_tokenizer = None
29
- device = "cpu"
30
 
31
- # System prompt
32
- SYSTEM_PROMPT = "You are a multimodal SVG generation assistant capable of generating SVG code from both text descriptions and images."
 
 
 
 
 
 
33
  SUPPORTED_FORMATS = ['.png', '.jpg', '.jpeg', '.webp', '.bmp', '.gif']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  def parse_args():
36
  parser = argparse.ArgumentParser(description='SVG Generator Service')
@@ -38,314 +529,846 @@ def parse_args():
38
  parser.add_argument('--port', type=int, default=7860)
39
  parser.add_argument('--share', action='store_true')
40
  parser.add_argument('--debug', action='store_true')
 
 
41
  return parser.parse_args()
42
 
43
- def load_models():
44
- """Load models safely (Lazy Loading with Model Construction)"""
45
- global tokenizer, processor, sketch_decoder, svg_tokenizer, device
 
46
 
47
- if sketch_decoder is not None:
48
- return
 
 
 
 
49
 
50
- print("🚀 Loading models inside GPU container...")
51
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- if tokenizer is None:
54
- # 1. 准备本地模型目录
55
- local_model_dir = "custom_model_build"
56
-
57
- # 只有当目录里没有权重文件时才执行构建
58
- if not os.path.exists(os.path.join(local_model_dir, "pytorch_model.bin")):
59
- print("🛠️ Building custom model directory...")
60
- os.makedirs(local_model_dir, exist_ok=True)
61
-
62
- # (A) 下载 Qwen 的配置文件
63
- print("Downloading Qwen configurations...")
64
- snapshot_download(
65
- repo_id="Qwen/Qwen2.5-VL-3B-Instruct",
66
- local_dir=local_model_dir,
67
- allow_patterns=["*.json", "*.txt", "*.py"], # 这会下载 index.json,下面我们会删掉它
68
- ignore_patterns=["*.safetensors", "*.bin", "*.pt"]
69
- )
70
-
71
- # (B) 下载 OmniSVG 权重
72
- print("Downloading OmniSVG weights...")
73
- sketch_weight_path = hf_hub_download(repo_id="OmniSVG/OmniSVG", filename="pytorch_model.bin")
74
-
75
- # (C) 处理并保存权重
76
- print("Processing and saving weights...")
77
- state_dict = torch.load(sketch_weight_path, map_location="cpu")
78
-
79
- new_state_dict = {}
80
- for key in list(state_dict.keys()):
81
- if key.startswith("transformer."):
82
- new_key = key.replace("transformer.", "", 1)
83
- new_state_dict[new_key] = state_dict[key]
84
- else:
85
- new_state_dict[key] = state_dict[key]
86
-
87
- torch.save(new_state_dict, os.path.join(local_model_dir, "pytorch_model.bin"))
88
- del state_dict, new_state_dict
89
- gc.collect()
90
- print("✅ Custom model directory built successfully.")
91
-
92
- # [关键修复] 强制删除所有的 index.json 文件
93
- # 即使之前的运行残留了这些文件,这里也会把它们清理掉,防止报错 FileNotFoundError
94
- print("🧹 Cleaning up conflicting index files...")
95
- for index_file in glob.glob(os.path.join(local_model_dir, "*.index.json")):
96
- try:
97
- os.remove(index_file)
98
- print(f" Removed: {index_file}")
99
- except Exception as e:
100
- print(f" Failed to remove {index_file}: {e}")
101
 
102
- # 2. 从本地目录加载模型
103
- print("Initializing quantized model from local directory...")
104
- sketch_decoder = SketchDecoder(model_path=local_model_dir)
105
- sketch_decoder.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
- # 3. 加载 Tokenizer
108
- tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", padding_side="left")
109
- processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", padding_side="left")
110
- svg_tokenizer = SVGTokenizer('config.yaml')
 
 
 
 
111
 
112
- print(f"✅ Models loaded successfully on {device}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
- def process_and_resize_image(image_input, target_size=(200, 200)):
115
- if isinstance(image_input, str):
116
- image = Image.open(image_input)
117
- elif isinstance(image_input, Image.Image):
118
- image = image_input
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  else:
120
- image = Image.fromarray(image_input)
 
 
 
 
 
121
 
122
- image = image.resize(target_size, Image.Resampling.LANCZOS)
123
- return image
 
 
124
 
125
- def get_example_images():
126
- example_dir = "./examples"
127
- example_images = []
128
- if os.path.exists(example_dir):
129
- for ext in SUPPORTED_FORMATS:
130
- pattern = os.path.join(example_dir, f"*{ext}")
131
- example_images.extend(glob.glob(pattern))
132
- example_images.sort()
133
- return example_images
134
 
135
- def process_text_to_svg(text_description):
136
- messages = [{
137
- "role": "system",
138
- "content": SYSTEM_PROMPT
139
- }, {
140
- "role": "user",
141
- "content": [
142
- {"type": "text", "text": f"Task: text-to-svg\nDescription: {text_description}\nGenerate SVG code based on the above description."}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  ]
144
- }]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
- text_input = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
147
- inputs = processor(text=[text_input], truncation=True, return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
- input_ids = inputs['input_ids'].to(device)
150
- attention_mask = inputs['attention_mask'].to(device)
151
- pixel_values = None
152
- image_grid_thw = None
153
-
154
- return input_ids, attention_mask, pixel_values, image_grid_thw
155
-
156
- def process_image_to_svg(image_path):
157
- messages = [{
158
- "role": "system",
159
- "content": SYSTEM_PROMPT
160
- }, {
161
- "role": "user",
162
- "content": [
163
- {"type": "text", "text": f"Task: image-to-svg\nGenerate SVG code that accurately represents the following image."},
164
- {"type": "image", "image": image_path},
165
- ]
166
- }]
 
167
 
168
- text_input = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
169
- image_inputs, _ = process_vision_info(messages)
170
 
171
- inputs = processor(
172
- text=[text_input],
173
- images=image_inputs,
174
- truncation=True,
175
- return_tensors="pt"
176
- )
 
 
 
 
 
 
 
 
 
 
 
177
 
178
  input_ids = inputs['input_ids'].to(device)
179
  attention_mask = inputs['attention_mask'].to(device)
180
- pixel_values = inputs['pixel_values'].to(device) if 'pixel_values' in inputs else None
181
- image_grid_thw = inputs['image_grid_thw'].to(device) if 'image_grid_thw' in inputs else None
182
 
183
- return input_ids, attention_mask, pixel_values, image_grid_thw
184
-
185
- def generate_svg(input_ids, attention_mask, pixel_values=None, image_grid_thw=None, task_type="image-to-svg"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  try:
187
- gc.collect()
188
- if torch.cuda.is_available():
189
- torch.cuda.empty_cache()
190
-
191
- print(f"Generating SVG for {task_type}...")
192
-
193
- if task_type == "image-to-svg":
194
- gen_config = dict(
195
- do_sample=True,
196
- temperature=0.1,
197
- top_p=0.001,
198
- top_k=1,
199
- num_beams=5,
200
- repetition_penalty=1.05,
201
- )
202
- else:
203
- gen_config = dict(
204
- do_sample=True,
205
- temperature=0.8,
206
- top_p=0.95,
207
- top_k=50,
208
- repetition_penalty=1.05,
209
- early_stopping=True,
210
- )
211
 
212
- if torch.cuda.is_available():
213
- torch.cuda.synchronize()
214
-
215
- model_config = config['model']
216
- max_length = model_config['max_length']
217
- output_ids = torch.ones(1, max_length).long().to(device) * model_config['eos_token_id']
218
-
219
- with torch.no_grad():
220
- results = sketch_decoder.transformer.generate(
221
- input_ids=input_ids,
222
- attention_mask=attention_mask,
223
- pixel_values=pixel_values,
224
- image_grid_thw=image_grid_thw,
225
- max_new_tokens=max_length-1,
226
- num_return_sequences=1,
227
- bos_token_id=model_config['bos_token_id'],
228
- eos_token_id=model_config['eos_token_id'],
229
- pad_token_id=model_config['pad_token_id'],
230
- use_cache=True,
231
- **gen_config
232
- )
233
- results = results[:, :max_length-1]
234
- output_ids[:, :results.shape[1]] = results
235
-
236
- generated_xy, generated_colors = svg_tokenizer.process_generated_tokens(output_ids)
237
- svg_tensors = svg_tokenizer.raster_svg(generated_xy)
238
-
239
- if not svg_tensors or not svg_tensors[0]:
240
- return "Error: No valid SVG paths generated", None
241
 
242
- print('Creating SVG...')
243
- svg = svg_tokenizer.apply_colors_to_svg(svg_tensors[0], generated_colors)
244
- svg_str = svg.to_str()
 
 
 
 
 
 
 
 
245
 
246
- png_data = cairosvg.svg2png(bytestring=svg_str.encode('utf-8'))
247
- png_image = Image.open(io.BytesIO(png_data))
248
 
249
- return svg_str, png_image
250
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  except Exception as e:
252
- print(f"Generation error: {e}")
253
  import traceback
254
  traceback.print_exc()
255
- return f"Error: {e}", None
 
 
 
 
256
 
257
  @spaces.GPU
258
- def gradio_image_to_svg(image):
259
- load_models()
260
- if image is None: return "Please upload an image", None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
 
262
- processed_image = process_and_resize_image(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
  with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_file:
264
- processed_image.save(tmp_file.name, format='PNG')
265
  tmp_path = tmp_file.name
266
 
267
  try:
268
- input_ids, attention_mask, pixel_values, image_grid_thw = process_image_to_svg(tmp_path)
269
- svg_code, png_image = generate_svg(input_ids, attention_mask, pixel_values, image_grid_thw, "image-to-svg")
270
- return svg_code, png_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
  finally:
272
- if os.path.exists(tmp_path): os.unlink(tmp_path)
 
273
 
274
- @spaces.GPU
275
- def gradio_text_to_svg(text_description):
276
- load_models()
277
- if not text_description or text_description.strip() == "":
278
- return "Please enter a description", None
279
 
280
- input_ids, attention_mask, pixel_values, image_grid_thw = process_text_to_svg(text_description)
281
- svg_code, png_image = generate_svg(input_ids, attention_mask, pixel_values, image_grid_thw, "text-to-svg")
282
- return svg_code, png_image
 
 
 
 
 
283
 
284
  def create_interface():
 
 
 
285
  example_texts = [
286
- "A yellow t-shirt with a heart design represents love and positivity.",
287
- "A bright yellow emoji with a surprised expression and rosy cheeks hovers above a shadow.",
288
- "A brown coffee cup on a white saucer is seen from a top-down perspective.",
289
- "A cartoon firefighter in a red and yellow uniform represents safety and protection.",
290
- "A cute bunny face with pink ears rosy cheeks and a playful red tongue conveys charm and cheerfulness.",
291
- "A bearded man with orange hair and a mustache represents a hipster style portrait.",
292
- "A colorful ice cream popsicle with a hint of chocolate at the bottom on a stick.",
293
- "A light blue shopping bag features a white flower with a red center and scattered dots.",
294
- "A yellow phone icon and orange arrow on a blue smartphone screen symbolize an incoming call.",
295
- "A sad wilted flower with pink petals slumps over an orange cloud with a blue striped background.",
296
- "A cartoon character with dark blue hair and a mustache wears a blue suit against a light blue circular background.",
297
- "A blue bookmark icon with a white plus sign in the center.",
298
- "A computer monitor displays a bar graph with yellow orange and green bars.",
299
- "A blue and gray database icon is overlaid with a yellow star in the bottom right corner.",
300
- "An orange thermometer with a circular base represents temperature measurement.",
301
- "A green delivery truck icon with a checkmark symbolizing a completed delivery.",
302
- "A blue and gray microphone icon symbolizes audio recording or voice input.",
303
- "Cloud icon with an upward arrow symbolizes uploading or cloud storage.",
304
- "A brown chocolate bar is depicted in four square segments with a shiny glossy finish.",
305
- "A colorful moving truck icon with a red and orange cargo container.",
306
- "A light blue T-shirt icon is outlined with a bold blue border.",
307
- "A person in a blue shirt and dark pants stands with one hand in a pocket gesturing outward.",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
  ]
 
309
  example_images = get_example_images()
310
 
311
- with gr.Blocks(title="OmniSVG Demo Page") as demo:
312
- gr.Markdown("# OmniSVG Demo Page")
313
- gr.Markdown("Generate SVG code from images or text descriptions")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
 
315
  with gr.Tabs():
316
- with gr.TabItem("Image-to-SVG"):
317
- with gr.Row():
318
- with gr.Column():
319
- image_input = gr.Image(label="Input Image", type="pil", image_mode="RGBA")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
  if example_images:
321
- gr.Examples(examples=example_images, inputs=[image_input], label="Example Images", examples_per_page=12)
322
- image_generate_btn = gr.Button("Generate SVG", variant="primary")
323
 
324
- with gr.Column():
325
- image_svg_output = gr.Textbox(label="Generated SVG Code", lines=10, max_lines=20)
326
- image_png_preview = gr.Image(label="SVG Preview", type="pil")
 
 
 
 
 
 
 
 
327
 
328
- image_generate_btn.click(fn=gradio_image_to_svg, inputs=[image_input], outputs=[image_svg_output, image_png_preview], queue=True)
329
-
330
- with gr.TabItem("Text-to-SVG"):
331
- with gr.Row():
332
- with gr.Column():
333
- text_input = gr.Textbox(label="Description", placeholder="Enter SVG description...", lines=3)
334
- gr.Examples(examples=[[text] for text in example_texts], inputs=[text_input], label="Example Descriptions", examples_per_page=10)
335
- text_generate_btn = gr.Button("Generate SVG", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
 
337
- with gr.Column():
338
- text_svg_output = gr.Textbox(label="Generated SVG Code", lines=10, max_lines=20)
339
- text_png_preview = gr.Image(label="SVG Preview", type="pil")
 
 
 
 
 
 
 
 
 
 
340
 
341
- text_generate_btn.click(fn=gradio_text_to_svg, inputs=[text_input], outputs=[text_svg_output, text_png_preview], queue=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
 
343
- gr.Markdown("""## Usage Instructions...""")
344
- return demo
345
 
346
  if __name__ == "__main__":
347
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
 
348
  args = parse_args()
349
- print("Application starting... Models will be loaded on demand.")
 
 
 
 
 
 
 
 
 
 
 
 
350
  demo = create_interface()
351
- demo.launch(server_name=args.listen, server_port=args.port, share=args.share, debug=args.debug)
 
 
 
 
 
 
 
 
 
9
  import gc
10
  import yaml
11
  import glob
12
+ import numpy as np
13
+ import time
14
+ import threading
15
+
16
  from decoder import SketchDecoder
17
  from transformers import AutoTokenizer, AutoProcessor
18
  from qwen_vl_utils import process_vision_info
19
  from tokenizer import SVGTokenizer
 
20
 
21
+ # Load config
22
+ with open('./config.yaml', 'r') as f:
23
  config = yaml.safe_load(f)
24
 
25
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
26
+ DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
27
+
28
+ # Global Models
29
  tokenizer = None
30
  processor = None
31
  sketch_decoder = None
32
  svg_tokenizer = None
 
33
 
34
+ # Thread lock for model inference
35
+ generation_lock = threading.Lock()
36
+
37
+ # Constants
38
+ SYSTEM_PROMPT = """You are an expert SVG code generator.
39
+ Generate precise, valid SVG path commands that accurately represent the described scene or object.
40
+ Focus on capturing key shapes, spatial relationships, and visual composition."""
41
+
42
  SUPPORTED_FORMATS = ['.png', '.jpg', '.jpeg', '.webp', '.bmp', '.gif']
43
+ TARGET_IMAGE_SIZE = 448
44
+ BLACK_COLOR_TOKEN = 40012
45
+
46
+ # Task configurations with defaults
47
+ TASK_CONFIGS = {
48
+ "text-to-svg-icon": {
49
+ "default_temperature": 0.5,
50
+ "default_top_p": 0.88,
51
+ "default_top_k": 50,
52
+ "default_repetition_penalty": 1.05,
53
+ },
54
+ "text-to-svg-illustration": {
55
+ "default_temperature": 0.6,
56
+ "default_top_p": 0.90,
57
+ "default_top_k": 60,
58
+ "default_repetition_penalty": 1.03,
59
+ },
60
+ "image-to-svg": {
61
+ "default_temperature": 0.3,
62
+ "default_top_p": 0.90,
63
+ "default_top_k": 50,
64
+ "default_repetition_penalty": 1.05,
65
+ }
66
+ }
67
+
68
+ # Custom CSS
69
+ CUSTOM_CSS = """
70
+ /* Main container centering */
71
+ .gradio-container {
72
+ max-width: 1400px !important;
73
+ margin: 0 auto !important;
74
+ padding: 20px !important;
75
+ }
76
+
77
+ /* Header styling */
78
+ .header-container {
79
+ text-align: center;
80
+ margin-bottom: 20px;
81
+ padding: 20px;
82
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
83
+ border-radius: 16px;
84
+ color: white;
85
+ }
86
+
87
+ .header-container h1 {
88
+ margin: 0;
89
+ font-size: 2.5em;
90
+ font-weight: 700;
91
+ }
92
+
93
+ .header-container p {
94
+ margin: 10px 0 0 0;
95
+ opacity: 0.9;
96
+ font-size: 1.1em;
97
+ }
98
+
99
+ /* Tips section */
100
+ .tips-box {
101
+ background: #f8f9fa;
102
+ border-radius: 12px;
103
+ padding: 20px;
104
+ margin-bottom: 20px;
105
+ border: 1px solid #e0e0e0;
106
+ }
107
+
108
+ .tips-box h3 {
109
+ margin-top: 0;
110
+ color: #333;
111
+ border-bottom: 2px solid #667eea;
112
+ padding-bottom: 10px;
113
+ }
114
+
115
+ .tip-category {
116
+ background: white;
117
+ border-radius: 8px;
118
+ padding: 15px;
119
+ margin: 10px 0;
120
+ border-left: 4px solid #667eea;
121
+ }
122
+
123
+ .tip-category h4 {
124
+ margin: 0 0 10px 0;
125
+ color: #667eea;
126
+ }
127
+
128
+ .tip-category code {
129
+ background: #f0f0f0;
130
+ padding: 2px 6px;
131
+ border-radius: 4px;
132
+ font-size: 0.9em;
133
+ }
134
+
135
+ .example-prompt {
136
+ background: #e8f4fd;
137
+ padding: 10px;
138
+ border-radius: 6px;
139
+ margin: 8px 0;
140
+ font-style: italic;
141
+ font-size: 0.95em;
142
+ color: #333;
143
+ }
144
+
145
+ .red-tip {
146
+ color: #dc3545;
147
+ font-weight: 600;
148
+ }
149
+
150
+ .red-box {
151
+ background: #fff5f5;
152
+ border: 1px solid #ffcccc;
153
+ border-left: 4px solid #dc3545;
154
+ padding: 12px;
155
+ border-radius: 8px;
156
+ margin: 10px 0;
157
+ }
158
+
159
+ .red-box strong {
160
+ color: #dc3545;
161
+ }
162
+
163
+ .orange-box {
164
+ background: #fff8e6;
165
+ border: 1px solid #ffc107;
166
+ border-left: 4px solid #ff9800;
167
+ padding: 12px;
168
+ border-radius: 8px;
169
+ margin: 10px 0;
170
+ }
171
+
172
+ .orange-box strong {
173
+ color: #ff9800;
174
+ }
175
+
176
+ .green-box {
177
+ background: #e8f5e9;
178
+ border: 1px solid #81c784;
179
+ border-left: 4px solid #4caf50;
180
+ padding: 12px;
181
+ border-radius: 8px;
182
+ margin: 10px 0;
183
+ }
184
+
185
+ .green-box strong {
186
+ color: #4caf50;
187
+ }
188
+
189
+ /* Tab styling */
190
+ .tabs {
191
+ border-radius: 12px !important;
192
+ overflow: hidden;
193
+ }
194
+
195
+ .tabitem {
196
+ padding: 20px !important;
197
+ }
198
+
199
+ /* Button styling */
200
+ .primary-btn {
201
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
202
+ border: none !important;
203
+ font-weight: 600 !important;
204
+ padding: 12px 24px !important;
205
+ font-size: 1.1em !important;
206
+ }
207
+
208
+ .primary-btn:hover {
209
+ transform: translateY(-2px);
210
+ box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4);
211
+ }
212
+
213
+ /* Settings group */
214
+ .settings-group {
215
+ background: #f8f9fa;
216
+ border-radius: 10px;
217
+ padding: 15px;
218
+ margin: 10px 0;
219
+ }
220
+
221
+ .advanced-settings {
222
+ background: #f0f4f8;
223
+ border-radius: 8px;
224
+ padding: 12px;
225
+ margin-top: 10px;
226
+ }
227
+
228
+ /* Code output */
229
+ .code-output textarea {
230
+ font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', monospace !important;
231
+ font-size: 12px !important;
232
+ background: #1e1e1e !important;
233
+ color: #d4d4d4 !important;
234
+ border-radius: 8px !important;
235
+ }
236
+
237
+ /* Input image area */
238
+ .input-image {
239
+ border: 2px dashed #ccc;
240
+ border-radius: 12px;
241
+ transition: border-color 0.3s;
242
+ }
243
+
244
+ .input-image:hover {
245
+ border-color: #667eea;
246
+ }
247
+
248
+ /* Footer */
249
+ .footer {
250
+ text-align: center;
251
+ padding: 20px;
252
+ color: #666;
253
+ font-size: 0.9em;
254
+ }
255
+
256
+ /* Responsive adjustments */
257
+ @media (max-width: 768px) {
258
+ .gradio-container {
259
+ padding: 10px !important;
260
+ }
261
+ .header-container h1 {
262
+ font-size: 1.8em;
263
+ }
264
+ }
265
+ """
266
+
267
+ # Enhanced Tips HTML - Bilingual with Red Tips
268
+ TIPS_HTML = """
269
+ <div class="tips-box">
270
+ <h3>💡 Prompting Guide & Best Practices | 提示词指南与最佳实践</h3>
271
+
272
+ <!-- Critical Red Tips Section -->
273
+ <div class="red-box">
274
+ <strong>🔴 CRITICAL: Tips That WILL Improve Your Results | 关键:一定能提升效果的技巧</strong>
275
+ <ul style="margin: 8px 0 0 0; padding-left: 20px;">
276
+ <li style="color: #dc3545; font-weight: 600;">
277
+ <strong>🎲 Generate 4-8 candidates and pick the best one!</strong> Results vary significantly between generations - this is NORMAL!<br/>
278
+ <span style="color: #666; font-weight: normal;">生成4-8个候选结果并选择最好的!每次生成结果差异很大 - 这是正常的!</span>
279
+ </li>
280
+ <li style="color: #dc3545; font-weight: 600;">
281
+ <strong>📐 Use GEOMETRIC descriptions:</strong> "triangular roof", "circular head", "rectangular body", "curved tail"<br/>
282
+ <span style="color: #666; font-weight: normal;">使用几何描述:"三角形屋顶"、"圆形头部"、"矩形身体"、"弯曲尾巴"</span>
283
+ </li>
284
+ <li style="color: #dc3545; font-weight: 600;">
285
+ <strong>🎨 ALWAYS specify colors for EACH element:</strong> "black outline", "red roof", "blue shirt", "green grass"<br/>
286
+ <span style="color: #666; font-weight: normal;">始终为每个元素指定颜色:"黑色轮廓"、"红色屋顶"、"蓝色衬衫"、"绿色草地"</span>
287
+ </li>
288
+ <li style="color: #dc3545; font-weight: 600;">
289
+ <strong>⬜ Say "white background" or "on white background"</strong> for cleaner results<br/>
290
+ <span style="color: #666; font-weight: normal;">说"白色背景"或"在白色背景上"可获得更干净的结果</span>
291
+ </li>
292
+ <li style="color: #dc3545; font-weight: 600;">
293
+ <strong>📍 Describe position & orientation:</strong> "centrally positioned", "pointing upward", "facing right", "at the bottom"<br/>
294
+ <span style="color: #666; font-weight: normal;">描述位置和方向:"居中放置"、"指向上方"、"朝右"、"在底部"</span>
295
+ </li>
296
+ <li style="color: #dc3545; font-weight: 600;">
297
+ <strong>✂️ Keep it SIMPLE:</strong> Avoid complex sentences. Use short, clear phrases connected by commas.<br/>
298
+ <span style="color: #666; font-weight: normal;">保持简单:避免复杂句子。使用简短清晰的短语,用逗号连接。</span>
299
+ </li>
300
+ </ul>
301
+ </div>
302
+
303
+ <!-- Parameter Tuning Tips -->
304
+ <div class="orange-box">
305
+ <strong>🎛️ Parameter Tuning Guide | 参数调整指南</strong>
306
+ <table style="width: 100%; margin-top: 10px; border-collapse: collapse;">
307
+ <tr style="background: rgba(255,255,255,0.5);">
308
+ <th style="padding: 8px; text-align: left; border-bottom: 1px solid #ddd;">Scenario 场景</th>
309
+ <th style="padding: 8px; text-align: center; border-bottom: 1px solid #ddd;">Temperature</th>
310
+ <th style="padding: 8px; text-align: center; border-bottom: 1px solid #ddd;">Top-P</th>
311
+ <th style="padding: 8px; text-align: center; border-bottom: 1px solid #ddd;">Top-K</th>
312
+ <th style="padding: 8px; text-align: center; border-bottom: 1px solid #ddd;">Rep. Penalty</th>
313
+ </tr>
314
+ <tr>
315
+ <td style="padding: 8px;">Simple icons/shapes 简单图标</td>
316
+ <td style="padding: 8px; text-align: center;">0.3 - 0.5</td>
317
+ <td style="padding: 8px; text-align: center;">0.85 - 0.90</td>
318
+ <td style="padding: 8px; text-align: center;">40 - 50</td>
319
+ <td style="padding: 8px; text-align: center;">1.05</td>
320
+ </tr>
321
+ <tr style="background: rgba(255,255,255,0.3);">
322
+ <td style="padding: 8px;">Characters/Avatars 人物/头像</td>
323
+ <td style="padding: 8px; text-align: center;">0.5 - 0.7</td>
324
+ <td style="padding: 8px; text-align: center;">0.88 - 0.92</td>
325
+ <td style="padding: 8px; text-align: center;">50 - 70</td>
326
+ <td style="padding: 8px; text-align: center;">1.02 - 1.05</td>
327
+ </tr>
328
+ <tr>
329
+ <td style="padding: 8px;">Landscapes/Scenes 风景/场景</td>
330
+ <td style="padding: 8px; text-align: center;">0.5 - 0.7</td>
331
+ <td style="padding: 8px; text-align: center;">0.88 - 0.92</td>
332
+ <td style="padding: 8px; text-align: center;">50 - 60</td>
333
+ <td style="padding: 8px; text-align: center;">1.03</td>
334
+ </tr>
335
+ <tr style="background: rgba(255,255,255,0.3);">
336
+ <td style="padding: 8px;">Image-to-SVG 图像转SVG</td>
337
+ <td style="padding: 8px; text-align: center;">0.2 - 0.4</td>
338
+ <td style="padding: 8px; text-align: center;">0.88 - 0.92</td>
339
+ <td style="padding: 8px; text-align: center;">40 - 50</td>
340
+ <td style="padding: 8px; text-align: center;">1.05</td>
341
+ </tr>
342
+ </table>
343
+ <p style="margin: 10px 0 0 0; font-size: 0.9em; color: #856404;">
344
+ 💡 <strong>Tip:</strong> If results are too chaotic, lower temperature. If too simple/empty, raise it slightly.<br/>
345
+ 如果结果太混乱,降低温度。如果太简单/空白,稍微提高。
346
+ </p>
347
+ </div>
348
+
349
+ <div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(300px, 1fr)); gap: 15px; margin-top: 15px;">
350
+
351
+ <div class="tip-category">
352
+ <h4>🎯 Icons & Simple Shapes | 图标与简单形状</h4>
353
+ <p>Use clear geometric descriptions with explicit colors.<br/>
354
+ <span style="color: #666; font-size: 0.9em;">使用清晰的几何描述和明确的颜色。</span></p>
355
+ <div class="example-prompt">
356
+ "A black triangle pointing downward, centrally positioned on white background."<br/>
357
+ <span style="color: #666;">"黑色三角形,指向下方,居中在白色背景上。"</span>
358
+ </div>
359
+ <div class="example-prompt">
360
+ "A red heart shape with smooth curved edges, centered on white background."<br/>
361
+ <span style="color: #666;">"红色心形,边缘光滑弯曲,居中在白色背景上。"</span>
362
+ </div>
363
+ <p><strong>Keywords:</strong> <code>triangle</code> <code>circle</code> <code>arrow</code> <code>heart</code> <code>star</code> <code>centered</code></p>
364
+ </div>
365
+
366
+ <div class="tip-category">
367
+ <h4>👤 Characters & People | 人物角色</h4>
368
+ <p>Break down into simple geometric parts. Describe each body part with shape + color.<br/>
369
+ <span style="color: #666; font-size: 0.9em;">分解为简单几何部分。用形状+颜色描述每个身体部位。</span></p>
370
+ <div class="example-prompt">
371
+ "A simple person: round beige head, rectangular blue shirt body, two dark gray rectangular legs. Standing pose, arms at sides, flat colors, white background."<br/>
372
+ <span style="color: #666;">"简单人物:米色圆形头,蓝色矩形衬衫身体,两条深灰矩形腿。站立姿势,双臂下垂,平面颜色,白色背景。"</span>
373
+ </div>
374
+ <div class="example-prompt">
375
+ "A girl with long black hair, pink dress with triangular skirt shape, small circular face with dot eyes and curved smile. Simple cartoon style."<br/>
376
+ <span style="color: #666;">"长黑发女孩,粉色连衣裙(三角形裙摆),小圆脸配点状眼睛和弯曲微笑。简单卡通风格。"</span>
377
+ </div>
378
+ <p class="red-tip">⚠️ Keep poses SIMPLE: standing, sitting, waving. Avoid complex actions!</p>
379
+ </div>
380
+
381
+ <div class="tip-category">
382
+ <h4>😊 Avatars & Portraits | 头像与肖像</h4>
383
+ <p>Use circular frame, focus on face and upper body only.<br/>
384
+ <span style="color: #666; font-size: 0.9em;">使用圆形框架,只关注脸部和上半身。</span></p>
385
+ <div class="example-prompt">
386
+ "Circular avatar: person with short black hair, round face with two dot eyes and small curved smile, wearing blue collar shirt. Minimal style, white background."<br/>
387
+ <span style="color: #666;">"圆形头像:短黑发人物,圆脸配两个点状眼睛和小弯曲微笑,穿蓝色衬衫领子。极简风格,白色背景。"</span>
388
+ </div>
389
+ <div class="example-prompt">
390
+ "Profile avatar silhouette: black side view of head with short hair, facing right. Simple solid shape on white background."<br/>
391
+ <span style="color: #666;">"侧面头像剪影:黑色短发头部侧视图,朝右。简单实心形状,白色背景。"</span>
392
+ </div>
393
+ </div>
394
+
395
+ <div class="tip-category">
396
+ <h4>🏔️ Landscapes & Scenes | 风景与场景</h4>
397
+ <p>Layer elements from background to foreground. Specify color for EACH layer.<br/>
398
+ <span style="color: #666; font-size: 0.9em;">从背景到前景分层。为每层指定颜色。</span></p>
399
+ <div class="example-prompt">
400
+ "Layered landscape: light blue sky at top, gray triangular mountains in middle, dark green triangular pine trees at bottom. Flat colors, simple shapes."<br/>
401
+ <span style="color: #666;">"分层风景:顶部浅蓝天空,中间灰色三角山脉,底部深绿三角松树。平面颜色,简单形状。"</span>
402
+ </div>
403
+ <div class="example-prompt">
404
+ "Sunset beach: orange gradient sky at top, yellow semicircle sun on horizon, dark blue wavy ocean below, tan beach at bottom."<br/>
405
+ <span style="color: #666;">"日落海滩:顶部橙色渐变天空,地平线黄色半圆太阳,下方深蓝波浪海洋,底部棕褐色沙滩。"</span>
406
+ </div>
407
+ <p class="red-tip">⚠️ Use geometric shapes for nature: triangular trees, wavy water, semicircle sun!</p>
408
+ </div>
409
+
410
+ <div class="tip-category">
411
+ <h4>🐱 Animals | 动物</h4>
412
+ <p>Describe as geometric shapes: oval body, round head, triangular ears, curved tail.<br/>
413
+ <span style="color: #666; font-size: 0.9em;">描述为几何形状:椭圆身体,圆头,三角耳朵,弯曲尾巴。</span></p>
414
+ <div class="example-prompt">
415
+ "Cute cat: orange round head with two triangular ears, oval orange body, curved tail. Simple cartoon style with black outlines, sitting pose, white background."<br/>
416
+ <span style="color: #666;">"可爱猫咪:橙色圆头配两个三角耳朵,橙色椭圆身体,弯曲尾巴。简单卡通风格,黑色轮廓,坐姿,白色背景。"</span>
417
+ </div>
418
+ <div class="example-prompt">
419
+ "Simple black bird: oval body, small round head, pointed triangular beak facing right, triangular tail, two stick legs. Silhouette style on white."<br/>
420
+ <span style="color: #666;">"简单黑鸟:椭圆身体,小圆头,尖三角喙朝右,三角尾巴,两条棒状腿。白色背景剪影风格。"</span>
421
+ </div>
422
+ </div>
423
+
424
+ <div class="tip-category">
425
+ <h4>🏠 Buildings & Objects | 建筑与物体</h4>
426
+ <p>Use basic shapes: rectangles for walls, triangles for roofs, squares for windows.<br/>
427
+ <span style="color: #666; font-size: 0.9em;">使用基本形状:矩形墙壁,三角屋顶,方形窗户。</span></p>
428
+ <div class="example-prompt">
429
+ "Simple house: red triangular roof on top, beige rectangular wall, brown rectangular door in center, two small blue square windows. Green ground at bottom, white background."<br/>
430
+ <span style="color: #666;">"简单房屋:顶部红色三角屋顶,米色矩形墙壁,中间棕色矩形门,两个小蓝色方形窗户。底部绿色地面,白色背景。"</span>
431
+ </div>
432
+ <div class="example-prompt">
433
+ "Coffee mug: brown cylindrical cup shape with curved handle on right side, three wavy steam lines rising from top. Simple flat style on white."<br/>
434
+ <span style="color: #666;">"咖啡杯:棕色圆柱杯身,右侧弯曲把手,顶部三条波浪蒸汽线上升。简单平面风格,白色背景。"</span>
435
+ </div>
436
+ </div>
437
+
438
+ </div>
439
+
440
+ <!-- Extended Examples Section -->
441
+ <div style="margin-top: 20px; padding: 15px; background: #f0f7ff; border-radius: 10px; border: 1px solid #cce5ff;">
442
+ <h4 style="margin-top: 0; color: #0066cc;">🎨 More Complex Examples (Generate 6-8 candidates!) | 更多复杂示例(请生成6-8个候选!)</h4>
443
+
444
+ <div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(280px, 1fr)); gap: 12px; margin-top: 15px;">
445
+ <div class="example-prompt">
446
+ <strong>👨‍💼 Business Avatar:</strong><br/>
447
+ "Circular professional avatar: man with short black hair, neutral skin tone round face, wearing dark navy suit with white shirt collar visible. Clean minimal style, centered in circle, white background."
448
+ </div>
449
+ <div class="example-prompt">
450
+ <strong>👩 Female Portrait:</strong><br/>
451
+ "Simple female face: oval face shape, long brown wavy hair on sides, two dot eyes, small nose, curved smile lips. Pink blush on cheeks. Cartoon portrait style, white background."
452
+ </div>
453
+ <div class="example-prompt">
454
+ <strong>🧒 Child Character:</strong><br/>
455
+ "Cute child standing: large round head with short brown hair, big circular eyes with white highlights, small body in red t-shirt and blue shorts, simple stick arms and legs. Cheerful cartoon style."
456
+ </div>
457
+ <div class="example-prompt">
458
+ <strong>🏃 Active Pose:</strong><br/>
459
+ "Person walking: side view, circular head, rectangular torso in green jacket, legs in walking position (one forward, one back). Simple geometric style, moving right, white background."
460
+ </div>
461
+ <div class="example-prompt">
462
+ <strong>🌲 Forest Scene:</strong><br/>
463
+ "Simple forest: light blue sky, row of 5 dark green triangular pine trees of varying heights, brown rectangular trunks, light green grass strip at bottom. Layered flat design."
464
+ </div>
465
+ <div class="example-prompt">
466
+ <strong>🌊 Ocean View:</strong><br/>
467
+ "Minimalist ocean: gradient blue sky at top, three horizontal wavy lines in dark blue for ocean, small white sailboat with triangular sail in center. Clean vector style."
468
+ </div>
469
+ <div class="example-prompt">
470
+ <strong>🌆 City Skyline:</strong><br/>
471
+ "Simple city skyline: orange sunset sky gradient, row of black rectangular building silhouettes of different heights, some with small yellow square windows. Minimalist style."
472
+ </div>
473
+ <div class="example-prompt">
474
+ <strong>🐕 Dog Character:</strong><br/>
475
+ "Friendly cartoon dog: brown oval body, round head with floppy ears, black dot nose, curved tail pointing up, four short legs. Sitting pose facing forward, white background."
476
+ </div>
477
+ </div>
478
+ </div>
479
+
480
+ <!-- Quick Troubleshooting -->
481
+ <div class="green-box" style="margin-top: 15px;">
482
+ <strong>⚡ Quick Troubleshooting | 快速故障排除</strong>
483
+ <ul style="margin: 8px 0 0 0; padding-left: 20px;">
484
+ <li><strong>Messy/chaotic? 混乱?</strong> → Lower temperature to 0.3-0.4, simplify description, reduce top_k</li>
485
+ <li><strong>Too simple/empty? 太简单?</strong> → Raise temperature to 0.5-0.6, add more shape details</li>
486
+ <li><strong>Wrong colors? 颜色错误?</strong> → Explicitly name EVERY color: "red roof", "blue shirt", "black outline"</li>
487
+ <li><strong>Missing elements? 元素缺失?</strong> → Add position words: "at top", "in center", "at bottom left"</li>
488
+ <li><strong>Repetitive patterns? 重复图案?</strong> → Increase repetition_penalty to 1.08-1.15</li>
489
+ <li><strong>Inconsistent? 不一致?</strong> → <span class="red-tip">Generate MORE candidates (6-8) and pick the best!</span></li>
490
+ </ul>
491
+ </div>
492
+
493
+ <!-- Prompt Template -->
494
+ <div style="margin-top: 15px; padding: 12px; background: #e8f5e9; border-radius: 8px; border-left: 4px solid #4caf50;">
495
+ <strong>✅ Recommended Prompt Structure | 推荐提示词结构</strong>
496
+ <div style="background: white; padding: 10px; border-radius: 6px; margin-top: 8px; font-family: monospace; font-size: 0.9em;">
497
+ [Subject] + [Shape descriptions with colors] + [Position/orientation] + [Style] + [Background]
498
+ </div>
499
+ <p style="margin: 10px 0 0 0; color: #2e7d32; font-size: 0.95em;">
500
+ ✓ "A fox logo: triangular orange head, pointed ears, white chest marking, facing right. Minimalist flat style, centered on white background."
501
+ </p>
502
+ </div>
503
+ </div>
504
+ """
505
+
506
+ # Image-to-SVG specific tips
507
+ IMAGE_TIPS_HTML = """
508
+ <div class="red-box">
509
+ <strong>🔴 Image-to-SVG Tips | 图片转SVG技巧</strong>
510
+ <ul style="margin: 8px 0 0 0; padding-left: 20px;">
511
+ <li><strong>Best input: Simple images with white/transparent background</strong><br/>
512
+ <span style="color: #666;">最佳输入:白色或透明背景的简单图片</span></li>
513
+ <li><strong>PNG with transparency (RGBA) works best!</strong> We auto-convert to white background.<br/>
514
+ <span style="color: #666;">透明背景的PNG效果最好!我们会自动转换为白色背景。</span></li>
515
+ <li><strong>For complex backgrounds:</strong> Enable "Replace Background" option below.<br/>
516
+ <span style="color: #666;">复杂背景图片:启用下方的"替换背景"选项。</span></li>
517
+ <li><strong>Lower temperature (0.2-0.4)</strong> for more accurate reproduction.<br/>
518
+ <span style="color: #666;">较低温度(0.2-0.4)可获得更准确的复制效果。</span></li>
519
+ <li style="color: #dc3545; font-weight: 600;"><strong>Generate 4-8 candidates!</strong> Pick the one that best matches your input.<br/>
520
+ <span style="color: #666; font-weight: normal;">生成4-8个候选!选择最匹配输入的那个。</span></li>
521
+ </ul>
522
+ </div>
523
+ """
524
+
525
 
526
  def parse_args():
527
  parser = argparse.ArgumentParser(description='SVG Generator Service')
 
529
  parser.add_argument('--port', type=int, default=7860)
530
  parser.add_argument('--share', action='store_true')
531
  parser.add_argument('--debug', action='store_true')
532
+ parser.add_argument('--weight_path', type=str, default="/mnt/jfs-test/OmniSVG_result/8B_1126/1688_bs_4/merge_slerp/merge_150_350_bf16")
533
+ parser.add_argument('--model_path', type=str, default="/mnt/jfs-test/Qwen2.5-VL-7B-Instruct")
534
  return parser.parse_args()
535
 
536
+
537
+ def load_models(weight_path, model_path):
538
+ """Load all models"""
539
+ global tokenizer, processor, sketch_decoder, svg_tokenizer
540
 
541
+ print(f"Loading models from {model_path}...")
542
+ print(f"Using precision: {DTYPE}")
543
+
544
+ tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left")
545
+ processor = AutoProcessor.from_pretrained(model_path, padding_side="left")
546
+ processor.tokenizer.padding_side = "left"
547
 
548
+ sketch_decoder = SketchDecoder(
549
+ pix_len=config['model']['max_length'],
550
+ text_len=200,
551
+ model_path=model_path,
552
+ torch_dtype=DTYPE
553
+ )
554
+
555
+ bin_path = os.path.join(weight_path, "pytorch_model.bin")
556
+ if os.path.exists(bin_path):
557
+ print(f"Loading weights from: {bin_path}")
558
+ sketch_decoder.load_state_dict(torch.load(bin_path, map_location='cpu'))
559
+ else:
560
+ raise FileNotFoundError(f"No weights found at {bin_path}")
561
+
562
+ sketch_decoder = sketch_decoder.to(device).eval()
563
+ svg_tokenizer = SVGTokenizer('./config.yaml')
564
+
565
+ print("All models loaded successfully!")
566
+
567
+
568
+ def detect_text_subtype(text_prompt):
569
+ """Auto-detect text prompt subtype"""
570
+ text_lower = text_prompt.lower()
571
+
572
+ icon_keywords = ['icon', 'logo', 'symbol', 'badge', 'button', 'emoji', 'glyph', 'simple',
573
+ 'arrow', 'triangle', 'circle', 'square', 'heart', 'star', 'checkmark']
574
+ if any(kw in text_lower for kw in icon_keywords):
575
+ return "icon"
576
+
577
+ illustration_keywords = [
578
+ 'illustration', 'scene', 'person', 'people', 'character', 'man', 'woman', 'boy', 'girl',
579
+ 'avatar', 'portrait', 'face', 'head', 'body',
580
+ 'cat', 'dog', 'bird', 'animal', 'pet', 'fox', 'rabbit',
581
+ 'sitting', 'standing', 'walking', 'running', 'sleeping', 'holding', 'playing',
582
+ 'house', 'building', 'tree', 'garden', 'landscape', 'mountain', 'forest', 'city',
583
+ 'ocean', 'beach', 'sunset', 'sunrise', 'sky'
584
+ ]
585
+
586
+ match_count = sum(1 for kw in illustration_keywords if kw in text_lower)
587
+ if match_count >= 1 or len(text_prompt) > 50:
588
+ return "illustration"
589
+
590
+ return "icon"
591
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
592
 
593
+ def detect_and_replace_background(image, threshold=240, edge_sample_ratio=0.1):
594
+ """
595
+ Detect if image has non-white background and optionally replace it.
596
+
597
+ Args:
598
+ image: PIL Image (RGB or RGBA)
599
+ threshold: Pixel values above this are considered "white"
600
+ edge_sample_ratio: Ratio of edge pixels to sample
601
+
602
+ Returns:
603
+ tuple: (processed_image, background_was_replaced)
604
+ """
605
+ img_array = np.array(image)
606
+
607
+ # If already has alpha channel, composite onto white
608
+ if image.mode == 'RGBA':
609
+ # Create white background and composite
610
+ bg = Image.new('RGBA', image.size, (255, 255, 255, 255))
611
+ composite = Image.alpha_composite(bg, image)
612
+ return composite.convert('RGB'), True
613
+
614
+ # Sample edge pixels to detect background color
615
+ h, w = img_array.shape[:2]
616
+ edge_pixels = []
617
+
618
+ # Sample from all 4 edges
619
+ sample_count = max(10, int(min(h, w) * edge_sample_ratio))
620
+
621
+ # Top and bottom edges
622
+ for i in range(0, w, max(1, w // sample_count)):
623
+ edge_pixels.append(img_array[0, i])
624
+ edge_pixels.append(img_array[h-1, i])
625
+
626
+ # Left and right edges
627
+ for i in range(0, h, max(1, h // sample_count)):
628
+ edge_pixels.append(img_array[i, 0])
629
+ edge_pixels.append(img_array[i, w-1])
630
+
631
+ edge_pixels = np.array(edge_pixels)
632
+
633
+ # Check if background is already white-ish
634
+ if len(edge_pixels) > 0:
635
+ mean_edge = edge_pixels.mean(axis=0)
636
+ if np.all(mean_edge > threshold):
637
+ # Background is already white, just return original
638
+ return image, False
639
+
640
+ # Background is not white - try to replace it
641
+ # Use the most common edge color as the background color to replace
642
+ if len(img_array.shape) == 3 and img_array.shape[2] >= 3:
643
+ # Convert to grayscale for easier background detection
644
+ if img_array.shape[2] == 4:
645
+ gray = np.mean(img_array[:, :, :3], axis=2)
646
+ else:
647
+ gray = np.mean(img_array, axis=2)
648
 
649
+ # Find background color (most common color at edges)
650
+ edge_colors = []
651
+ for i in range(w):
652
+ edge_colors.append(tuple(img_array[0, i, :3]))
653
+ edge_colors.append(tuple(img_array[h-1, i, :3]))
654
+ for i in range(h):
655
+ edge_colors.append(tuple(img_array[i, 0, :3]))
656
+ edge_colors.append(tuple(img_array[i, w-1, :3]))
657
 
658
+ # Find most common edge color
659
+ from collections import Counter
660
+ color_counts = Counter(edge_colors)
661
+ bg_color = color_counts.most_common(1)[0][0]
662
+
663
+ # Create mask for background (colors similar to detected bg_color)
664
+ color_diff = np.sqrt(np.sum((img_array[:, :, :3].astype(float) - np.array(bg_color)) ** 2, axis=2))
665
+ bg_mask = color_diff < 30 # Threshold for color similarity
666
+
667
+ # Replace background with white
668
+ result = img_array.copy()
669
+ if result.shape[2] == 4:
670
+ result[bg_mask] = [255, 255, 255, 255]
671
+ else:
672
+ result[bg_mask] = [255, 255, 255]
673
+
674
+ return Image.fromarray(result).convert('RGB'), True
675
+
676
+ return image, False
677
 
678
+
679
+ def preprocess_image_for_svg(image, replace_background=True, target_size=448):
680
+ """
681
+ Preprocess image for SVG generation.
682
+
683
+ Args:
684
+ image: Input PIL Image or path
685
+ replace_background: Whether to replace non-white backgrounds
686
+ target_size: Target size for resizing
687
+
688
+ Returns:
689
+ tuple: (processed_pil_image, was_modified)
690
+ """
691
+ # Load image if path
692
+ if isinstance(image, str):
693
+ raw_img = Image.open(image)
694
+ else:
695
+ raw_img = image
696
+
697
+ was_modified = False
698
+
699
+ # Handle different modes
700
+ if raw_img.mode == 'RGBA':
701
+ # RGBA images: composite onto white background
702
+ bg = Image.new('RGBA', raw_img.size, (255, 255, 255, 255))
703
+ img_with_bg = Image.alpha_composite(bg, raw_img).convert('RGB')
704
+ was_modified = True
705
+ elif raw_img.mode == 'LA' or raw_img.mode == 'PA':
706
+ # Grayscale or Palette with alpha
707
+ raw_img = raw_img.convert('RGBA')
708
+ bg = Image.new('RGBA', raw_img.size, (255, 255, 255, 255))
709
+ img_with_bg = Image.alpha_composite(bg, raw_img).convert('RGB')
710
+ was_modified = True
711
+ elif raw_img.mode != 'RGB':
712
+ img_with_bg = raw_img.convert('RGB')
713
  else:
714
+ img_with_bg = raw_img
715
+
716
+ # Optionally detect and replace non-white background
717
+ if replace_background:
718
+ img_with_bg, bg_replaced = detect_and_replace_background(img_with_bg)
719
+ was_modified = was_modified or bg_replaced
720
 
721
+ # Resize to target size
722
+ img_resized = img_with_bg.resize((target_size, target_size), Image.Resampling.LANCZOS)
723
+
724
+ return img_resized, was_modified
725
 
 
 
 
 
 
 
 
 
 
726
 
727
+ def prepare_inputs(task_type, content):
728
+ """Prepare model inputs"""
729
+ if task_type == "text-to-svg":
730
+ prompt_text = str(content).strip()
731
+
732
+ instruction = f"""Generate an SVG illustration for: {prompt_text}
733
+
734
+ Requirements:
735
+ - Create complete SVG path commands
736
+ - Include proper coordinates and colors
737
+ - Maintain visual clarity and composition"""
738
+
739
+ messages = [
740
+ {"role": "system", "content": SYSTEM_PROMPT},
741
+ {"role": "user", "content": [{"type": "text", "text": instruction}]}
742
+ ]
743
+ text_input = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
744
+ inputs = processor(text=[text_input], padding=True, truncation=True, return_tensors="pt")
745
+
746
+ else: # image-to-svg
747
+ messages = [
748
+ {"role": "system", "content": SYSTEM_PROMPT},
749
+ {"role": "user", "content": [
750
+ {"type": "text", "text": "Generate SVG code that accurately represents this image:"},
751
+ {"type": "image", "image": content},
752
+ ]}
753
  ]
754
+ text_input = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
755
+ image_inputs, _ = process_vision_info(messages)
756
+ inputs = processor(text=[text_input], images=image_inputs, padding=True, truncation=True, return_tensors="pt")
757
+
758
+ return inputs
759
+
760
+
761
+ def render_svg_to_image(svg_str, size=512):
762
+ """Render SVG to high-quality PIL Image"""
763
+ try:
764
+ png_data = cairosvg.svg2png(
765
+ bytestring=svg_str.encode('utf-8'),
766
+ output_width=size,
767
+ output_height=size
768
+ )
769
+ image_rgba = Image.open(io.BytesIO(png_data)).convert("RGBA")
770
+ bg = Image.new("RGB", image_rgba.size, (255, 255, 255))
771
+ bg.paste(image_rgba, mask=image_rgba.split()[3])
772
+ return bg
773
+ except Exception as e:
774
+ print(f"Render error: {e}")
775
+ return None
776
+
777
+
778
+ def create_gallery_html(candidates, cols=4):
779
+ """Create HTML gallery for multiple SVG candidates"""
780
+ if not candidates:
781
+ return '<div style="text-align:center;color:#999;padding:50px;">No candidates generated / 未生成候选</div>'
782
 
783
+ items_html = []
784
+ for i, cand in enumerate(candidates):
785
+ svg_str = cand['svg']
786
+ if 'viewBox' not in svg_str:
787
+ svg_str = svg_str.replace('<svg', f'<svg viewBox="0 0 {TARGET_IMAGE_SIZE} {TARGET_IMAGE_SIZE}"', 1)
788
+
789
+ item_html = f'''
790
+ <div style="
791
+ background: white;
792
+ border: 1px solid #ddd;
793
+ border-radius: 8px;
794
+ padding: 10px;
795
+ text-align: center;
796
+ transition: transform 0.2s, box-shadow 0.2s;
797
+ cursor: pointer;
798
+ " onmouseover="this.style.transform='scale(1.02)';this.style.boxShadow='0 4px 12px rgba(0,0,0,0.15)';"
799
+ onmouseout="this.style.transform='scale(1)';this.style.boxShadow='none';">
800
+ <div style="width: 180px; height: 180px; margin: 0 auto; display: flex; justify-content: center; align-items: center; overflow: hidden;">
801
+ {svg_str}
802
+ </div>
803
+ <div style="margin-top: 8px; font-size: 12px; color: #666;">
804
+ #{i+1} | {cand['path_count']} paths
805
+ </div>
806
+ </div>
807
+ '''
808
+ items_html.append(item_html)
809
 
810
+ grid_html = f'''
811
+ <div style="
812
+ display: grid;
813
+ grid-template-columns: repeat({cols}, 1fr);
814
+ gap: 15px;
815
+ padding: 15px;
816
+ background: #fafafa;
817
+ border-radius: 12px;
818
+ ">
819
+ {''.join(items_html)}
820
+ </div>
821
+ '''
822
+ return grid_html
823
+
824
+
825
+ def is_valid_candidate(svg_str, img, subtype="illustration"):
826
+ """Check candidate validity"""
827
+ if not svg_str or len(svg_str) < 20:
828
+ return False, "too_short"
829
 
830
+ if '<svg' not in svg_str:
831
+ return False, "no_svg_tag"
832
 
833
+ if img is None:
834
+ return False, "render_failed"
835
+
836
+ img_array = np.array(img)
837
+ mean_val = img_array.mean()
838
+
839
+ threshold = 250 if subtype == "illustration" else 252
840
+
841
+ if mean_val > threshold:
842
+ return False, "empty_image"
843
+
844
+ return True, "ok"
845
+
846
+
847
+ def generate_candidates(inputs, task_type, subtype, temperature, top_p, top_k, repetition_penalty,
848
+ max_length, num_samples, progress_callback=None):
849
+ """Generate candidate SVGs with full parameter control"""
850
 
851
  input_ids = inputs['input_ids'].to(device)
852
  attention_mask = inputs['attention_mask'].to(device)
 
 
853
 
854
+ model_inputs = {
855
+ "input_ids": input_ids,
856
+ "attention_mask": attention_mask
857
+ }
858
+
859
+ if 'pixel_values' in inputs:
860
+ model_inputs["pixel_values"] = inputs['pixel_values'].to(device, dtype=DTYPE)
861
+
862
+ if 'image_grid_thw' in inputs:
863
+ model_inputs["image_grid_thw"] = inputs['image_grid_thw'].to(device)
864
+
865
+ all_candidates = []
866
+
867
+ # Generation config with user parameters
868
+ gen_config = {
869
+ 'do_sample': True,
870
+ 'temperature': temperature,
871
+ 'top_p': top_p,
872
+ 'top_k': int(top_k),
873
+ 'repetition_penalty': repetition_penalty,
874
+ 'early_stopping': True,
875
+ 'no_repeat_ngram_size': 0,
876
+ 'eos_token_id': config['model']['eos_token_id'],
877
+ 'pad_token_id': config['model']['pad_token_id'],
878
+ 'bos_token_id': config['model']['bos_token_id'],
879
+ }
880
+
881
+ actual_samples = num_samples + 4
882
+
883
  try:
884
+ if progress_callback:
885
+ progress_callback(0.1, "Waiting for model access / 等待模型访问...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
886
 
887
+ with generation_lock:
888
+ if progress_callback:
889
+ progress_callback(0.15, "Generating SVG tokens / 生成SVG令牌...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
890
 
891
+ with torch.no_grad():
892
+ results = sketch_decoder.transformer.generate(
893
+ **model_inputs,
894
+ max_new_tokens=max_length,
895
+ num_return_sequences=actual_samples,
896
+ use_cache=True,
897
+ **gen_config
898
+ )
899
+
900
+ input_len = input_ids.shape[1]
901
+ generated_ids_batch = results[:, input_len:]
902
 
903
+ if progress_callback:
904
+ progress_callback(0.5, "Processing generated tokens / 处理生成的令牌...")
905
 
906
+ for i in range(min(actual_samples, generated_ids_batch.shape[0])):
907
+ try:
908
+ current_ids = generated_ids_batch[i:i+1]
909
+
910
+ fake_wrapper = torch.cat([
911
+ torch.full((1, 1), config['model']['bos_token_id'], device=device),
912
+ current_ids,
913
+ torch.full((1, 1), config['model']['eos_token_id'], device=device)
914
+ ], dim=1)
915
+
916
+ generated_xy = svg_tokenizer.process_generated_tokens(fake_wrapper)
917
+ if len(generated_xy) == 0:
918
+ continue
919
+
920
+ svg_tensors, color_tensors = svg_tokenizer.raster_svg(generated_xy)
921
+ if not svg_tensors or not svg_tensors[0]:
922
+ continue
923
+
924
+ num_paths = len(svg_tensors[0])
925
+ while len(color_tensors) < num_paths:
926
+ color_tensors.append(BLACK_COLOR_TOKEN)
927
+
928
+ svg = svg_tokenizer.apply_colors_to_svg(svg_tensors[0], color_tensors)
929
+ svg_str = svg.to_str()
930
+
931
+ if 'width=' not in svg_str:
932
+ svg_str = svg_str.replace('<svg', f'<svg width="{TARGET_IMAGE_SIZE}" height="{TARGET_IMAGE_SIZE}"', 1)
933
+
934
+ png_image = render_svg_to_image(svg_str, size=512)
935
+
936
+ is_valid, reason = is_valid_candidate(svg_str, png_image, subtype)
937
+ if is_valid:
938
+ all_candidates.append({
939
+ 'svg': svg_str,
940
+ 'img': png_image,
941
+ 'path_count': num_paths,
942
+ 'index': len(all_candidates) + 1
943
+ })
944
+
945
+ if progress_callback:
946
+ progress_callback(0.5 + 0.4 * (i / actual_samples),
947
+ f"Found {len(all_candidates)} valid / 找到 {len(all_candidates)} 个有效...")
948
+
949
+ if len(all_candidates) >= num_samples:
950
+ break
951
+
952
+ except Exception as e:
953
+ print(f" Candidate {i} error: {e}")
954
+ continue
955
+
956
  except Exception as e:
957
+ print(f"Generation Error: {e}")
958
  import traceback
959
  traceback.print_exc()
960
+
961
+ if progress_callback:
962
+ progress_callback(0.95, f"Generated {len(all_candidates)} valid / 生成了 {len(all_candidates)} 个有效")
963
+
964
+ return all_candidates
965
 
966
  @spaces.GPU
967
+ def gradio_text_to_svg(text_description, num_candidates, temperature, top_p, top_k, repetition_penalty,
968
+ progress=gr.Progress()):
969
+ """Gradio interface - text-to-svg with advanced parameters"""
970
+ if not text_description or text_description.strip() == "":
971
+ return '<div style="text-align:center;color:#999;padding:50px;">Please enter a description / 请输入描述</div>', ""
972
+
973
+ progress(0, "Starting generation / 开始生成...")
974
+
975
+ gc.collect()
976
+ if torch.cuda.is_available():
977
+ torch.cuda.empty_cache()
978
+
979
+ start_time = time.time()
980
+
981
+ subtype = detect_text_subtype(text_description)
982
+ progress(0.05, f"Detected: {subtype} / 检测到: {subtype}")
983
+
984
+ inputs = prepare_inputs("text-to-svg", text_description.strip())
985
+ max_length = config['model']['max_length']
986
+
987
+ def update_progress(val, msg):
988
+ progress(val, msg)
989
+
990
+ all_candidates = generate_candidates(
991
+ inputs, "text-to-svg", subtype,
992
+ temperature, top_p, int(top_k), repetition_penalty,
993
+ max_length, int(num_candidates),
994
+ progress_callback=update_progress
995
+ )
996
+
997
+ elapsed = time.time() - start_time
998
+
999
+ if not all_candidates:
1000
+ return (
1001
+ '<div style="text-align:center;color:#999;padding:50px;">No valid SVG generated. Try different parameters or rephrase your prompt.<br/>未生成有效的SVG。请尝试不同参数或重新描述。</div>',
1002
+ f"<!-- No valid SVG (took {elapsed:.1f}s) -->"
1003
+ )
1004
 
1005
+ svg_codes = []
1006
+ for i, cand in enumerate(all_candidates):
1007
+ svg_codes.append(f"<!-- ====== Candidate {i+1} | Paths: {cand['path_count']} ====== -->\n{cand['svg']}")
1008
+
1009
+ combined_svg = "\n\n".join(svg_codes)
1010
+ gallery_html = create_gallery_html(all_candidates)
1011
+
1012
+ progress(1.0, f"Done! {len(all_candidates)} candidates in {elapsed:.1f}s / 完成!{len(all_candidates)} 个候选,{elapsed:.1f}秒")
1013
+
1014
+ return gallery_html, combined_svg
1015
+
1016
+ @spaces.GPU
1017
+ def gradio_image_to_svg(image, num_candidates, temperature, top_p, top_k, repetition_penalty,
1018
+ replace_background, progress=gr.Progress()):
1019
+ """Gradio interface - image-to-svg with background handling"""
1020
+
1021
+ if image is None:
1022
+ return (
1023
+ '<div style="text-align:center;color:#999;padding:50px;">Please upload an image / 请上传图片</div>',
1024
+ "",
1025
+ None
1026
+ )
1027
+
1028
+ progress(0, "Processing input image / 处理输入图片...")
1029
+
1030
+ gc.collect()
1031
+ if torch.cuda.is_available():
1032
+ torch.cuda.empty_cache()
1033
+
1034
+ start_time = time.time()
1035
+
1036
+ # Preprocess image with optional background replacement
1037
+ img_processed, was_modified = preprocess_image_for_svg(
1038
+ image,
1039
+ replace_background=replace_background,
1040
+ target_size=TARGET_IMAGE_SIZE
1041
+ )
1042
+
1043
+ if was_modified:
1044
+ progress(0.05, "Background processed / 背景已处理")
1045
+
1046
+ # Save temp file
1047
  with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_file:
1048
+ img_processed.save(tmp_file.name, format='PNG', quality=100)
1049
  tmp_path = tmp_file.name
1050
 
1051
  try:
1052
+ progress(0.1, "Preparing model inputs / 准备模型输入...")
1053
+ inputs = prepare_inputs("image-to-svg", tmp_path)
1054
+ max_length = config['model']['max_length']
1055
+
1056
+ def update_progress(val, msg):
1057
+ progress(val, msg)
1058
+
1059
+ all_candidates = generate_candidates(
1060
+ inputs, "image-to-svg", "image",
1061
+ temperature, top_p, int(top_k), repetition_penalty,
1062
+ max_length, int(num_candidates),
1063
+ progress_callback=update_progress
1064
+ )
1065
+
1066
+ elapsed = time.time() - start_time
1067
+
1068
+ if not all_candidates:
1069
+ return (
1070
+ '<div style="text-align:center;color:#999;padding:50px;">No valid SVG generated. Try adjusting parameters.<br/>未生成有效的SVG。请尝试调整参数。</div>',
1071
+ f"<!-- No valid SVG (took {elapsed:.1f}s) -->",
1072
+ img_processed
1073
+ )
1074
+
1075
+ svg_codes = []
1076
+ for i, cand in enumerate(all_candidates):
1077
+ svg_codes.append(f"<!-- ====== Candidate {i+1} | Paths: {cand['path_count']} ====== -->\n{cand['svg']}")
1078
+
1079
+ combined_svg = "\n\n".join(svg_codes)
1080
+ gallery_html = create_gallery_html(all_candidates)
1081
+
1082
+ progress(1.0, f"Done! {len(all_candidates)} candidates in {elapsed:.1f}s")
1083
+
1084
+ return gallery_html, combined_svg, img_processed
1085
+
1086
  finally:
1087
+ if os.path.exists(tmp_path):
1088
+ os.unlink(tmp_path)
1089
 
1090
+
1091
+ def get_example_images():
1092
+ """Get example images from the examples directory"""
1093
+ example_dir = "./examples"
1094
+ example_images = []
1095
 
1096
+ if os.path.exists(example_dir):
1097
+ for ext in SUPPORTED_FORMATS:
1098
+ pattern = os.path.join(example_dir, f"*{ext}")
1099
+ example_images.extend(glob.glob(pattern))
1100
+ example_images.sort()
1101
+
1102
+ return example_images
1103
+
1104
 
1105
  def create_interface():
1106
+ """Create Gradio interface"""
1107
+
1108
+ # 30 Example prompts covering various categories
1109
  example_texts = [
1110
+ # === Simple Icons (1-6) ===
1111
+ "A black triangle pointing downward, centrally positioned on white background.",
1112
+ "A red heart shape with smooth curved edges, centered on white background.",
1113
+ "A yellow star with five sharp points, simple geometric design, flat color on white background.",
1114
+ "A blue arrow pointing to the right, thick solid shape, centered on white background.",
1115
+ "A green circle with a white checkmark inside, centered on white background.",
1116
+ "A black plus sign with equal length arms, thick lines, centered on white background.",
1117
+
1118
+ # === Characters & People (7-12) ===
1119
+ "A simple person standing: round beige head, rectangular blue shirt body, two dark gray rectangular legs, arms at sides. Flat colors, white background.",
1120
+ "A girl with long black hair, wearing pink dress with triangular skirt, small circular face with dot eyes and curved smile. Simple cartoon style, white background.",
1121
+ "A businessman: circular head with short black hair, rectangular dark navy suit body, straight standing pose. Professional minimal style, white background.",
1122
+ "A child waving: large round head with brown messy hair, big circular eyes, small body in red t-shirt and blue shorts, one arm raised. Cheerful cartoon style.",
1123
+ "A person sitting on chair: side view, round head, rectangular torso in green sweater, bent legs on simple chair shape. Relaxed pose, white background.",
1124
+ "A running person: side view silhouette in black, dynamic pose with one leg forward, arms pumping. Motion style, white background.",
1125
+
1126
+ # === Avatars & Portraits (13-17) ===
1127
+ "Circular avatar: person with short black hair, round face with two dot eyes and small curved smile, wearing blue collar shirt. Minimal style, centered in circle.",
1128
+ "Female avatar: oval face with long wavy brown hair, simple eyes, pink lips, wearing v-neck purple top. Soft cartoon style in circular frame.",
1129
+ "Profile silhouette avatar: black side view of head with short hair and glasses outline, facing right. Simple solid shape on white.",
1130
+ "Cute cartoon avatar: round face with big sparkly eyes, rosy cheeks, short bob haircut in orange. Kawaii style, circular frame.",
1131
+ "Professional headshot avatar: person with neat hair, neutral expression, wearing suit collar. Corporate minimal style, circular frame, white background.",
1132
+
1133
+ # === Landscapes & Scenes (18-23) ===
1134
+ "Layered mountain landscape: light blue sky at top, gray triangular snow-capped mountains in middle, dark green triangular pine trees at bottom. Flat colors.",
1135
+ "Sunset beach scene: orange gradient sky at top, yellow semicircle sun on horizon, dark blue wavy ocean, tan beach strip at bottom. Simple shapes.",
1136
+ "Forest scene: light blue sky, row of 5 dark green triangular pine trees of varying heights on brown trunks, light green grass at bottom.",
1137
+ "City skyline at dusk: purple-orange gradient sky, row of black rectangular building silhouettes of different heights, some with yellow window squares.",
1138
+ "Desert landscape: light orange sky with white circle sun, tan sand dunes as curved shapes, one green cactus with arms on the right side.",
1139
+ "Countryside scene: blue sky with white fluffy clouds, green rolling hills, small red barn with white door in the center, yellow hay bales.",
1140
+
1141
+ # === Animals (24-27) ===
1142
+ "Cute orange cat sitting: round head with two triangular ears, oval body, curved tail. Black outline cartoon style, facing forward, white background.",
1143
+ "Simple black bird: oval body, round head, pointed triangular beak facing right, triangular tail, two stick legs. Silhouette style on white.",
1144
+ "Friendly cartoon dog: brown oval body, round head with floppy ears, black dot nose, wagging curved tail, four short legs. Sitting pose.",
1145
+ "Red fox logo: triangular orange face with pointed ears, white chest marking, bushy tail. Minimalist style, facing right, centered on white.",
1146
+
1147
+ # === Objects & Misc (28-30) ===
1148
+ "Simple house icon: red triangular roof, beige rectangular walls, brown door in center, two blue square windows, green ground at bottom.",
1149
+ "Coffee mug: brown cylindrical cup with curved handle on right, three wavy steam lines rising from top. Flat style on white background.",
1150
+ "Open book: two rectangular white pages spread open, black text lines on each page, brown spine in center. Simple top-down view."
1151
  ]
1152
+
1153
  example_images = get_example_images()
1154
 
1155
+ with gr.Blocks(title="OmniSVG Generator", theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo:
1156
+ # Header
1157
+ gr.HTML("""
1158
+ <div class="header-container">
1159
+ <h1>🎨 OmniSVG Generator</h1>
1160
+ <p>Transform images and text descriptions into scalable vector graphics</p>
1161
+ <p style="font-size: 0.9em; opacity: 0.8;">将图像和文本描述转换为可缩放矢量图形</p>
1162
+ </div>
1163
+ """)
1164
+
1165
+ # Queue status
1166
+ gr.HTML("""
1167
+ <div style="background: #e7f3ff; border: 1px solid #b3d7ff; border-radius: 8px; padding: 12px 15px; margin-bottom: 15px;">
1168
+ <span style="font-size: 1.5em;">ℹ️</span>
1169
+ <strong>Queue System Active</strong> - Requests processed one at a time. Please wait patiently if busy.<br/>
1170
+ <span style="color: #666;">队列系统已启用 - 请求按顺序处理,繁忙时请耐心等待。</span>
1171
+ </div>
1172
+ """)
1173
+
1174
+ # Tips section
1175
+ gr.HTML(TIPS_HTML)
1176
 
1177
  with gr.Tabs():
1178
+ # ==================== Image-to-SVG Tab ====================
1179
+ with gr.TabItem("🖼️ Image-to-SVG", id="image-tab"):
1180
+ gr.HTML(IMAGE_TIPS_HTML)
1181
+
1182
+ with gr.Row(equal_height=False):
1183
+ with gr.Column(scale=1, min_width=300):
1184
+ gr.Markdown("### 📤 Upload Image / 上传图片")
1185
+ image_input = gr.Image(
1186
+ label="Drag, upload, or Ctrl+V to paste / 拖拽、上传或Ctrl+V粘贴",
1187
+ type="pil",
1188
+ image_mode="RGBA",
1189
+ height=250,
1190
+ sources=["upload", "clipboard"],
1191
+ elem_classes=["input-image"]
1192
+ )
1193
+
1194
+ with gr.Group(elem_classes=["settings-group"]):
1195
+ gr.Markdown("### ⚙️ Settings / 设置")
1196
+ img_num_candidates = gr.Slider(
1197
+ minimum=1, maximum=8, value=4, step=1,
1198
+ label="Number of Candidates / 候选数量"
1199
+ )
1200
+ img_replace_bg = gr.Checkbox(
1201
+ label="Replace non-white background / 替换非白色背景",
1202
+ value=True,
1203
+ info="Enable for images with colored backgrounds / 对有色背景图片启用"
1204
+ )
1205
+
1206
+ with gr.Accordion("🔧 Advanced Parameters / 高级参数", open=False):
1207
+ img_temperature = gr.Slider(
1208
+ minimum=0.1, maximum=1.0, value=0.3, step=0.05,
1209
+ label="Temperature (Lower=accurate)",
1210
+ info="0.2-0.4 recommended / 建议0.2-0.4"
1211
+ )
1212
+ img_top_p = gr.Slider(
1213
+ minimum=0.5, maximum=1.0, value=0.90, step=0.02,
1214
+ label="Top-P"
1215
+ )
1216
+ img_top_k = gr.Slider(
1217
+ minimum=10, maximum=100, value=50, step=5,
1218
+ label="Top-K"
1219
+ )
1220
+ img_rep_penalty = gr.Slider(
1221
+ minimum=1.0, maximum=1.3, value=1.05, step=0.01,
1222
+ label="Repetition Penalty"
1223
+ )
1224
+
1225
+ image_generate_btn = gr.Button(
1226
+ "🚀 Generate SVG / 生成SVG",
1227
+ variant="primary",
1228
+ size="lg",
1229
+ elem_classes=["primary-btn"]
1230
+ )
1231
+
1232
  if example_images:
1233
+ gr.Markdown("### 📁 Examples")
1234
+ gr.Examples(examples=example_images, inputs=[image_input], label="")
1235
 
1236
+ with gr.Column(scale=2, min_width=500):
1237
+ gr.Markdown("### 📥 Processed Input / 处理后输入")
1238
+ image_processed = gr.Image(label="", type="pil", height=120)
1239
+
1240
+ gr.Markdown("### 🖼️ Generated SVG Candidates / 生成的SVG候选")
1241
+ image_gallery = gr.HTML(
1242
+ value='<div style="text-align:center;color:#999;padding:50px;background:#fafafa;border-radius:12px;">Generated SVGs will appear here / 生成的SVG将显示在这里</div>'
1243
+ )
1244
+
1245
+ gr.Markdown("### 📝 SVG Code")
1246
+ image_svg_output = gr.Code(label="", language="html", lines=10, elem_classes=["code-output"])
1247
 
1248
+ image_generate_btn.click(
1249
+ fn=gradio_image_to_svg,
1250
+ inputs=[image_input, img_num_candidates, img_temperature, img_top_p,
1251
+ img_top_k, img_rep_penalty, img_replace_bg],
1252
+ outputs=[image_gallery, image_svg_output, image_processed],
1253
+ queue=True
1254
+ )
1255
+
1256
+ # ==================== Text-to-SVG Tab ====================
1257
+ with gr.TabItem("✏️ Text-to-SVG", id="text-tab"):
1258
+ with gr.Row(equal_height=False):
1259
+ with gr.Column(scale=1, min_width=300):
1260
+ gr.Markdown("### 📝 Description / 描述")
1261
+ gr.HTML("""
1262
+ <div style="background: #fff5f5; padding: 10px; border-radius: 8px; border-left: 4px solid #dc3545; margin-bottom: 10px;">
1263
+ <strong style="color: #dc3545;">🔴 Generate 4-8 candidates and pick the best!</strong><br/>
1264
+ 生成4-8个候选结果并选择最好的!
1265
+ </div>
1266
+ """)
1267
+ text_input = gr.Textbox(
1268
+ label="",
1269
+ placeholder="Describe your SVG with geometric shapes and colors...\n用几何形状和颜色描述您的SVG...\n\nExample: A black triangle pointing downward, centrally positioned on white background.",
1270
+ lines=5
1271
+ )
1272
+
1273
+ with gr.Group(elem_classes=["settings-group"]):
1274
+ gr.Markdown("### ⚙️ Settings / 设置")
1275
+ text_num_candidates = gr.Slider(
1276
+ minimum=1, maximum=8, value=6, step=1,
1277
+ label="Number of Candidates / 候选数量",
1278
+ info="More = better chances! / 越多越好!"
1279
+ )
1280
+
1281
+ with gr.Accordion("🔧 Advanced Parameters / 高级参数", open=False):
1282
+ text_temperature = gr.Slider(
1283
+ minimum=0.1, maximum=1.0, value=0.5, step=0.05,
1284
+ label="Temperature",
1285
+ info="Icons: 0.3-0.5 | Complex: 0.5-0.7"
1286
+ )
1287
+ text_top_p = gr.Slider(
1288
+ minimum=0.5, maximum=1.0, value=0.90, step=0.02,
1289
+ label="Top-P"
1290
+ )
1291
+ text_top_k = gr.Slider(
1292
+ minimum=10, maximum=100, value=60, step=5,
1293
+ label="Top-K"
1294
+ )
1295
+ text_rep_penalty = gr.Slider(
1296
+ minimum=1.0, maximum=1.3, value=1.03, step=0.01,
1297
+ label="Repetition Penalty",
1298
+ info="Increase if you see repetitive patterns"
1299
+ )
1300
+
1301
+ text_generate_btn = gr.Button(
1302
+ "🚀 Generate SVG / 生成SVG",
1303
+ variant="primary",
1304
+ size="lg",
1305
+ elem_classes=["primary-btn"]
1306
+ )
1307
+
1308
+ gr.Markdown("### 📝 Example Prompts (30)")
1309
+ gr.Examples(
1310
+ examples=[[text] for text in example_texts],
1311
+ inputs=[text_input],
1312
+ label=""
1313
+ )
1314
 
1315
+ with gr.Column(scale=2, min_width=500):
1316
+ gr.Markdown("### 🖼️ Generated SVG Candidates / 生成的SVG候选")
1317
+ gr.HTML("""
1318
+ <div style="background: #d4edda; padding: 10px; border-radius: 8px; margin-bottom: 10px;">
1319
+ <strong>💡 Pick the best from multiple candidates! / 从多个候选中选择最好的!</strong>
1320
+ </div>
1321
+ """)
1322
+ text_gallery = gr.HTML(
1323
+ value='<div style="text-align:center;color:#999;padding:50px;background:#fafafa;border-radius:12px;">Generated SVGs will appear here / 生成的SVG将显示在这里</div>'
1324
+ )
1325
+
1326
+ gr.Markdown("### 📝 SVG Code")
1327
+ text_svg_output = gr.Code(label="", language="html", lines=12, elem_classes=["code-output"])
1328
 
1329
+ text_generate_btn.click(
1330
+ fn=gradio_text_to_svg,
1331
+ inputs=[text_input, text_num_candidates, text_temperature, text_top_p,
1332
+ text_top_k, text_rep_penalty],
1333
+ outputs=[text_gallery, text_svg_output],
1334
+ queue=True
1335
+ )
1336
+
1337
+ # Footer
1338
+ gr.HTML("""
1339
+ <div class="footer">
1340
+ <p>Built with ❤️ using OmniSVG</p>
1341
+ <p style="color: #dc3545; font-weight: 600;">🔴 Remember: Generate 4-8 candidates and pick the best! / 记住:生成4-8个候选并选择最好的!</p>
1342
+ </div>
1343
+ """)
1344
+
1345
+ return demo
1346
 
 
 
1347
 
1348
  if __name__ == "__main__":
1349
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
1350
+
1351
  args = parse_args()
1352
+
1353
+ print("="*60)
1354
+ print("OmniSVG Generator - Gradio App")
1355
+ print("="*60)
1356
+ print(f"Model path: {args.model_path}")
1357
+ print(f"Weight path: {args.weight_path}")
1358
+ print(f"Device: {device}")
1359
+ print("="*60)
1360
+
1361
+ print("\nLoading models...")
1362
+ load_models(args.weight_path, args.model_path)
1363
+ print("Models loaded successfully!\n")
1364
+
1365
  demo = create_interface()
1366
+
1367
+ demo.queue(default_concurrency_limit=1, max_size=20)
1368
+
1369
+ demo.launch(
1370
+ server_name=args.listen,
1371
+ server_port=args.port,
1372
+ share=args.share,
1373
+ debug=args.debug,
1374
+ )