OmniSVG commited on
Commit
911a9ce
·
verified ·
1 Parent(s): f053915

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -152
app.py CHANGED
@@ -374,6 +374,7 @@ TIPS_HTML = """
374
  <ul style="margin: 8px 0 0 0; padding-left: 20px;">
375
  <li><strong>8B Model:</strong> Higher quality, more details, better for complex illustrations. Requires more VRAM (~16GB+).</li>
376
  <li><strong>4B Model:</strong> Faster, less VRAM required (~8GB+). Good for simple icons and basic shapes.</li>
 
377
  </ul>
378
  </div>
379
 
@@ -495,46 +496,6 @@ TIPS_HTML = """
495
 
496
  </div>
497
 
498
- <!-- Extended Examples Section -->
499
- <div style="margin-top: 20px; padding: 15px; background: #f0f7ff; border-radius: 10px; border: 1px solid #cce5ff;">
500
- <h4 style="margin-top: 0; color: #0066cc;">More Complex Examples (Generate 6-8 candidates!)</h4>
501
-
502
- <div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(280px, 1fr)); gap: 12px; margin-top: 15px;">
503
- <div class="example-prompt">
504
- <strong>Business Avatar:</strong><br/>
505
- "Circular professional avatar: man with short black hair, neutral skin tone round face, wearing dark navy suit with white shirt collar visible. Clean minimal style, centered in circle."
506
- </div>
507
- <div class="example-prompt">
508
- <strong>Female Portrait:</strong><br/>
509
- "Simple female face: oval face shape, long brown wavy hair on sides, two dot eyes, small nose, curved smile lips. Pink blush on cheeks. Cartoon portrait style."
510
- </div>
511
- <div class="example-prompt">
512
- <strong>Child Character:</strong><br/>
513
- "Cute child standing: large round head with short brown hair, big circular eyes with white highlights, small body in red t-shirt and blue shorts, simple stick arms and legs. Cheerful cartoon style."
514
- </div>
515
- <div class="example-prompt">
516
- <strong>Active Pose:</strong><br/>
517
- "Person walking: side view, circular head, rectangular torso in green jacket, legs in walking position (one forward, one back). Simple geometric style, moving right."
518
- </div>
519
- <div class="example-prompt">
520
- <strong>Forest Scene:</strong><br/>
521
- "Simple forest: light blue sky, row of 5 dark green triangular pine trees of varying heights, brown rectangular trunks, light green grass strip at bottom. Layered flat design."
522
- </div>
523
- <div class="example-prompt">
524
- <strong>Ocean View:</strong><br/>
525
- "Minimalist ocean: gradient blue sky at top, three horizontal wavy lines in dark blue for ocean, small white sailboat with triangular sail in center. Clean vector style."
526
- </div>
527
- <div class="example-prompt">
528
- <strong>City Skyline:</strong><br/>
529
- "Simple city skyline: orange sunset sky gradient, row of black rectangular building silhouettes of different heights, some with small yellow square windows. Minimalist style."
530
- </div>
531
- <div class="example-prompt">
532
- <strong>Dog Character:</strong><br/>
533
- "Friendly cartoon dog: brown oval body, round head with floppy ears, black dot nose, curved tail pointing up, four short legs. Sitting pose facing forward."
534
- </div>
535
- </div>
536
- </div>
537
-
538
  <!-- Quick Troubleshooting -->
539
  <div class="green-box" style="margin-top: 15px;">
540
  <strong>Quick Troubleshooting</strong>
@@ -582,26 +543,12 @@ def parse_args():
582
  parser.add_argument('--port', type=int, default=7860)
583
  parser.add_argument('--share', action='store_true')
584
  parser.add_argument('--debug', action='store_true')
