OmniSVG commited on
Commit
fd45ec3
·
verified ·
1 Parent(s): 062fa0c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +385 -221
app.py CHANGED
@@ -12,7 +12,6 @@ import glob
12
  import numpy as np
13
  import time
14
  import threading
15
- import copy
16
  import spaces
17
 
18
  from huggingface_hub import hf_hub_download, snapshot_download
@@ -22,167 +21,131 @@ from transformers import AutoTokenizer, AutoProcessor
22
  from qwen_vl_utils import process_vision_info
23
  from tokenizer import SVGTokenizer
24
 
25
- # ============================================================
26
- # Configuration Loading with Variant Support
27
- # ============================================================
28
- def load_config(config_path: str, variant: str = None) -> dict:
29
- """
30
- Load config file and merge variant-specific settings.
31
-
32
- Args:
33
- config_path: Path to the config.yaml file
34
- variant: Model variant ("8B" or "4B"). If None, uses default_variant from config.
35
-
36
- Returns:
37
- Merged configuration dictionary
38
- """
39
- with open(config_path, 'r') as f:
40
- raw_config = yaml.safe_load(f)
41
-
42
- # Determine which variant to use
43
- if variant is None:
44
- variant = raw_config.get('default_variant', '8B')
45
-
46
- # Check if variant exists
47
- variants = raw_config.get('variants', {})
48
- if variant not in variants:
49
- available = list(variants.keys())
50
- raise ValueError(f"Unknown model variant '{variant}'. Available variants: {available}")
51
-
52
- # Start with a copy of raw config (excluding 'variants' key)
53
- merged_config = {k: v for k, v in raw_config.items() if k != 'variants'}
54
-
55
- # Merge variant-specific settings
56
- variant_config = variants[variant]
57
- for key, value in variant_config.items():
58
- if isinstance(value, dict) and key in merged_config and isinstance(merged_config[key], dict):
59
- # Deep merge for nested dicts
60
- merged_config[key] = {**merged_config.get(key, {}), **value}
61
- else:
62
- merged_config[key] = value
63
-
64
- # Store the active variant name
65
- merged_config['active_variant'] = variant
66
-
67
- return merged_config
68
-
69
-
70
- def write_variant_config(config: dict, output_path: str):
71
- """
72
- Write a variant-specific config file for SVGTokenizer.
73
-
74
- Args:
75
- config: Merged configuration dictionary
76
- output_path: Path to write the temporary config file
77
- """
78
- # Create a config without the 'variants' and 'active_variant' keys
79
- clean_config = {k: v for k, v in config.items()
80
- if k not in ['variants', 'active_variant', 'default_variant']}
81
-
82
- with open(output_path, 'w') as f:
83
- yaml.safe_dump(clean_config, f, default_flow_style=False)
84
 
85
-
86
- # ============================================================
87
- # Global Variables
88
- # ============================================================
89
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
90
  DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
91
 
92
- # Global Models
93
  tokenizer = None
94
  processor = None
95
  sketch_decoder = None
96
  svg_tokenizer = None
97
-
98
- # Global Config (will be set after loading)
99
- config = None
100
- MODEL_VARIANT = None
101
 
102
  # Thread lock for model inference
103
  generation_lock = threading.Lock()
 
104
 
105
- # Constants (will be set from config)
106
  SYSTEM_PROMPT = """You are an expert SVG code generator.
107
  Generate precise, valid SVG path commands that accurately represent the described scene or object.
108
  Focus on capturing key shapes, spatial relationships, and visual composition."""
109
 
110
  SUPPORTED_FORMATS = ['.png', '.jpg', '.jpeg', '.webp', '.bmp', '.gif']
 
 
111
 
112
-
113
- def init_config_constants(cfg: dict):
114
- """Initialize global constants from config."""
115
- global TARGET_IMAGE_SIZE, RENDER_SIZE, BACKGROUND_THRESHOLD
116
- global EMPTY_THRESHOLD_ILLUSTRATION, EMPTY_THRESHOLD_ICON
117
- global EDGE_SAMPLE_RATIO, COLOR_SIMILARITY_THRESHOLD, MIN_EDGE_SAMPLES
118
- global BLACK_COLOR_TOKEN, BOS_TOKEN_ID, EOS_TOKEN_ID, PAD_TOKEN_ID, MAX_LENGTH
119
- global DEFAULT_QWEN_MODEL, DEFAULT_OMNISVG_MODEL
120
- global TASK_CONFIGS, DEFAULT_NUM_CANDIDATES, MAX_NUM_CANDIDATES, EXTRA_CANDIDATES_BUFFER
121
- global MIN_SVG_LENGTH
122
-
123
- # Image processing settings
124
- image_config = cfg.get('image', {})
125
- TARGET_IMAGE_SIZE = image_config.get('target_size', 448)
126
- RENDER_SIZE = image_config.get('render_size', 512)
127
- BACKGROUND_THRESHOLD = image_config.get('background_threshold', 240)
128
- EMPTY_THRESHOLD_ILLUSTRATION = image_config.get('empty_threshold_illustration', 250)
129
- EMPTY_THRESHOLD_ICON = image_config.get('empty_threshold_icon', 252)
130
- EDGE_SAMPLE_RATIO = image_config.get('edge_sample_ratio', 0.1)
131
- COLOR_SIMILARITY_THRESHOLD = image_config.get('color_similarity_threshold', 30)
132
- MIN_EDGE_SAMPLES = image_config.get('min_edge_samples', 10)
133
-
134
- # Color settings
135
- colors_config = cfg.get('colors', {})
136
- BLACK_COLOR_TOKEN = colors_config.get('black_color_token',
137
- colors_config.get('color_token_start', 40010) + 2)
138
-
139
- # Model settings
140
- model_config = cfg.get('model', {})
141
- BOS_TOKEN_ID = model_config.get('bos_token_id', 196998)
142
- EOS_TOKEN_ID = model_config.get('eos_token_id', 196999)
143
- PAD_TOKEN_ID = model_config.get('pad_token_id', 151643)
144
- MAX_LENGTH = model_config.get('max_length', 1536)
145
-
146
- # HuggingFace model IDs
147
- hf_config = cfg.get('huggingface', {})
148
- DEFAULT_QWEN_MODEL = hf_config.get('qwen_model', "Qwen/Qwen2.5-VL-7B-Instruct")
149
- DEFAULT_OMNISVG_MODEL = hf_config.get('omnisvg_model', "OmniSVG/OmniSVG1.1_8B")
150
-
151
- # Task configurations
152
- task_config = cfg.get('task_configs', {})
153
- TASK_CONFIGS = {
154
- "text-to-svg-icon": task_config.get('text_to_svg_icon', {
155
- "default_temperature": 0.5,
156
- "default_top_p": 0.88,
157
- "default_top_k": 50,
158
- "default_repetition_penalty": 1.05,
159
- }),
160
- "text-to-svg-illustration": task_config.get('text_to_svg_illustration', {
161
- "default_temperature": 0.6,
162
- "default_top_p": 0.90,
163
- "default_top_k": 60,
164
- "default_repetition_penalty": 1.03,
165
- }),
166
- "image-to-svg": task_config.get('image_to_svg', {
167
- "default_temperature": 0.3,
168
- "default_top_p": 0.90,
169
- "default_top_k": 50,
170
- "default_repetition_penalty": 1.05,
171
- })
172
- }
173
-
174
- # Generation parameters
175
- gen_config = cfg.get('generation', {})
176
- DEFAULT_NUM_CANDIDATES = gen_config.get('default_num_candidates', 4)
177
- MAX_NUM_CANDIDATES = gen_config.get('max_num_candidates', 8)
178
- EXTRA_CANDIDATES_BUFFER = gen_config.get('extra_candidates_buffer', 4)
179
 