585
- parser.add_argument('--model_size', type=str, default=None,
586
- choices=AVAILABLE_MODEL_SIZES,
587
- help=f'Model size to load at startup (default: {DEFAULT_MODEL_SIZE}). Can be changed in UI.')
588
- parser.add_argument('--weight_path', type=str, default=None,
589
- help='HuggingFace repo ID or local path for OmniSVG weights (overrides config)')
590
- parser.add_argument('--model_path', type=str, default=None,
591
- help='HuggingFace repo ID or local path for Qwen model (overrides config)')
592
  return parser.parse_args()
593
 
594
 
595
  def download_model_weights(repo_id: str, filename: str = "pytorch_model.bin") -> str:
596
  """
597
  Download model weights from Hugging Face Hub.
598
-
599
- Args:
600
- repo_id: Hugging Face repository ID (e.g., 'OmniSVG/OmniSVG1.1_8B')
601
- filename: Name of the weights file to download
602
-
603
- Returns:
604
- Local path to the downloaded file
605
  """
606
  print(f"Downloading {filename} from {repo_id}...")
607
  try:
@@ -633,11 +580,6 @@ def is_local_path(path: str) -> bool:
633
  def load_models(model_size: str, weight_path: str = None, model_path: str = None):
634
  """
635
  Load all models for a specific model size.
636
-
637
- Args:
638
- model_size: Model size ("8B" or "4B")
639
- weight_path: Local path or HuggingFace repo ID for OmniSVG weights (optional, uses config if None)
640
- model_path: Local path or HuggingFace repo ID for Qwen model (optional, uses config if None)
641
  """
642
  global tokenizer, processor, sketch_decoder, svg_tokenizer, current_model_size
643
 
@@ -714,34 +656,34 @@ def load_models(model_size: str, weight_path: str = None, model_path: str = None
714
  print("="*60 + "\n")
715
 
716
 
717
- def switch_model(new_model_size: str):
718
  """
719
- Switch to a different model size.
720
-
721
- Args:
722
- new_model_size: Target model size ("8B" or "4B")
723
-
724
- Returns:
725
- Status message
726
  """
727
- global current_model_size
728
 
729
- if new_model_size == current_model_size:
730
- return f"✅ Already using {new_model_size} model"
731
 
732
  with model_loading_lock:
733
- # Clear memory
734
- gc.collect()
735
- if torch.cuda.is_available():
736
- torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
737
 
738
- try:
739
- load_models(new_model_size)
740
- return f"✅ Successfully switched to {new_model_size} model"
741
- except Exception as e:
742
- error_msg = f"❌ Failed to switch to {new_model_size}: {str(e)}"
743
- print(error_msg)
744
- return error_msg
745
 
746
 
747
  def detect_text_subtype(text_prompt):
@@ -1116,20 +1058,21 @@ def generate_candidates(inputs, task_type, subtype, temperature, top_p, top_k, r
1116
  return all_candidates
1117
 
1118
 
1119
- @spaces.GPU
1120
- def gradio_text_to_svg(text_description, num_candidates, temperature, top_p, top_k, repetition_penalty,
1121
  progress=gr.Progress()):
1122
- """Gradio interface - text-to-svg with advanced parameters"""
1123
  if not text_description or text_description.strip() == "":
1124
- return '<div style="text-align:center;color:#999;padding:50px;">Please enter a description</div>', ""
1125
 
1126
  print("\n" + "="*60)
1127
- print(f"[TASK] text-to-svg (Model: {current_model_size})")
 
1128
  print(f"[INPUT] {text_description[:100]}{'...' if len(text_description) > 100 else ''}")
1129
  print(f"[PARAMS] candidates={num_candidates}, temp={temperature}, top_p={top_p}, top_k={top_k}, rep_penalty={repetition_penalty}")
1130
  print("="*60)
1131
 
1132
- progress(0, "Starting generation...")
1133
 
1134
  gc.collect()
1135
  if torch.cuda.is_available():
@@ -1137,9 +1080,16 @@ def gradio_text_to_svg(text_description, num_candidates, temperature, top_p, top
1137
 
1138
  start_time = time.time()
1139
 
 
 
 
 
 
 
 
1140
  subtype = detect_text_subtype(text_description)
1141
  print(f"[SUBTYPE] Detected: {subtype}")
1142
- progress(0.05, f"Detected: {subtype}")
1143
 
1144
  inputs = prepare_inputs("text-to-svg", text_description.strip())
1145
 
@@ -1161,7 +1111,8 @@ def gradio_text_to_svg(text_description, num_candidates, temperature, top_p, top
1161
  print("[WARNING] No valid SVG generated")
1162
  return (
1163
  '<div style="text-align:center;color:#999;padding:50px;">No valid SVG generated. Try different parameters or rephrase your prompt.</div>',
1164
- f"<!-- No valid SVG (took {elapsed:.1f}s) -->"
 
1165
  )
1166
 
1167
  svg_codes = []
@@ -1174,28 +1125,30 @@ def gradio_text_to_svg(text_description, num_candidates, temperature, top_p, top
1174
  progress(1.0, f"Done! {len(all_candidates)} candidates in {elapsed:.1f}s")
1175
  print(f"[COMPLETE] text-to-svg finished\n")
1176
 
1177
- return gallery_html, combined_svg
1178
 
1179
 
1180
- @spaces.GPU
1181
- def gradio_image_to_svg(image, num_candidates, temperature, top_p, top_k, repetition_penalty,
1182
  replace_background, progress=gr.Progress()):
1183
- """Gradio interface - image-to-svg with background handling"""
1184
 
1185
  if image is None:
1186
  return (
1187
  '<div style="text-align:center;color:#999;padding:50px;">Please upload an image</div>',
1188
  "",
1189
- None
 
1190
  )
1191
 
1192
  print("\n" + "="*60)
1193
- print(f"[TASK] image-to-svg (Model: {current_model_size})")
 
1194
  print(f"[INPUT] Image size: {image.size if hasattr(image, 'size') else 'unknown'}, mode: {image.mode if hasattr(image, 'mode') else 'unknown'}")
1195
  print(f"[PARAMS] candidates={num_candidates}, temp={temperature}, top_p={top_p}, top_k={top_k}, rep_penalty={repetition_penalty}, replace_bg={replace_background}")
1196
  print("="*60)
1197
 
1198
- progress(0, "Processing input image...")
1199
 
1200
  gc.collect()
1201
  if torch.cuda.is_available():
@@ -1203,6 +1156,13 @@ def gradio_image_to_svg(image, num_candidates, temperature, top_p, top_k, repeti
1203
 
1204
  start_time = time.time()
1205
 
 
 
 
 
 
 
 
1206
  img_processed, was_modified = preprocess_image_for_svg(
1207
  image,
1208
  replace_background=replace_background,
@@ -1211,7 +1171,7 @@ def gradio_image_to_svg(image, num_candidates, temperature, top_p, top_k, repeti
1211
 
1212
  if was_modified:
1213
  print("[PREPROCESS] Background processed/replaced")
1214
- progress(0.05, "Background processed")
1215
 
1216
  with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_file:
1217
  img_processed.save(tmp_file.name, format='PNG', quality=100)
@@ -1240,7 +1200,8 @@ def gradio_image_to_svg(image, num_candidates, temperature, top_p, top_k, repeti
1240
  return (
1241
  '<div style="text-align:center;color:#999;padding:50px;">No valid SVG generated. Try adjusting parameters.</div>',
1242
  f"<!-- No valid SVG (took {elapsed:.1f}s) -->",
1243
- img_processed
 
1244
  )
1245
 
1246
  svg_codes = []
@@ -1253,7 +1214,7 @@ def gradio_image_to_svg(image, num_candidates, temperature, top_p, top_k, repeti
1253
  progress(1.0, f"Done! {len(all_candidates)} candidates in {elapsed:.1f}s")
1254
  print(f"[COMPLETE] image-to-svg finished\n")
1255
 
1256
- return gallery_html, combined_svg, img_processed
1257
 
1258
  finally:
1259
  if os.path.exists(tmp_path):
@@ -1323,7 +1284,7 @@ def create_interface():
1323
 
1324
  example_images = get_example_images()
1325
 
1326
- with gr.Blocks(title="OmniSVG Demo Page") as demo:
1327
  # Header
1328
  gr.HTML("""
1329
  <div class="header-container">
@@ -1332,48 +1293,11 @@ def create_interface():
1332
  </div>
1333
  """)
1334
 
1335
- # Model Selection Section
1336
- with gr.Row():
1337
- with gr.Column():
1338
- gr.HTML("""
1339
- <div class="blue-box">
1340
- <strong>🔧 Model Selection</strong>
1341
- <p style="margin: 5px 0 0 0; font-size: 0.9em;">
1342
- Choose between <b>8B</b> (higher quality, more VRAM) or <b>4B</b> (faster, less VRAM).
1343
- </p>
1344
- </div>
1345
- """)
1346
-
1347
- with gr.Row():
1348
- model_selector = gr.Dropdown(
1349
- choices=AVAILABLE_MODEL_SIZES,
1350
- value=DEFAULT_MODEL_SIZE,
1351
- label="Model Size",
1352
- info="8B: ~16GB VRAM, higher quality | 4B: ~8GB VRAM, faster",
1353
- interactive=True,
1354
- scale=1
1355
- )
1356
- model_status = gr.Textbox(
1357
- label="Model Status",
1358
- value=f"✅ Ready: {DEFAULT_MODEL_SIZE} model loaded",
1359
- interactive=False,
1360
- scale=2
1361
- )
1362
- switch_btn = gr.Button("Switch Model", variant="secondary", scale=1)
1363
-
1364
- # Model switch handler
1365
- switch_btn.click(
1366
- fn=switch_model,
1367
- inputs=[model_selector],
1368
- outputs=[model_status],
1369
- queue=True
1370
- )
1371
-
1372
  # Queue status
1373
  gr.HTML("""
1374
  <div style="background: #e7f3ff; border: 1px solid #b3d7ff; border-radius: 8px; padding: 12px 15px; margin: 15px 0;">
1375
  <span style="font-size: 1.5em;">ℹ️</span>
1376
- <strong>Queue System Active</strong> - Requests processed one at a time. Please wait patiently if busy.
1377
  </div>
1378
  """)
1379
 
@@ -1399,6 +1323,15 @@ def create_interface():
1399
 
1400
  with gr.Group(elem_classes=["settings-group"]):
1401
  gr.Markdown("### Settings")
 
 
 
 
 
 
 
 
 
1402
  img_num_candidates = gr.Slider(
1403
  minimum=1, maximum=MAX_NUM_CANDIDATES, value=DEFAULT_NUM_CANDIDATES, step=1,
1404
  label="Number of Candidates"
@@ -1443,6 +1376,13 @@ def create_interface():
1443
  elem_classes=["primary-btn"]
1444
  )
1445
 
 
 
 
 
 
 
 
1446
  if example_images:
1447
  gr.Markdown("### Examples")
1448
  gr.Examples(examples=example_images, inputs=[image_input], label="")
@@ -1461,9 +1401,9 @@ def create_interface():
1461
 
1462
  image_generate_btn.click(
1463
  fn=gradio_image_to_svg,
1464
- inputs=[image_input, img_num_candidates, img_temperature, img_top_p,
1465
  img_top_k, img_rep_penalty, img_replace_bg],
1466
- outputs=[image_gallery, image_svg_output, image_processed],
1467
  queue=True
1468
  )
1469
 
@@ -1485,6 +1425,15 @@ def create_interface():
1485
 
1486
  with gr.Group(elem_classes=["settings-group"]):
1487
  gr.Markdown("### Settings")
 
 
 
 
 
 
 
 
 
1488
  text_num_candidates = gr.Slider(
1489
  minimum=1, maximum=MAX_NUM_CANDIDATES, value=6, step=1,
1490
  label="Number of Candidates",
@@ -1526,6 +1475,13 @@ def create_interface():
1526
  elem_classes=["primary-btn"]
1527
  )
1528
 
 
 
 
 
 
 
 
1529
  gr.Markdown("### Example Prompts (30)")
1530
  gr.Examples(
1531
  examples=[[text] for text in example_texts],
@@ -1549,16 +1505,16 @@ def create_interface():
1549
 
1550
  text_generate_btn.click(
1551
  fn=gradio_text_to_svg,
1552
- inputs=[text_input, text_num_candidates, text_temperature, text_top_p,
1553
  text_top_k, text_rep_penalty],
1554
- outputs=[text_gallery, text_svg_output],
1555
  queue=True
1556
  )
1557
 
1558
  # Footer
1559
  gr.HTML(f"""
1560
  <div class="footer">
1561
- <p>Built with OmniSVG | Current Model: <strong>{DEFAULT_MODEL_SIZE}</strong></p>
1562
  <p style="color: #dc3545; font-weight: 600;">Remember: Generate 4-8 candidates and pick the best!</p>
1563
  </div>
1564
  """)
@@ -1571,16 +1527,14 @@ if __name__ == "__main__":
1571
 
1572
  args = parse_args()
1573
 
1574
- # Determine initial model size
1575
- initial_model_size = args.model_size or DEFAULT_MODEL_SIZE
1576
-
1577
  print("="*60)
1578
- print("OmniSVG Demo Page - Gradio App")
1579
  print("="*60)
1580
  print(f"Available model sizes: {AVAILABLE_MODEL_SIZES}")
1581
- print(f"Initial model size: {initial_model_size}")
1582
  print(f"Device: {device}")
1583
  print(f"Precision: {DTYPE}")
 
1584
  print("="*60)
1585
 
1586
  # Print loaded config values
@@ -1594,10 +1548,6 @@ if __name__ == "__main__":
1594
  print(f" - PAD_TOKEN_ID: {PAD_TOKEN_ID}")
1595
  print("="*60)
1596
 
1597
- print("\nLoading models (may download from HuggingFace Hub if needed)...")
1598
- load_models(initial_model_size, args.weight_path, args.model_path)
1599
- print("Models loaded successfully!\n")
1600
-
1601
  demo = create_interface()
1602
 
1603
  demo.queue(default_concurrency_limit=1, max_size=20)
 
374
  <ul style="margin: 8px 0 0 0; padding-left: 20px;">
375
  <li><strong>8B Model:</strong> Higher quality, more details, better for complex illustrations. Requires more VRAM (~16GB+).</li>
376
  <li><strong>4B Model:</strong> Faster, less VRAM required (~8GB+). Good for simple icons and basic shapes.</li>
377
+ <li><strong>Note:</strong> First generation with a new model size may take longer to load.</li>
378
  </ul>
379
  </div>
380
 
 
496
 
497
  </div>
498
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
499
  <!-- Quick Troubleshooting -->
500
  <div class="green-box" style="margin-top: 15px;">
501
  <strong>Quick Troubleshooting</strong>
 
543
  parser.add_argument('--port', type=int, default=7860)
544
  parser.add_argument('--share', action='store_true')
545
  parser.add_argument('--debug', action='store_true')
 
 
 
 
 
 
 
546
  return parser.parse_args()
547
 
548
 
549
  def download_model_weights(repo_id: str, filename: str = "pytorch_model.bin") -> str:
550
  """
551
  Download model weights from Hugging Face Hub.
 
 
 
 
 
 
 
552
  """
553
  print(f"Downloading {filename} from {repo_id}...")
554
  try:
 
580
  def load_models(model_size: str, weight_path: str = None, model_path: str = None):
581
  """
582
  Load all models for a specific model size.
 
 
 
 
 
583
  """
584
  global tokenizer, processor, sketch_decoder, svg_tokenizer, current_model_size
585
 
 
656
  print("="*60 + "\n")
657
 
658
 
659
+ def ensure_model_loaded(model_size: str):
660
  """
661
+ Ensure the specified model is loaded. Load or switch if necessary.
662
+ This function should be called within @spaces.GPU decorated functions.
 
 
 
 
 
663
  """
664
+ global current_model_size, sketch_decoder, tokenizer, processor, svg_tokenizer
665
 
666
+ if current_model_size == model_size and sketch_decoder is not None:
667
+ return # Already loaded
668
 
669
  with model_loading_lock:
670
+ # Double-check after acquiring lock
671
+ if current_model_size == model_size and sketch_decoder is not None:
672
+ return
673
+
674
+ # Clear old models if switching
675
+ if current_model_size is not None:
676
+ print(f"Switching from {current_model_size} to {model_size}...")
677
+ del sketch_decoder
678
+ del tokenizer
679
+ del processor
680
+ del svg_tokenizer
681
+ gc.collect()
682
+ if torch.cuda.is_available():
683
+ torch.cuda.empty_cache()
684
 
685
+ # Load new model
686
+ load_models(model_size)
 
 
 
 
 
687
 
688
 
689
  def detect_text_subtype(text_prompt):
 
1058
  return all_candidates
1059
 
1060
 
1061
+ @spaces.GPU(duration=300) # 增加时间以支持模型加载
1062
+ def gradio_text_to_svg(text_description, model_size, num_candidates, temperature, top_p, top_k, repetition_penalty,
1063
  progress=gr.Progress()):
1064
+ """Gradio interface - text-to-svg with model selection"""
1065
  if not text_description or text_description.strip() == "":
1066
+ return '<div style="text-align:center;color:#999;padding:50px;">Please enter a description</div>', "", f"Ready (no model loaded)"
1067
 
1068
  print("\n" + "="*60)
1069
+ print(f"[TASK] text-to-svg")
1070
+ print(f"[MODEL] Requested: {model_size}")
1071
  print(f"[INPUT] {text_description[:100]}{'...' if len(text_description) > 100 else ''}")
1072
  print(f"[PARAMS] candidates={num_candidates}, temp={temperature}, top_p={top_p}, top_k={top_k}, rep_penalty={repetition_penalty}")
1073
  print("="*60)
1074
 
1075
+ progress(0, "Initializing...")
1076
 
1077
  gc.collect()
1078
  if torch.cuda.is_available():
 
1080
 
1081
  start_time = time.time()
1082
 
1083
+ # 确保模型已加载
1084
+ progress(0.02, f"Loading {model_size} model (first time may take a while)...")
1085
+ ensure_model_loaded(model_size)
1086
+ model_status = f"✅ Using {model_size} model"
1087
+
1088
+ progress(0.05, "Model ready, starting generation...")
1089
+
1090
  subtype = detect_text_subtype(text_description)
1091
  print(f"[SUBTYPE] Detected: {subtype}")
1092
+ progress(0.08, f"Detected: {subtype}")
1093
 
1094
  inputs = prepare_inputs("text-to-svg", text_description.strip())
1095
 
 
1111
  print("[WARNING] No valid SVG generated")
1112
  return (
1113
  '<div style="text-align:center;color:#999;padding:50px;">No valid SVG generated. Try different parameters or rephrase your prompt.</div>',
1114
+ f"<!-- No valid SVG (took {elapsed:.1f}s) -->",
1115
+ model_status
1116
  )
1117
 
1118
  svg_codes = []
 
1125
  progress(1.0, f"Done! {len(all_candidates)} candidates in {elapsed:.1f}s")
1126
  print(f"[COMPLETE] text-to-svg finished\n")
1127
 
1128
+ return gallery_html, combined_svg, model_status
1129
 
1130
 
1131
+ @spaces.GPU(duration=300) # 增加时间以支持模型加载
1132
+ def gradio_image_to_svg(image, model_size, num_candidates, temperature, top_p, top_k, repetition_penalty,
1133
  replace_background, progress=gr.Progress()):
1134
+ """Gradio interface - image-to-svg with model selection"""
1135
 
1136
  if image is None:
1137
  return (
1138
  '<div style="text-align:center;color:#999;padding:50px;">Please upload an image</div>',
1139
  "",
1140
+ None,
1141
+ f"Ready (no model loaded)"
1142
  )
1143
 
1144
  print("\n" + "="*60)
1145
+ print(f"[TASK] image-to-svg")
1146
+ print(f"[MODEL] Requested: {model_size}")
1147
  print(f"[INPUT] Image size: {image.size if hasattr(image, 'size') else 'unknown'}, mode: {image.mode if hasattr(image, 'mode') else 'unknown'}")
1148
  print(f"[PARAMS] candidates={num_candidates}, temp={temperature}, top_p={top_p}, top_k={top_k}, rep_penalty={repetition_penalty}, replace_bg={replace_background}")
1149
  print("="*60)
1150
 
1151
+ progress(0, "Initializing...")
1152
 
1153
  gc.collect()
1154
  if torch.cuda.is_available():
 
1156
 
1157
  start_time = time.time()
1158
 
1159
+ # 确保模型已加载
1160
+ progress(0.02, f"Loading {model_size} model (first time may take a while)...")
1161
+ ensure_model_loaded(model_size)
1162
+ model_status = f"✅ Using {model_size} model"
1163
+
1164
+ progress(0.05, "Processing input image...")
1165
+
1166
  img_processed, was_modified = preprocess_image_for_svg(
1167
  image,
1168
  replace_background=replace_background,
 
1171
 
1172
  if was_modified:
1173
  print("[PREPROCESS] Background processed/replaced")
1174
+ progress(0.08, "Background processed")
1175
 
1176
  with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_file:
1177
  img_processed.save(tmp_file.name, format='PNG', quality=100)
 
1200
  return (
1201
  '<div style="text-align:center;color:#999;padding:50px;">No valid SVG generated. Try adjusting parameters.</div>',
1202
  f"<!-- No valid SVG (took {elapsed:.1f}s) -->",
1203
+ img_processed,
1204
+ model_status
1205
  )
1206
 
1207
  svg_codes = []
 
1214
  progress(1.0, f"Done! {len(all_candidates)} candidates in {elapsed:.1f}s")
1215
  print(f"[COMPLETE] image-to-svg finished\n")
1216
 
1217
+ return gallery_html, combined_svg, img_processed, model_status
1218
 
1219
  finally:
1220
  if os.path.exists(tmp_path):
 
1284
 
1285
  example_images = get_example_images()
1286
 
1287
+ with gr.Blocks(title="OmniSVG Generator", css=CUSTOM_CSS) as demo:
1288
  # Header
1289
  gr.HTML("""
1290
  <div class="header-container">
 
1293
  </div>
1294
  """)
1295
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1296
  # Queue status
1297
  gr.HTML("""
1298
  <div style="background: #e7f3ff; border: 1px solid #b3d7ff; border-radius: 8px; padding: 12px 15px; margin: 15px 0;">
1299
  <span style="font-size: 1.5em;">ℹ️</span>
1300
+ <strong>Queue System Active</strong> - Requests processed one at a time. First generation with a new model may take longer to load.
1301
  </div>
1302
  """)
1303
 
 
1323
 
1324
  with gr.Group(elem_classes=["settings-group"]):
1325
  gr.Markdown("### Settings")
1326
+
1327
+ # Model Selection
1328
+ img_model_size = gr.Dropdown(
1329
+ choices=AVAILABLE_MODEL_SIZES,
1330
+ value=DEFAULT_MODEL_SIZE,
1331
+ label="Model Size",
1332
+ info="8B: Higher quality (~16GB VRAM) | 4B: Faster (~8GB VRAM)"
1333
+ )
1334
+
1335
  img_num_candidates = gr.Slider(
1336
  minimum=1, maximum=MAX_NUM_CANDIDATES, value=DEFAULT_NUM_CANDIDATES, step=1,
1337
  label="Number of Candidates"
 
1376
  elem_classes=["primary-btn"]
1377
  )
1378
 
1379
+ # Model status display
1380
+ img_model_status = gr.Textbox(
1381
+ label="Model Status",
1382
+ value="Ready (model loads on first generation)",
1383
+ interactive=False
1384
+ )
1385
+
1386
  if example_images:
1387
  gr.Markdown("### Examples")
1388
  gr.Examples(examples=example_images, inputs=[image_input], label="")
 
1401
 
1402
  image_generate_btn.click(
1403
  fn=gradio_image_to_svg,
1404
+ inputs=[image_input, img_model_size, img_num_candidates, img_temperature, img_top_p,
1405
  img_top_k, img_rep_penalty, img_replace_bg],
1406
+ outputs=[image_gallery, image_svg_output, image_processed, img_model_status],
1407
  queue=True
1408
  )
1409
 
 
1425
 
1426
  with gr.Group(elem_classes=["settings-group"]):
1427
  gr.Markdown("### Settings")
1428
+
1429
+ # Model Selection
1430
+ text_model_size = gr.Dropdown(
1431
+ choices=AVAILABLE_MODEL_SIZES,
1432
+ value=DEFAULT_MODEL_SIZE,
1433
+ label="Model Size",
1434
+ info="8B: Higher quality (~16GB VRAM) | 4B: Faster (~8GB VRAM)"
1435
+ )
1436
+
1437
  text_num_candidates = gr.Slider(
1438
  minimum=1, maximum=MAX_NUM_CANDIDATES, value=6, step=1,
1439
  label="Number of Candidates",
 
1475
  elem_classes=["primary-btn"]
1476
  )
1477
 
1478
+ # Model status display
1479
+ text_model_status = gr.Textbox(
1480
+ label="Model Status",
1481
+ value="Ready (model loads on first generation)",
1482
+ interactive=False
1483
+ )
1484
+
1485
  gr.Markdown("### Example Prompts (30)")
1486
  gr.Examples(
1487
  examples=[[text] for text in example_texts],
 
1505
 
1506
  text_generate_btn.click(
1507
  fn=gradio_text_to_svg,
1508
+ inputs=[text_input, text_model_size, text_num_candidates, text_temperature, text_top_p,
1509
  text_top_k, text_rep_penalty],
1510
+ outputs=[text_gallery, text_svg_output, text_model_status],
1511
  queue=True
1512
  )
1513
 
1514
  # Footer
1515
  gr.HTML(f"""
1516
  <div class="footer">
1517
+ <p>Built with OmniSVG | Available models: {', '.join(AVAILABLE_MODEL_SIZES)}</p>
1518
  <p style="color: #dc3545; font-weight: 600;">Remember: Generate 4-8 candidates and pick the best!</p>
1519
  </div>
1520
  """)
 
1527
 
1528
  args = parse_args()
1529
 
 
 
 
1530
  print("="*60)
1531
+ print("OmniSVG Demo Page - Gradio App (HuggingFace Space)")
1532
  print("="*60)
1533
  print(f"Available model sizes: {AVAILABLE_MODEL_SIZES}")
1534
+ print(f"Default model size: {DEFAULT_MODEL_SIZE}")
1535
  print(f"Device: {device}")
1536
  print(f"Precision: {DTYPE}")
1537
+ print("Models will be loaded on-demand when first used.")
1538
  print("="*60)
1539
 
1540
  # Print loaded config values
 
1548
  print(f" - PAD_TOKEN_ID: {PAD_TOKEN_ID}")
1549
  print("="*60)
1550
 
 
 
 
 
1551
  demo = create_interface()
1552
 
1553
  demo.queue(default_concurrency_limit=1, max_size=20)