180
- # Validation settings
181
- validation_config = cfg.get('validation', {})
182
- MIN_SVG_LENGTH = validation_config.get('min_svg_length', 20)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
 
 
 
 
 
184
 
185
- # Custom CSS (same as before)
186
  CUSTOM_CSS = """
187
  /* Main container centering */
188
  .gradio-container {
@@ -209,14 +172,18 @@ CUSTOM_CSS = """
209
  opacity: 0.9;
210
  font-size: 1.1em;
211
  }
212
- /* Model badge styling */
213
- .model-badge {
214
- display: inline-block;
215
- background: rgba(255,255,255,0.2);
216
- padding: 4px 12px;
217
- border-radius: 20px;
218
- font-size: 0.85em;
219
- margin-top: 8px;
 
 
 
 
220
  }
221
  /* Tips section */
222
  .tips-box {
@@ -295,6 +262,17 @@ CUSTOM_CSS = """
295
  .green-box strong {
296
  color: #4caf50;
297
  }
 
 
 
 
 
 
 
 
 
 
 
298
  /* Tab styling */
299
  .tabs {
300
  border-radius: 12px !important;
@@ -363,7 +341,7 @@ CUSTOM_CSS = """
363
  }
364
  """
365
 
366
- # Enhanced Tips HTML (same as before - abbreviated for brevity)
367
  TIPS_HTML = """
368
  <div class="tips-box">
369
  <h3>Prompting Guide & Best Practices</h3>
@@ -390,6 +368,15 @@ TIPS_HTML = """
390
  </ul>
391
  </div>
392
 
 
 
 
 
 
 
 
 
 
393
  <!-- Parameter Tuning Tips -->
394
  <div class="orange-box">
395
  <strong>Parameter Tuning Guide</strong>
@@ -455,15 +442,32 @@ TIPS_HTML = """
455
  <div class="example-prompt">
456
  "A simple person: round beige head, rectangular blue shirt body, two dark gray rectangular legs. Standing pose, arms at sides, flat colors."
457
  </div>
 
 
 
458
  <p class="red-tip">Keep poses SIMPLE: standing, sitting, waving. Avoid complex actions!</p>
459
  </div>
460
 
 
 
 
 
 
 
 
 
 
 
 
461
  <div class="tip-category">
462
  <h4>Landscapes & Scenes</h4>
463
  <p>Layer elements from background to foreground. Specify color for EACH layer.</p>
464
  <div class="example-prompt">
465
  "Layered landscape: light blue sky at top, gray triangular mountains in middle, dark green triangular pine trees at bottom. Flat colors, simple shapes."
466
  </div>
 
 
 
467
  <p class="red-tip">Use geometric shapes for nature: triangular trees, wavy water, semicircle sun!</p>
468
  </div>
469
 
@@ -473,10 +477,64 @@ TIPS_HTML = """
473
  <div class="example-prompt">
474
  "Cute cat: orange round head with two triangular ears, oval orange body, curved tail. Simple cartoon style with black outlines, sitting pose."
475
  </div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
476
  </div>
477
 
478
  </div>
479
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
480
  <!-- Quick Troubleshooting -->
481
  <div class="green-box" style="margin-top: 15px;">
482
  <strong>Quick Troubleshooting</strong>
@@ -489,6 +547,17 @@ TIPS_HTML = """
489
  <li><strong>Inconsistent?</strong> <span class="red-tip">Generate MORE candidates (6-8) and pick the best!</span></li>
490
  </ul>
491
  </div>
 
 
 
 
 
 
 
 
 
 
 
492
  </div>
493
  """
494
 
@@ -513,10 +582,9 @@ def parse_args():
513
  parser.add_argument('--port', type=int, default=7860)
514
  parser.add_argument('--share', action='store_true')
515
  parser.add_argument('--debug', action='store_true')
516
- parser.add_argument('--model_size', type=str, default=None, choices=['8B', '4B'],
517
- help='Model size variant to use (8B or 4B). Overrides config default.')
518
- parser.add_argument('--config', type=str, default='./config.yaml',
519
- help='Path to config file (default: ./config.yaml)')
520
  parser.add_argument('--weight_path', type=str, default=None,
521
  help='HuggingFace repo ID or local path for OmniSVG weights (overrides config)')
522
  parser.add_argument('--model_path', type=str, default=None,
@@ -527,6 +595,13 @@ def parse_args():
527
  def download_model_weights(repo_id: str, filename: str = "pytorch_model.bin") -> str:
528
  """
529
  Download model weights from Hugging Face Hub.
 
 
 
 
 
 
 
530
  """
531
  print(f"Downloading {filename} from {repo_id}...")
532
  try:
@@ -555,20 +630,29 @@ def is_local_path(path: str) -> bool:
555
  return False
556
 
557
 
558
- def load_models(weight_path: str, model_path: str, variant_config_path: str):
559
  """
560
- Load all models with support for both local paths and HuggingFace Hub.
561
 
562
  Args:
563
- weight_path: Local path or HuggingFace repo ID for OmniSVG weights
564
- model_path: Local path or HuggingFace repo ID for Qwen model
565
- variant_config_path: Path to the variant-specific config file for SVGTokenizer
566
  """
567
- global tokenizer, processor, sketch_decoder, svg_tokenizer
568
 
569
- print(f"Loading Qwen model from: {model_path}")
570
- print(f"Loading OmniSVG weights from: {weight_path}")
571
- print(f"Using precision: {DTYPE}")
 
 
 
 
 
 
 
 
 
572
 
573
  # Load Qwen tokenizer and processor
574
  print("\n[1/3] Loading tokenizer and processor...")
@@ -585,12 +669,14 @@ def load_models(weight_path: str, model_path: str, variant_config_path: str):
585
  processor.tokenizer.padding_side = "left"
586
  print("Tokenizer and processor loaded successfully!")
587
 
588
- # Initialize sketch decoder
589
  print("\n[2/3] Initializing SketchDecoder...")
590
  sketch_decoder = SketchDecoder(
 
 
 
591
  pix_len=MAX_LENGTH,
592
  text_len=config.get('text', {}).get('max_length', 200),
593
- model_path=model_path,
594
  torch_dtype=DTYPE
595
  )
596
 
@@ -618,14 +704,46 @@ def load_models(weight_path: str, model_path: str, variant_config_path: str):
618
 
619
  sketch_decoder = sketch_decoder.to(device).eval()
620
 
621
- # Initialize SVG tokenizer with variant-specific config
622
- svg_tokenizer = SVGTokenizer(variant_config_path)
 
 
623
 
624
  print("\n" + "="*60)
625
- print("All models loaded successfully!")
626
  print("="*60 + "\n")
627
 
628
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
629
  def detect_text_subtype(text_prompt):
630
  """Auto-detect text prompt subtype"""
631
  text_lower = text_prompt.lower()
@@ -652,7 +770,9 @@ def detect_text_subtype(text_prompt):
652
 
653
 
654
  def detect_and_replace_background(image, threshold=None, edge_sample_ratio=None):
655
- """Detect if image has non-white background and optionally replace it."""
 
 
656
  if threshold is None:
657
  threshold = BACKGROUND_THRESHOLD
658
  if edge_sample_ratio is None:
@@ -686,6 +806,11 @@ def detect_and_replace_background(image, threshold=None, edge_sample_ratio=None)
686
  return image, False
687
 
688
  if len(img_array.shape) == 3 and img_array.shape[2] >= 3:
 
 
 
 
 
689
  edge_colors = []
690
  for i in range(w):
691
  edge_colors.append(tuple(img_array[0, i, :3]))
@@ -713,7 +838,9 @@ def detect_and_replace_background(image, threshold=None, edge_sample_ratio=None)
713
 
714
 
715
  def preprocess_image_for_svg(image, replace_background=True, target_size=None):
716
- """Preprocess image for SVG generation."""
 
 
717
  if target_size is None:
718
  target_size = TARGET_IMAGE_SIZE
719
 
@@ -889,7 +1016,8 @@ def generate_candidates(inputs, task_type, subtype, temperature, top_p, top_k, r
889
 
890
  all_candidates = []
891
 
892
- gen_config = {
 
893
  'do_sample': True,
894
  'temperature': temperature,
895
  'top_p': top_p,
@@ -918,7 +1046,7 @@ def generate_candidates(inputs, task_type, subtype, temperature, top_p, top_k, r
918
  max_new_tokens=max_length,
919
  num_return_sequences=actual_samples,
920
  use_cache=True,
921
- **gen_config
922
  )
923
 
924
  input_len = input_ids.shape[1]
@@ -996,7 +1124,7 @@ def gradio_text_to_svg(text_description, num_candidates, temperature, top_p, top
996
  return '<div style="text-align:center;color:#999;padding:50px;">Please enter a description</div>', ""
997
 
998
  print("\n" + "="*60)
999
- print(f"[TASK] text-to-svg ({MODEL_VARIANT})")
1000
  print(f"[INPUT] {text_description[:100]}{'...' if len(text_description) > 100 else ''}")
1001
  print(f"[PARAMS] candidates={num_candidates}, temp={temperature}, top_p={top_p}, top_k={top_k}, rep_penalty={repetition_penalty}")
1002
  print("="*60)
@@ -1062,7 +1190,7 @@ def gradio_image_to_svg(image, num_candidates, temperature, top_p, top_k, repeti
1062
  )
1063
 
1064
  print("\n" + "="*60)
1065
- print(f"[TASK] image-to-svg ({MODEL_VARIANT})")
1066
  print(f"[INPUT] Image size: {image.size if hasattr(image, 'size') else 'unknown'}, mode: {image.mode if hasattr(image, 'mode') else 'unknown'}")
1067
  print(f"[PARAMS] candidates={num_candidates}, temp={temperature}, top_p={top_p}, top_k={top_k}, rep_penalty={repetition_penalty}, replace_bg={replace_background}")
1068
  print("="*60)
@@ -1149,43 +1277,101 @@ def get_example_images():
1149
  def create_interface():
1150
  """Create Gradio interface"""
1151
 
1152
- # Example prompts
1153
  example_texts = [
 
1154
  "A black triangle pointing downward, centrally positioned.",
1155
  "A red heart shape with smooth curved edges, centered.",
1156
  "A yellow star with five sharp points, simple geometric design, flat color.",
1157
  "A blue arrow pointing to the right, thick solid shape, centered.",
1158
  "A green circle with a white checkmark inside, centered.",
1159
  "A black plus sign with equal length arms, thick lines, centered.",
 
 
1160
  "A simple person standing: round beige head, rectangular blue shirt body, two dark gray rectangular legs, arms at sides. Flat colors.",
1161
  "A girl with long black hair, wearing pink dress with triangular skirt, small circular face with dot eyes and curved smile. Simple cartoon style.",
 
 
 
 
 
1162
  "Circular avatar: person with short black hair, round face with two dot eyes and small curved smile, wearing blue collar shirt. Minimal style, centered in circle.",
 
 
 
 
 
 
1163
  "Layered mountain landscape: light blue sky at top, gray triangular snow-capped mountains in middle, dark green triangular pine trees at bottom. Flat colors.",
1164
  "Sunset beach scene: orange gradient sky at top, yellow semicircle sun on horizon, dark blue wavy ocean, tan beach strip at bottom. Simple shapes.",
 
 
 
 
 
 
1165
  "Cute orange cat sitting: round head with two triangular ears, oval body, curved tail. Black outline cartoon style, facing forward.",
 
 
 
 
 
1166
  "Simple house icon: red triangular roof, beige rectangular walls, brown door in center, two blue square windows, green ground at bottom.",
1167
  "Coffee mug: brown cylindrical cup with curved handle on right, three wavy steam lines rising from top. Flat style.",
1168
- "Red fox logo: triangular orange face with pointed ears, white chest marking, bushy tail. Minimalist style, facing right, centered.",
1169
  ]
1170
 
1171
  example_images = get_example_images()
1172
 
1173
- # Dynamic header with model info
1174
- header_html = f"""
1175
- <div class="header-container">
1176
- <h1>OmniSVG Generator</h1>
1177
- <p>Transform images and text descriptions into scalable vector graphics</p>
1178
- <div class="model-badge">Model: OmniSVG {MODEL_VARIANT} | Qwen: {DEFAULT_QWEN_MODEL.split('/')[-1]}</div>
1179
- </div>
1180
- """
1181
-
1182
- with gr.Blocks(title=f"OmniSVG Generator ({MODEL_VARIANT})") as demo:
1183
  # Header
1184
- gr.HTML(header_html)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1185
 
1186
  # Queue status
1187
  gr.HTML("""
1188
- <div style="background: #e7f3ff; border: 1px solid #b3d7ff; border-radius: 8px; padding: 12px 15px; margin-bottom: 15px;">
1189
  <span style="font-size: 1.5em;">ℹ️</span>
1190
  <strong>Queue System Active</strong> - Requests processed one at a time. Please wait patiently if busy.
1191
  </div>
@@ -1340,7 +1526,7 @@ def create_interface():
1340
  elem_classes=["primary-btn"]
1341
  )
1342
 
1343
- gr.Markdown("### Example Prompts")
1344
  gr.Examples(
1345
  examples=[[text] for text in example_texts],
1346
  inputs=[text_input],
@@ -1372,7 +1558,7 @@ def create_interface():
1372
  # Footer
1373
  gr.HTML(f"""
1374
  <div class="footer">
1375
- <p>Built with OmniSVG {MODEL_VARIANT}</p>
1376
  <p style="color: #dc3545; font-weight: 600;">Remember: Generate 4-8 candidates and pick the best!</p>
1377
  </div>
1378
  """)
@@ -1385,31 +1571,20 @@ if __name__ == "__main__":
1385
 
1386
  args = parse_args()
1387
 
 
 
 
1388
  print("="*60)
1389
  print("OmniSVG Demo Page - Gradio App")
1390
  print("="*60)
1391
-
1392
- # Load config with variant support
1393
- print(f"\nLoading config from: {args.config}")
1394
- config = load_config(args.config, variant=args.model_size)
1395
- MODEL_VARIANT = config['active_variant']
1396
-
1397
- # Initialize constants from config
1398
- init_config_constants(config)
1399
-
1400
- # Override model paths if provided via command line
1401
- weight_path = args.weight_path if args.weight_path else DEFAULT_OMNISVG_MODEL
1402
- model_path = args.model_path if args.model_path else DEFAULT_QWEN_MODEL
1403
-
1404
- print(f"\n[CONFIG] Active variant: {MODEL_VARIANT}")
1405
- print(f"[CONFIG] Qwen model: {model_path}")
1406
- print(f"[CONFIG] OmniSVG weights: {weight_path}")
1407
- print(f"[CONFIG] Device: {device}")
1408
- print(f"[CONFIG] Precision: {DTYPE}")
1409
  print("="*60)
1410
 
1411
  # Print loaded config values
1412
- print("\n[CONFIG] Loaded settings:")
1413
  print(f" - TARGET_IMAGE_SIZE: {TARGET_IMAGE_SIZE}")
1414
  print(f" - RENDER_SIZE: {RENDER_SIZE}")
1415
  print(f" - BLACK_COLOR_TOKEN: {BLACK_COLOR_TOKEN}")
@@ -1417,21 +1592,10 @@ if __name__ == "__main__":
1417
  print(f" - BOS_TOKEN_ID: {BOS_TOKEN_ID}")
1418
  print(f" - EOS_TOKEN_ID: {EOS_TOKEN_ID}")
1419
  print(f" - PAD_TOKEN_ID: {PAD_TOKEN_ID}")
1420
-
1421
- # Print variant-specific token offsets
1422
- print(f"\n[CONFIG] Variant-specific ({MODEL_VARIANT}):")
1423
- print(f" - base_offset: {config.get('tokens', {}).get('base_offset', 'N/A')}")
1424
- print(f" - color_start_offset: {config.get('colors', {}).get('color_start_offset', 'N/A')}")
1425
- print(f" - color_end_offset: {config.get('colors', {}).get('color_end_offset', 'N/A')}")
1426
  print("="*60)
1427
 
1428
- # Write variant-specific config for SVGTokenizer
1429
- variant_config_path = f'./config_{MODEL_VARIANT.lower()}_runtime.yaml'
1430
- write_variant_config(config, variant_config_path)
1431
- print(f"\n[CONFIG] Written variant config to: {variant_config_path}")
1432
-
1433
  print("\nLoading models (may download from HuggingFace Hub if needed)...")
1434
- load_models(weight_path, model_path, variant_config_path)
1435
  print("Models loaded successfully!\n")
1436
 
1437
  demo = create_interface()
 
12
  import numpy as np
13
  import time
14
  import threading
 
15
  import spaces
16
 
17
  from huggingface_hub import hf_hub_download, snapshot_download
 
21
  from qwen_vl_utils import process_vision_info
22
  from tokenizer import SVGTokenizer
23
 
24
+ # Load config
25
+ CONFIG_PATH = './config.yaml'
26
+ with open(CONFIG_PATH, 'r') as f:
27
+ config = yaml.safe_load(f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
 
 
 
 
29
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
30
  DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
31
 
32
+ # Global Models (will be loaded based on selected model size)
33
  tokenizer = None
34
  processor = None
35
  sketch_decoder = None
36
  svg_tokenizer = None
37
+ current_model_size = None # Track which model is currently loaded
 
 
 
38
 
39
  # Thread lock for model inference
40
  generation_lock = threading.Lock()
41
+ model_loading_lock = threading.Lock()
42
 
43
+ # Constants from config
44
  SYSTEM_PROMPT = """You are an expert SVG code generator.
45
  Generate precise, valid SVG path commands that accurately represent the described scene or object.
46
  Focus on capturing key shapes, spatial relationships, and visual composition."""
47
 
48
  SUPPORTED_FORMATS = ['.png', '.jpg', '.jpeg', '.webp', '.bmp', '.gif']
49
+ AVAILABLE_MODEL_SIZES = list(config.get('models', {}).keys())
50
+ DEFAULT_MODEL_SIZE = config.get('default_model_size', '8B')
51
 
52
+ # ============================================================
53
+ # Helper function to get config value (model-specific or shared)
54
+ # ============================================================
55
+ def get_config_value(model_size, *keys):
56
+ """Get config value with model-specific override support."""
57
+ # Try model-specific config first
58
+ model_cfg = config.get('models', {}).get(model_size, {})
59
+ value = model_cfg
60
+ for key in keys:
61
+ if isinstance(value, dict) and key in value:
62
+ value = value[key]
63
+ else:
64
+ value = None
65
+ break
66
+
67
+ # Fallback to shared config if not found
68
+ if value is None:
69
+ value = config
70
+ for key in keys:
71
+ if isinstance(value, dict) and key in value:
72
+ value = value[key]
73
+ else:
74
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
+ return value
77
+
78
+
79
+ # ============================================================
80
+ # Image processing settings from config (shared)
81
+ # ============================================================
82
+ image_config = config.get('image', {})
83
+ TARGET_IMAGE_SIZE = image_config.get('target_size', 448)
84
+ RENDER_SIZE = image_config.get('render_size', 512)
85
+ BACKGROUND_THRESHOLD = image_config.get('background_threshold', 240)
86
+ EMPTY_THRESHOLD_ILLUSTRATION = image_config.get('empty_threshold_illustration', 250)
87
+ EMPTY_THRESHOLD_ICON = image_config.get('empty_threshold_icon', 252)
88
+ EDGE_SAMPLE_RATIO = image_config.get('edge_sample_ratio', 0.1)
89
+ COLOR_SIMILARITY_THRESHOLD = image_config.get('color_similarity_threshold', 30)
90
+ MIN_EDGE_SAMPLES = image_config.get('min_edge_samples', 10)
91
+
92
+ # ============================================================
93
+ # Color settings from config (shared)
94
+ # ============================================================
95
+ colors_config = config.get('colors', {})
96
+ BLACK_COLOR_TOKEN = colors_config.get('black_color_token',
97
+ colors_config.get('color_token_start', 40010) + 2)
98
+
99
+ # ============================================================
100
+ # Model settings from config (shared)
101
+ # ============================================================
102
+ model_config = config.get('model', {})
103
+ BOS_TOKEN_ID = model_config.get('bos_token_id', 196998)
104
+ EOS_TOKEN_ID = model_config.get('eos_token_id', 196999)
105
+ PAD_TOKEN_ID = model_config.get('pad_token_id', 151643)
106
+ MAX_LENGTH = model_config.get('max_length', 1536)
107
+
108
+ # ============================================================
109
+ # Task configurations with defaults from config (shared)
110
+ # ============================================================
111
+ task_config = config.get('task_configs', {})
112
+
113
+ TASK_CONFIGS = {
114
+ "text-to-svg-icon": task_config.get('text_to_svg_icon', {
115
+ "default_temperature": 0.5,
116
+ "default_top_p": 0.88,
117
+ "default_top_k": 50,
118
+ "default_repetition_penalty": 1.05,
119
+ }),
120
+ "text-to-svg-illustration": task_config.get('text_to_svg_illustration', {
121
+ "default_temperature": 0.6,
122
+ "default_top_p": 0.90,
123
+ "default_top_k": 60,
124
+ "default_repetition_penalty": 1.03,
125
+ }),
126
+ "image-to-svg": task_config.get('image_to_svg', {
127
+ "default_temperature": 0.3,
128
+ "default_top_p": 0.90,
129
+ "default_top_k": 50,
130
+ "default_repetition_penalty": 1.05,
131
+ })
132
+ }
133
+
134
+ # ============================================================
135
+ # Generation parameters from config (shared)
136
+ # ============================================================
137
+ gen_config = config.get('generation', {})
138
+ DEFAULT_NUM_CANDIDATES = gen_config.get('default_num_candidates', 4)
139
+ MAX_NUM_CANDIDATES = gen_config.get('max_num_candidates', 8)
140
+ EXTRA_CANDIDATES_BUFFER = gen_config.get('extra_candidates_buffer', 4)
141
 
142
+ # ============================================================
143
+ # Validation settings from config (shared)
144
+ # ============================================================
145
+ validation_config = config.get('validation', {})
146
+ MIN_SVG_LENGTH = validation_config.get('min_svg_length', 20)
147
 
148
+ # Custom CSS
149
  CUSTOM_CSS = """
150
  /* Main container centering */
151
  .gradio-container {
 
172
  opacity: 0.9;
173
  font-size: 1.1em;
174
  }
175
+ /* Model selector styling */
176
+ .model-selector {
177
+ background: #f0f4f8;
178
+ border: 2px solid #667eea;
179
+ border-radius: 12px;
180
+ padding: 15px;
181
+ margin-bottom: 20px;
182
+ }
183
+ .model-selector-title {
184
+ font-weight: 700;
185
+ color: #667eea;
186
+ margin-bottom: 10px;
187
  }
188
  /* Tips section */
189
  .tips-box {
 
262
  .green-box strong {
263
  color: #4caf50;
264
  }
265
+ .blue-box {
266
+ background: #e3f2fd;
267
+ border: 1px solid #90caf9;
268
+ border-left: 4px solid #2196f3;
269
+ padding: 12px;
270
+ border-radius: 8px;
271
+ margin: 10px 0;
272
+ }
273
+ .blue-box strong {
274
+ color: #2196f3;
275
+ }
276
  /* Tab styling */
277
  .tabs {
278
  border-radius: 12px !important;
 
341
  }
342
  """
343
 
344
+ # Enhanced Tips HTML
345
  TIPS_HTML = """
346
  <div class="tips-box">
347
  <h3>Prompting Guide & Best Practices</h3>
 
368
  </ul>
369
  </div>
370
 
371
+ <!-- Model Selection Tips -->
372
+ <div class="blue-box">
373
+ <strong>Model Selection Guide</strong>
374
+ <ul style="margin: 8px 0 0 0; padding-left: 20px;">
375
+ <li><strong>8B Model:</strong> Higher quality, more details, better for complex illustrations. Requires more VRAM (~16GB+).</li>
376
+ <li><strong>4B Model:</strong> Faster, less VRAM required (~8GB+). Good for simple icons and basic shapes.</li>
377
+ </ul>
378
+ </div>
379
+
380
  <!-- Parameter Tuning Tips -->
381
  <div class="orange-box">
382
  <strong>Parameter Tuning Guide</strong>
 
442
  <div class="example-prompt">
443
  "A simple person: round beige head, rectangular blue shirt body, two dark gray rectangular legs. Standing pose, arms at sides, flat colors."
444
  </div>
445
+ <div class="example-prompt">
446
+ "A girl with long black hair, pink dress with triangular skirt shape, small circular face with dot eyes and curved smile. Simple cartoon style."
447
+ </div>
448
  <p class="red-tip">Keep poses SIMPLE: standing, sitting, waving. Avoid complex actions!</p>
449
  </div>
450
 
451
+ <div class="tip-category">
452
+ <h4>Avatars & Portraits</h4>
453
+ <p>Use circular frame, focus on face and upper body only.</p>
454
+ <div class="example-prompt">
455
+ "Circular avatar: person with short black hair, round face with two dot eyes and small curved smile, wearing blue collar shirt. Minimal style."
456
+ </div>
457
+ <div class="example-prompt">
458
+ "Profile avatar silhouette: black side view of head with short hair, facing right. Simple solid shape."
459
+ </div>
460
+ </div>
461
+
462
  <div class="tip-category">
463
  <h4>Landscapes & Scenes</h4>
464
  <p>Layer elements from background to foreground. Specify color for EACH layer.</p>
465
  <div class="example-prompt">
466
  "Layered landscape: light blue sky at top, gray triangular mountains in middle, dark green triangular pine trees at bottom. Flat colors, simple shapes."
467
  </div>
468
+ <div class="example-prompt">
469
+ "Sunset beach: orange gradient sky at top, yellow semicircle sun on horizon, dark blue wavy ocean below, tan beach at bottom."
470
+ </div>
471
  <p class="red-tip">Use geometric shapes for nature: triangular trees, wavy water, semicircle sun!</p>
472
  </div>
473
 
 
477
  <div class="example-prompt">
478
  "Cute cat: orange round head with two triangular ears, oval orange body, curved tail. Simple cartoon style with black outlines, sitting pose."
479
  </div>
480
+ <div class="example-prompt">
481
+ "Simple black bird: oval body, small round head, pointed triangular beak facing right, triangular tail, two stick legs. Silhouette style."
482
+ </div>
483
+ </div>
484
+
485
+ <div class="tip-category">
486
+ <h4>Buildings & Objects</h4>
487
+ <p>Use basic shapes: rectangles for walls, triangles for roofs, squares for windows.</p>
488
+ <div class="example-prompt">
489
+ "Simple house: red triangular roof on top, beige rectangular wall, brown rectangular door in center, two small blue square windows. Green ground at bottom."
490
+ </div>
491
+ <div class="example-prompt">
492
+ "Coffee mug: brown cylindrical cup shape with curved handle on right side, three wavy steam lines rising from top. Simple flat style."
493
+ </div>
494
  </div>
495
 
496
  </div>
497
 
498
+ <!-- Extended Examples Section -->
499
+ <div style="margin-top: 20px; padding: 15px; background: #f0f7ff; border-radius: 10px; border: 1px solid #cce5ff;">
500
+ <h4 style="margin-top: 0; color: #0066cc;">More Complex Examples (Generate 6-8 candidates!)</h4>
501
+
502
+ <div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(280px, 1fr)); gap: 12px; margin-top: 15px;">
503
+ <div class="example-prompt">
504
+ <strong>Business Avatar:</strong><br/>
505
+ "Circular professional avatar: man with short black hair, neutral skin tone round face, wearing dark navy suit with white shirt collar visible. Clean minimal style, centered in circle."
506
+ </div>
507
+ <div class="example-prompt">
508
+ <strong>Female Portrait:</strong><br/>
509
+ "Simple female face: oval face shape, long brown wavy hair on sides, two dot eyes, small nose, curved smile lips. Pink blush on cheeks. Cartoon portrait style."
510
+ </div>
511
+ <div class="example-prompt">
512
+ <strong>Child Character:</strong><br/>
513
+ "Cute child standing: large round head with short brown hair, big circular eyes with white highlights, small body in red t-shirt and blue shorts, simple stick arms and legs. Cheerful cartoon style."
514
+ </div>
515
+ <div class="example-prompt">
516
+ <strong>Active Pose:</strong><br/>
517
+ "Person walking: side view, circular head, rectangular torso in green jacket, legs in walking position (one forward, one back). Simple geometric style, moving right."
518
+ </div>
519
+ <div class="example-prompt">
520
+ <strong>Forest Scene:</strong><br/>
521
+ "Simple forest: light blue sky, row of 5 dark green triangular pine trees of varying heights, brown rectangular trunks, light green grass strip at bottom. Layered flat design."
522
+ </div>
523
+ <div class="example-prompt">
524
+ <strong>Ocean View:</strong><br/>
525
+ "Minimalist ocean: gradient blue sky at top, three horizontal wavy lines in dark blue for ocean, small white sailboat with triangular sail in center. Clean vector style."
526
+ </div>
527
+ <div class="example-prompt">
528
+ <strong>City Skyline:</strong><br/>
529
+ "Simple city skyline: orange sunset sky gradient, row of black rectangular building silhouettes of different heights, some with small yellow square windows. Minimalist style."
530
+ </div>
531
+ <div class="example-prompt">
532
+ <strong>Dog Character:</strong><br/>
533
+ "Friendly cartoon dog: brown oval body, round head with floppy ears, black dot nose, curved tail pointing up, four short legs. Sitting pose facing forward."
534
+ </div>
535
+ </div>
536
+ </div>
537
+
538
  <!-- Quick Troubleshooting -->
539
  <div class="green-box" style="margin-top: 15px;">
540
  <strong>Quick Troubleshooting</strong>
 
547
  <li><strong>Inconsistent?</strong> <span class="red-tip">Generate MORE candidates (6-8) and pick the best!</span></li>
548
  </ul>
549
  </div>
550
+
551
+ <!-- Prompt Template -->
552
+ <div style="margin-top: 15px; padding: 12px; background: #e8f5e9; border-radius: 8px; border-left: 4px solid #4caf50;">
553
+ <strong>Recommended Prompt Structure</strong>
554
+ <div style="background: white; padding: 10px; border-radius: 6px; margin-top: 8px; font-family: monospace; font-size: 0.9em;">
555
+ [Subject] + [Shape descriptions with colors] + [Position/orientation] + [Style]
556
+ </div>
557
+ <p style="margin: 10px 0 0 0; color: #2e7d32; font-size: 0.95em;">
558
+ Example: "A fox logo: triangular orange head, pointed ears, white chest marking, facing right. Minimalist flat style, centered."
559
+ </p>
560
+ </div>
561
  </div>
562
  """
563
 
 
582
  parser.add_argument('--port', type=int, default=7860)
583
  parser.add_argument('--share', action='store_true')
584
  parser.add_argument('--debug', action='store_true')
585
+ parser.add_argument('--model_size', type=str, default=None,
586
+ choices=AVAILABLE_MODEL_SIZES,
587
+ help=f'Model size to load at startup (default: {DEFAULT_MODEL_SIZE}). Can be changed in UI.')
 
588
  parser.add_argument('--weight_path', type=str, default=None,
589
  help='HuggingFace repo ID or local path for OmniSVG weights (overrides config)')
590
  parser.add_argument('--model_path', type=str, default=None,
 
595
  def download_model_weights(repo_id: str, filename: str = "pytorch_model.bin") -> str:
596
  """
597
  Download model weights from Hugging Face Hub.
598
+
599
+ Args:
600
+ repo_id: Hugging Face repository ID (e.g., 'OmniSVG/OmniSVG1.1_8B')
601
+ filename: Name of the weights file to download
602
+
603
+ Returns:
604
+ Local path to the downloaded file
605
  """
606
  print(f"Downloading {filename} from {repo_id}...")
607
  try:
 
630
  return False
631
 
632
 
633
+ def load_models(model_size: str, weight_path: str = None, model_path: str = None):
634
  """
635
+ Load all models for a specific model size.
636
 
637
  Args:
638
+ model_size: Model size ("8B" or "4B")
639
+ weight_path: Local path or HuggingFace repo ID for OmniSVG weights (optional, uses config if None)
640
+ model_path: Local path or HuggingFace repo ID for Qwen model (optional, uses config if None)
641
  """
642
+ global tokenizer, processor, sketch_decoder, svg_tokenizer, current_model_size
643
 
644
+ # Use config values if not provided
645
+ if weight_path is None:
646
+ weight_path = get_config_value(model_size, 'huggingface', 'omnisvg_model')
647
+ if model_path is None:
648
+ model_path = get_config_value(model_size, 'huggingface', 'qwen_model')
649
+
650
+ print(f"\n{'='*60}")
651
+ print(f"Loading {model_size} Model")
652
+ print(f"{'='*60}")
653
+ print(f"Qwen model: {model_path}")
654
+ print(f"OmniSVG weights: {weight_path}")
655
+ print(f"Precision: {DTYPE}")
656
 
657
  # Load Qwen tokenizer and processor
658
  print("\n[1/3] Loading tokenizer and processor...")
 
669
  processor.tokenizer.padding_side = "left"
670
  print("Tokenizer and processor loaded successfully!")
671
 
672
+ # Initialize sketch decoder with model_size
673
  print("\n[2/3] Initializing SketchDecoder...")
674
  sketch_decoder = SketchDecoder(
675
+ config_path=CONFIG_PATH,
676
+ model_path=model_path,
677
+ model_size=model_size,
678
  pix_len=MAX_LENGTH,
679
  text_len=config.get('text', {}).get('max_length', 200),
 
680
  torch_dtype=DTYPE
681
  )
682
 
 
704
 
705
  sketch_decoder = sketch_decoder.to(device).eval()
706
 
707
+ # Initialize SVG tokenizer with model_size
708
+ svg_tokenizer = SVGTokenizer(CONFIG_PATH, model_size=model_size)
709
+
710
+ current_model_size = model_size
711
 
712
  print("\n" + "="*60)
713
+ print(f"All {model_size} models loaded successfully!")
714
  print("="*60 + "\n")
715
 
716
 
717
+ def switch_model(new_model_size: str):
718
+ """
719
+ Switch to a different model size.
720
+
721
+ Args:
722
+ new_model_size: Target model size ("8B" or "4B")
723
+
724
+ Returns:
725
+ Status message
726
+ """
727
+ global current_model_size
728
+
729
+ if new_model_size == current_model_size:
730
+ return f"✅ Already using {new_model_size} model"
731
+
732
+ with model_loading_lock:
733
+ # Clear memory
734
+ gc.collect()
735
+ if torch.cuda.is_available():
736
+ torch.cuda.empty_cache()
737
+
738
+ try:
739
+ load_models(new_model_size)
740
+ return f"✅ Successfully switched to {new_model_size} model"
741
+ except Exception as e:
742
+ error_msg = f"❌ Failed to switch to {new_model_size}: {str(e)}"
743
+ print(error_msg)
744
+ return error_msg
745
+
746
+
747
  def detect_text_subtype(text_prompt):
748
  """Auto-detect text prompt subtype"""
749
  text_lower = text_prompt.lower()
 
770
 
771
 
772
  def detect_and_replace_background(image, threshold=None, edge_sample_ratio=None):
773
+ """
774
+ Detect if image has non-white background and optionally replace it.
775
+ """
776
  if threshold is None:
777
  threshold = BACKGROUND_THRESHOLD
778
  if edge_sample_ratio is None:
 
806
  return image, False
807
 
808
  if len(img_array.shape) == 3 and img_array.shape[2] >= 3:
809
+ if img_array.shape[2] == 4:
810
+ gray = np.mean(img_array[:, :, :3], axis=2)
811
+ else:
812
+ gray = np.mean(img_array, axis=2)
813
+
814
  edge_colors = []
815
  for i in range(w):
816
  edge_colors.append(tuple(img_array[0, i, :3]))
 
838
 
839
 
840
  def preprocess_image_for_svg(image, replace_background=True, target_size=None):
841
+ """
842
+ Preprocess image for SVG generation.
843
+ """
844
  if target_size is None:
845
  target_size = TARGET_IMAGE_SIZE
846
 
 
1016
 
1017
  all_candidates = []
1018
 
1019
+ # Generation config with user parameters
1020
+ gen_cfg = {
1021
  'do_sample': True,
1022
  'temperature': temperature,
1023
  'top_p': top_p,
 
1046
  max_new_tokens=max_length,
1047
  num_return_sequences=actual_samples,
1048
  use_cache=True,
1049
+ **gen_cfg
1050
  )
1051
 
1052
  input_len = input_ids.shape[1]
 
1124
  return '<div style="text-align:center;color:#999;padding:50px;">Please enter a description</div>', ""
1125
 
1126
  print("\n" + "="*60)
1127
+ print(f"[TASK] text-to-svg (Model: {current_model_size})")
1128
  print(f"[INPUT] {text_description[:100]}{'...' if len(text_description) > 100 else ''}")
1129
  print(f"[PARAMS] candidates={num_candidates}, temp={temperature}, top_p={top_p}, top_k={top_k}, rep_penalty={repetition_penalty}")
1130
  print("="*60)
 
1190
  )
1191
 
1192
  print("\n" + "="*60)
1193
+ print(f"[TASK] image-to-svg (Model: {current_model_size})")
1194
  print(f"[INPUT] Image size: {image.size if hasattr(image, 'size') else 'unknown'}, mode: {image.mode if hasattr(image, 'mode') else 'unknown'}")
1195
  print(f"[PARAMS] candidates={num_candidates}, temp={temperature}, top_p={top_p}, top_k={top_k}, rep_penalty={repetition_penalty}, replace_bg={replace_background}")
1196
  print("="*60)
 
1277
  def create_interface():
1278
  """Create Gradio interface"""
1279
 
1280
+ # 30 Example prompts covering various categories
1281
  example_texts = [
1282
+ # === Simple Icons (1-6) ===
1283
  "A black triangle pointing downward, centrally positioned.",
1284
  "A red heart shape with smooth curved edges, centered.",
1285
  "A yellow star with five sharp points, simple geometric design, flat color.",
1286
  "A blue arrow pointing to the right, thick solid shape, centered.",
1287
  "A green circle with a white checkmark inside, centered.",
1288
  "A black plus sign with equal length arms, thick lines, centered.",
1289
+
1290
+ # === Characters & People (7-12) ===
1291
  "A simple person standing: round beige head, rectangular blue shirt body, two dark gray rectangular legs, arms at sides. Flat colors.",
1292
  "A girl with long black hair, wearing pink dress with triangular skirt, small circular face with dot eyes and curved smile. Simple cartoon style.",
1293
+ "A child waving: large round head with brown messy hair, big circular eyes, small body in red t-shirt and blue shorts, one arm raised. Cheerful cartoon style.",
1294
+ "A person sitting on chair: side view, round head, rectangular torso in green sweater, bent legs on simple chair shape. Relaxed pose.",
1295
+ "A running person: side view silhouette in black, dynamic pose with one leg forward, arms pumping. Motion style.",
1296
+
1297
+ # === Avatars & Portraits (13-17) ===
1298
  "Circular avatar: person with short black hair, round face with two dot eyes and small curved smile, wearing blue collar shirt. Minimal style, centered in circle.",
1299
+ "Female avatar: oval face with long wavy brown hair, simple eyes, pink lips, wearing v-neck purple top. Soft cartoon style in circular frame.",
1300
+ "Profile silhouette avatar: black side view of head with short hair and glasses outline, facing right. Simple solid shape.",
1301
+ "Cute cartoon avatar: round face with big sparkly eyes, rosy cheeks, short bob haircut in orange. Kawaii style, circular frame.",
1302
+ "Professional headshot avatar: person with neat hair, neutral expression, wearing suit collar. Corporate minimal style, circular frame.",
1303
+
1304
+ # === Landscapes & Scenes (18-23) ===
1305
  "Layered mountain landscape: light blue sky at top, gray triangular snow-capped mountains in middle, dark green triangular pine trees at bottom. Flat colors.",
1306
  "Sunset beach scene: orange gradient sky at top, yellow semicircle sun on horizon, dark blue wavy ocean, tan beach strip at bottom. Simple shapes.",
1307
+ "Forest scene: light blue sky, row of 5 dark green triangular pine trees of varying heights on brown trunks, light green grass at bottom.",
1308
+ "City skyline at dusk: purple-orange gradient sky, row of black rectangular building silhouettes of different heights, some with yellow window squares.",
1309
+ "Desert landscape: light orange sky with white circle sun, tan sand dunes as curved shapes, one green cactus with arms on the right side.",
1310
+ "Countryside scene: blue sky with white fluffy clouds, green rolling hills, small red barn with white door in the center, yellow hay bales.",
1311
+
1312
+ # === Animals (24-27) ===
1313
  "Cute orange cat sitting: round head with two triangular ears, oval body, curved tail. Black outline cartoon style, facing forward.",
1314
+ "Simple black bird: oval body, round head, pointed triangular beak facing right, triangular tail, two stick legs. Silhouette style.",
1315
+ "Friendly cartoon dog: brown oval body, round head with floppy ears, black dot nose, wagging curved tail, four short legs. Sitting pose.",
1316
+ "Red fox logo: triangular orange face with pointed ears, white chest marking, bushy tail. Minimalist style, facing right, centered.",
1317
+
1318
+ # === Objects & Misc (28-30) ===
1319
  "Simple house icon: red triangular roof, beige rectangular walls, brown door in center, two blue square windows, green ground at bottom.",
1320
  "Coffee mug: brown cylindrical cup with curved handle on right, three wavy steam lines rising from top. Flat style.",
1321
+ "Open book: two rectangular white pages spread open, black text lines on each page, brown spine in center. Simple top-down view."
1322
  ]
1323
 
1324
  example_images = get_example_images()
1325
 
1326
+ with gr.Blocks(title="OmniSVG Generator", css=CUSTOM_CSS) as demo:
 
 
 
 
 
 
 
 
 
1327
  # Header
1328
+ gr.HTML("""
1329
+ <div class="header-container">
1330
+ <h1>OmniSVG Generator</h1>
1331
+ <p>Transform images and text descriptions into scalable vector graphics</p>
1332
+ </div>
1333
+ """)
1334
+
1335
+ # Model Selection Section
1336
+ with gr.Row():
1337
+ with gr.Column():
1338
+ gr.HTML("""
1339
+ <div class="blue-box">
1340
+ <strong>🔧 Model Selection</strong>
1341
+ <p style="margin: 5px 0 0 0; font-size: 0.9em;">
1342
+ Choose between <b>8B</b> (higher quality, more VRAM) or <b>4B</b> (faster, less VRAM).
1343
+ </p>
1344
+ </div>
1345
+ """)
1346
+
1347
+ with gr.Row():
1348
+ model_selector = gr.Dropdown(
1349
+ choices=AVAILABLE_MODEL_SIZES,
1350
+ value=DEFAULT_MODEL_SIZE,
1351
+ label="Model Size",
1352
+ info="8B: ~16GB VRAM, higher quality | 4B: ~8GB VRAM, faster",
1353
+ interactive=True,
1354
+ scale=1
1355
+ )
1356
+ model_status = gr.Textbox(
1357
+ label="Model Status",
1358
+ value=f"✅ Ready: {DEFAULT_MODEL_SIZE} model loaded",
1359
+ interactive=False,
1360
+ scale=2
1361
+ )
1362
+ switch_btn = gr.Button("Switch Model", variant="secondary", scale=1)
1363
+
1364
+ # Model switch handler
1365
+ switch_btn.click(
1366
+ fn=switch_model,
1367
+ inputs=[model_selector],
1368
+ outputs=[model_status],
1369
+ queue=True
1370
+ )
1371
 
1372
  # Queue status
1373
  gr.HTML("""
1374
+ <div style="background: #e7f3ff; border: 1px solid #b3d7ff; border-radius: 8px; padding: 12px 15px; margin: 15px 0;">
1375
  <span style="font-size: 1.5em;">ℹ️</span>
1376
  <strong>Queue System Active</strong> - Requests processed one at a time. Please wait patiently if busy.
1377
  </div>
 
1526
  elem_classes=["primary-btn"]
1527
  )
1528
 
1529
+ gr.Markdown("### Example Prompts (30)")
1530
  gr.Examples(
1531
  examples=[[text] for text in example_texts],
1532
  inputs=[text_input],
 
1558
  # Footer
1559
  gr.HTML(f"""
1560
  <div class="footer">
1561
+ <p>Built with OmniSVG | Current Model: <strong>{DEFAULT_MODEL_SIZE}</strong></p>
1562
  <p style="color: #dc3545; font-weight: 600;">Remember: Generate 4-8 candidates and pick the best!</p>
1563
  </div>
1564
  """)
 
1571
 
1572
  args = parse_args()
1573
 
1574
+ # Determine initial model size
1575
+ initial_model_size = args.model_size or DEFAULT_MODEL_SIZE
1576
+
1577
  print("="*60)
1578
  print("OmniSVG Demo Page - Gradio App")
1579
  print("="*60)
1580
+ print(f"Available model sizes: {AVAILABLE_MODEL_SIZES}")
1581
+ print(f"Initial model size: {initial_model_size}")
1582
+ print(f"Device: {device}")
1583
+ print(f"Precision: {DTYPE}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1584
  print("="*60)
1585
 
1586
  # Print loaded config values
1587
+ print("\n[CONFIG] Shared settings:")
1588
  print(f" - TARGET_IMAGE_SIZE: {TARGET_IMAGE_SIZE}")
1589
  print(f" - RENDER_SIZE: {RENDER_SIZE}")
1590
  print(f" - BLACK_COLOR_TOKEN: {BLACK_COLOR_TOKEN}")
 
1592
  print(f" - BOS_TOKEN_ID: {BOS_TOKEN_ID}")
1593
  print(f" - EOS_TOKEN_ID: {EOS_TOKEN_ID}")
1594
  print(f" - PAD_TOKEN_ID: {PAD_TOKEN_ID}")
 
 
 
 
 
 
1595
  print("="*60)
1596
 
 
 
 
 
 
1597
  print("\nLoading models (may download from HuggingFace Hub if needed)...")
1598
+ load_models(initial_model_size, args.weight_path, args.model_path)
1599
  print("Models loaded successfully!\n")
1600
 
1601
  demo = create_interface()