OmniSVG commited on
Commit
ab5aebd
·
verified ·
1 Parent(s): 84487e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -47
app.py CHANGED
@@ -37,41 +37,89 @@ svg_tokenizer = None
37
  # Thread lock for model inference
38
  generation_lock = threading.Lock()
39
 
40
- # Constants
41
  SYSTEM_PROMPT = """You are an expert SVG code generator.
42
  Generate precise, valid SVG path commands that accurately represent the described scene or object.
43
  Focus on capturing key shapes, spatial relationships, and visual composition."""
44
 
45
  SUPPORTED_FORMATS = ['.png', '.jpg', '.jpeg', '.webp', '.bmp', '.gif']
46
- TARGET_IMAGE_SIZE = 448
47
- BLACK_COLOR_TOKEN = 40012
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  # Default Hugging Face model IDs
50
- DEFAULT_QWEN_MODEL = "Qwen/Qwen2.5-VL-7B-Instruct"
51
- DEFAULT_OMNISVG_MODEL = "OmniSVG/OmniSVG1.1_8B"
 
 
 
 
 
 
 
52
 
53
- # Task configurations with defaults
54
  TASK_CONFIGS = {
55
- "text-to-svg-icon": {
56
  "default_temperature": 0.5,
57
  "default_top_p": 0.88,
58
  "default_top_k": 50,
59
  "default_repetition_penalty": 1.05,
60
- },
61
- "text-to-svg-illustration": {
62
  "default_temperature": 0.6,
63
  "default_top_p": 0.90,
64
  "default_top_k": 60,
65
  "default_repetition_penalty": 1.03,
66
- },
67
- "image-to-svg": {
68
  "default_temperature": 0.3,
69
  "default_top_p": 0.90,
70
  "default_top_k": 50,
71
  "default_repetition_penalty": 1.05,
72
- }
73
  }
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  # Custom CSS
76
  CUSTOM_CSS = """
77
  /* Main container centering */
@@ -556,8 +604,8 @@ def load_models(weight_path: str, model_path: str):
556
  # Initialize sketch decoder
557
  print("\n[2/3] Initializing SketchDecoder...")
558
  sketch_decoder = SketchDecoder(
559
- pix_len=config['model']['max_length'],
560
- text_len=200,
561
  model_path=model_path,
562
  torch_dtype=DTYPE
563
  )
@@ -625,18 +673,24 @@ def detect_text_subtype(text_prompt):
625
  return "icon"
626
 
627
 
628
- def detect_and_replace_background(image, threshold=240, edge_sample_ratio=0.1):
629
  """
630
  Detect if image has non-white background and optionally replace it.
631
 
632
  Args:
633
  image: PIL Image (RGB or RGBA)
634
- threshold: Pixel values above this are considered "white"
635
- edge_sample_ratio: Ratio of edge pixels to sample
636
 
637
  Returns:
638
  tuple: (processed_image, background_was_replaced)
639
  """
 
 
 
 
 
 
640
  img_array = np.array(image)
641
 
642
  # If already has alpha channel, composite onto white
@@ -651,7 +705,7 @@ def detect_and_replace_background(image, threshold=240, edge_sample_ratio=0.1):
651
  edge_pixels = []
652
 
653
  # Sample from all 4 edges
654
- sample_count = max(10, int(min(h, w) * edge_sample_ratio))
655
 
656
  # Top and bottom edges
657
  for i in range(0, w, max(1, w // sample_count)):
@@ -697,7 +751,7 @@ def detect_and_replace_background(image, threshold=240, edge_sample_ratio=0.1):
697
 
698
  # Create mask for background (colors similar to detected bg_color)
699
  color_diff = np.sqrt(np.sum((img_array[:, :, :3].astype(float) - np.array(bg_color)) ** 2, axis=2))
700
- bg_mask = color_diff < 30 # Threshold for color similarity
701
 
702
  # Replace background with white
703
  result = img_array.copy()
@@ -711,18 +765,22 @@ def detect_and_replace_background(image, threshold=240, edge_sample_ratio=0.1):
711
  return image, False
712
 
713
 
714
- def preprocess_image_for_svg(image, replace_background=True, target_size=448):
715
  """
716
  Preprocess image for SVG generation.
717
 
718
  Args:
719
  image: Input PIL Image or path
720
  replace_background: Whether to replace non-white backgrounds
721
- target_size: Target size for resizing
722
 
723
  Returns:
724
  tuple: (processed_pil_image, was_modified)
725
  """
 
 
 
 
726
  # Load image if path
727
  if isinstance(image, str):
728
  raw_img = Image.open(image)
@@ -792,8 +850,12 @@ Requirements:
792
  return inputs
793
 
794
 
795
- def render_svg_to_image(svg_str, size=512):
796
  """Render SVG to high-quality PIL Image"""
 
 
 
 
797
  try:
798
  png_data = cairosvg.svg2png(
799
  bytestring=svg_str.encode('utf-8'),
@@ -858,7 +920,7 @@ def create_gallery_html(candidates, cols=4):
858
 
859
  def is_valid_candidate(svg_str, img, subtype="illustration"):
860
  """Check candidate validity"""
861
- if not svg_str or len(svg_str) < 20:
862
  return False, "too_short"
863
 
864
  if '<svg' not in svg_str:
@@ -870,7 +932,7 @@ def is_valid_candidate(svg_str, img, subtype="illustration"):
870
  img_array = np.array(img)
871
  mean_val = img_array.mean()
872
 
873
- threshold = 250 if subtype == "illustration" else 252
874
 
875
  if mean_val > threshold:
876
  return False, "empty_image"
@@ -907,12 +969,12 @@ def generate_candidates(inputs, task_type, subtype, temperature, top_p, top_k, r
907
  'repetition_penalty': repetition_penalty,
908
  'early_stopping': True,
909
  'no_repeat_ngram_size': 0,
910
- 'eos_token_id': config['model']['eos_token_id'],
911
- 'pad_token_id': config['model']['pad_token_id'],
912
- 'bos_token_id': config['model']['bos_token_id'],
913
  }
914
 
915
- actual_samples = num_samples + 4
916
 
917
  try:
918
  if progress_callback:
@@ -942,9 +1004,9 @@ def generate_candidates(inputs, task_type, subtype, temperature, top_p, top_k, r
942
  current_ids = generated_ids_batch[i:i+1]
943
 
944
  fake_wrapper = torch.cat([
945
- torch.full((1, 1), config['model']['bos_token_id'], device=device),
946
  current_ids,
947
- torch.full((1, 1), config['model']['eos_token_id'], device=device)
948
  ], dim=1)
949
 
950
  generated_xy = svg_tokenizer.process_generated_tokens(fake_wrapper)
@@ -965,7 +1027,7 @@ def generate_candidates(inputs, task_type, subtype, temperature, top_p, top_k, r
965
  if 'width=' not in svg_str:
966
  svg_str = svg_str.replace('<svg', f'<svg width="{TARGET_IMAGE_SIZE}" height="{TARGET_IMAGE_SIZE}"', 1)
967
 
968
- png_image = render_svg_to_image(svg_str, size=512)
969
 
970
  is_valid, reason = is_valid_candidate(svg_str, png_image, subtype)
971
  if is_valid:
@@ -1024,7 +1086,6 @@ def gradio_text_to_svg(text_description, num_candidates, temperature, top_p, top
1024
  progress(0.05, f"Detected: {subtype}")
1025
 
1026
  inputs = prepare_inputs("text-to-svg", text_description.strip())
1027
- max_length = config['model']['max_length']
1028
 
1029
  def update_progress(val, msg):
1030
  progress(val, msg)
@@ -1032,7 +1093,7 @@ def gradio_text_to_svg(text_description, num_candidates, temperature, top_p, top
1032
  all_candidates = generate_candidates(
1033
  inputs, "text-to-svg", subtype,
1034
  temperature, top_p, int(top_k), repetition_penalty,
1035
- max_length, int(num_candidates),
1036
  progress_callback=update_progress
1037
  )
1038
 
@@ -1105,7 +1166,6 @@ def gradio_image_to_svg(image, num_candidates, temperature, top_p, top_k, repeti
1105
  try:
1106
  progress(0.1, "Preparing model inputs...")
1107
  inputs = prepare_inputs("image-to-svg", tmp_path)
1108
- max_length = config['model']['max_length']
1109
 
1110
  def update_progress(val, msg):
1111
  progress(val, msg)
@@ -1113,7 +1173,7 @@ def gradio_image_to_svg(image, num_candidates, temperature, top_p, top_k, repeti
1113
  all_candidates = generate_candidates(
1114
  inputs, "image-to-svg", "image",
1115
  temperature, top_p, int(top_k), repetition_penalty,
1116
- max_length, int(num_candidates),
1117
  progress_callback=update_progress
1118
  )
1119
 
@@ -1249,7 +1309,7 @@ def create_interface():
1249
  with gr.Group(elem_classes=["settings-group"]):
1250
  gr.Markdown("### Settings")
1251
  img_num_candidates = gr.Slider(
1252
- minimum=1, maximum=8, value=4, step=1,
1253
  label="Number of Candidates"
1254
  )
1255
  img_replace_bg = gr.Checkbox(
@@ -1260,20 +1320,28 @@ def create_interface():
1260
 
1261
  with gr.Accordion("Advanced Parameters", open=False):
1262
  img_temperature = gr.Slider(
1263
- minimum=0.1, maximum=1.0, value=0.3, step=0.05,
 
 
1264
  label="Temperature (Lower=accurate)",
1265
  info="0.2-0.4 recommended"
1266
  )
1267
  img_top_p = gr.Slider(
1268
- minimum=0.5, maximum=1.0, value=0.90, step=0.02,
 
 
1269
  label="Top-P"
1270
  )
1271
  img_top_k = gr.Slider(
1272
- minimum=10, maximum=100, value=50, step=5,
 
 
1273
  label="Top-K"
1274
  )
1275
  img_rep_penalty = gr.Slider(
1276
- minimum=1.0, maximum=1.3, value=1.05, step=0.01,
 
 
1277
  label="Repetition Penalty"
1278
  )
1279
 
@@ -1327,27 +1395,35 @@ def create_interface():
1327
  with gr.Group(elem_classes=["settings-group"]):
1328
  gr.Markdown("### Settings")
1329
  text_num_candidates = gr.Slider(
1330
- minimum=1, maximum=8, value=6, step=1,
1331
  label="Number of Candidates",
1332
  info="More = better chances!"
1333
  )
1334
 
1335
  with gr.Accordion("Advanced Parameters", open=False):
1336
  text_temperature = gr.Slider(
1337
- minimum=0.1, maximum=1.0, value=0.5, step=0.05,
 
 
1338
  label="Temperature",
1339
  info="Icons: 0.3-0.5 | Complex: 0.5-0.7"
1340
  )
1341
  text_top_p = gr.Slider(
1342
- minimum=0.5, maximum=1.0, value=0.90, step=0.02,
 
 
1343
  label="Top-P"
1344
  )
1345
  text_top_k = gr.Slider(
1346
- minimum=10, maximum=100, value=60, step=5,
 
 
1347
  label="Top-K"
1348
  )
1349
  text_rep_penalty = gr.Slider(
1350
- minimum=1.0, maximum=1.3, value=1.03, step=0.01,
 
 
1351
  label="Repetition Penalty",
1352
  info="Increase if you see repetitive patterns"
1353
  )
@@ -1413,6 +1489,17 @@ if __name__ == "__main__":
1413
  print(f"Precision: {DTYPE}")
1414
  print("="*60)
1415
 
 
 
 
 
 
 
 
 
 
 
 
1416
  print("\nLoading models (may download from HuggingFace Hub if needed)...")
1417
  load_models(args.weight_path, args.model_path)
1418
  print("Models loaded successfully!\n")
@@ -1426,5 +1513,4 @@ if __name__ == "__main__":
1426
  server_port=args.port,
1427
  share=args.share,
1428
  debug=args.debug,
1429
- )
1430
-
 
37
  # Thread lock for model inference
38
  generation_lock = threading.Lock()
39
 
40
+ # Constants from config
41
  SYSTEM_PROMPT = """You are an expert SVG code generator.
42
  Generate precise, valid SVG path commands that accurately represent the described scene or object.
43
  Focus on capturing key shapes, spatial relationships, and visual composition."""
44
 
45
  SUPPORTED_FORMATS = ['.png', '.jpg', '.jpeg', '.webp', '.bmp', '.gif']
 
 
46
 
47
+ # ============================================================
48
+ # Image processing settings from config
49
+ # ============================================================
50
+ image_config = config.get('image', {})
51
+ TARGET_IMAGE_SIZE = image_config.get('target_size', 448)
52
+ RENDER_SIZE = image_config.get('render_size', 512)
53
+ BACKGROUND_THRESHOLD = image_config.get('background_threshold', 240)
54
+ EMPTY_THRESHOLD_ILLUSTRATION = image_config.get('empty_threshold_illustration', 250)
55
+ EMPTY_THRESHOLD_ICON = image_config.get('empty_threshold_icon', 252)
56
+ EDGE_SAMPLE_RATIO = image_config.get('edge_sample_ratio', 0.1)
57
+ COLOR_SIMILARITY_THRESHOLD = image_config.get('color_similarity_threshold', 30)
58
+ MIN_EDGE_SAMPLES = image_config.get('min_edge_samples', 10)
59
+
60
+ # ============================================================
61
+ # Color settings from config
62
+ # ============================================================
63
+ colors_config = config.get('colors', {})
64
+ BLACK_COLOR_TOKEN = colors_config.get('black_color_token',
65
+ colors_config.get('color_token_start', 40010) + 2)
66
+
67
+ # ============================================================
68
+ # Model settings from config
69
+ # ============================================================
70
+ model_config = config.get('model', {})
71
+ BOS_TOKEN_ID = model_config.get('bos_token_id', 196998)
72
+ EOS_TOKEN_ID = model_config.get('eos_token_id', 196999)
73
+ PAD_TOKEN_ID = model_config.get('pad_token_id', 151643)
74
+ MAX_LENGTH = model_config.get('max_length', 3537)
75
+
76
+ # ============================================================
77
  # Default Hugging Face model IDs
78
+ # ============================================================
79
+ hf_config = config.get('huggingface', {})
80
+ DEFAULT_QWEN_MODEL = hf_config.get('qwen_model', "Qwen/Qwen2.5-VL-7B-Instruct")
81
+ DEFAULT_OMNISVG_MODEL = hf_config.get('omnisvg_model', "OmniSVG/OmniSVG1.1_8B")
82
+
83
+ # ============================================================
84
+ # Task configurations with defaults from config
85
+ # ============================================================
86
+ task_config = config.get('task_configs', {})
87
 
 
88
  TASK_CONFIGS = {
89
+ "text-to-svg-icon": task_config.get('text_to_svg_icon', {
90
  "default_temperature": 0.5,
91
  "default_top_p": 0.88,
92
  "default_top_k": 50,
93
  "default_repetition_penalty": 1.05,
94
+ }),
95
+ "text-to-svg-illustration": task_config.get('text_to_svg_illustration', {
96
  "default_temperature": 0.6,
97
  "default_top_p": 0.90,
98
  "default_top_k": 60,
99
  "default_repetition_penalty": 1.03,
100
+ }),
101
+ "image-to-svg": task_config.get('image_to_svg', {
102
  "default_temperature": 0.3,
103
  "default_top_p": 0.90,
104
  "default_top_k": 50,
105
  "default_repetition_penalty": 1.05,
106
+ })
107
  }
108
 
109
+ # ============================================================
110
+ # Generation parameters from config
111
+ # ============================================================
112
+ gen_config = config.get('generation', {})
113
+ DEFAULT_NUM_CANDIDATES = gen_config.get('default_num_candidates', 4)
114
+ MAX_NUM_CANDIDATES = gen_config.get('max_num_candidates', 8)
115
+ EXTRA_CANDIDATES_BUFFER = gen_config.get('extra_candidates_buffer', 4)
116
+
117
+ # ============================================================
118
+ # Validation settings from config
119
+ # ============================================================
120
+ validation_config = config.get('validation', {})
121
+ MIN_SVG_LENGTH = validation_config.get('min_svg_length', 20)
122
+
123
  # Custom CSS
124
  CUSTOM_CSS = """
125
  /* Main container centering */
 
604
  # Initialize sketch decoder
605
  print("\n[2/3] Initializing SketchDecoder...")
606
  sketch_decoder = SketchDecoder(
607
+ pix_len=MAX_LENGTH,
608
+ text_len=config.get('text', {}).get('max_length', 200),
609
  model_path=model_path,
610
  torch_dtype=DTYPE
611
  )
 
673
  return "icon"
674
 
675
 
676
+ def detect_and_replace_background(image, threshold=None, edge_sample_ratio=None):
677
  """
678
  Detect if image has non-white background and optionally replace it.
679
 
680
  Args:
681
  image: PIL Image (RGB or RGBA)
682
+ threshold: Pixel values above this are considered "white" (from config)
683
+ edge_sample_ratio: Ratio of edge pixels to sample (from config)
684
 
685
  Returns:
686
  tuple: (processed_image, background_was_replaced)
687
  """
688
+ # Use config values if not provided
689
+ if threshold is None:
690
+ threshold = BACKGROUND_THRESHOLD
691
+ if edge_sample_ratio is None:
692
+ edge_sample_ratio = EDGE_SAMPLE_RATIO
693
+
694
  img_array = np.array(image)
695
 
696
  # If already has alpha channel, composite onto white
 
705
  edge_pixels = []
706
 
707
  # Sample from all 4 edges
708
+ sample_count = max(MIN_EDGE_SAMPLES, int(min(h, w) * edge_sample_ratio))
709
 
710
  # Top and bottom edges
711
  for i in range(0, w, max(1, w // sample_count)):
 
751
 
752
  # Create mask for background (colors similar to detected bg_color)
753
  color_diff = np.sqrt(np.sum((img_array[:, :, :3].astype(float) - np.array(bg_color)) ** 2, axis=2))
754
+ bg_mask = color_diff < COLOR_SIMILARITY_THRESHOLD
755
 
756
  # Replace background with white
757
  result = img_array.copy()
 
765
  return image, False
766
 
767
 
768
+ def preprocess_image_for_svg(image, replace_background=True, target_size=None):
769
  """
770
  Preprocess image for SVG generation.
771
 
772
  Args:
773
  image: Input PIL Image or path
774
  replace_background: Whether to replace non-white backgrounds
775
+ target_size: Target size for resizing (from config)
776
 
777
  Returns:
778
  tuple: (processed_pil_image, was_modified)
779
  """
780
+ # Use config value if not provided
781
+ if target_size is None:
782
+ target_size = TARGET_IMAGE_SIZE
783
+
784
  # Load image if path
785
  if isinstance(image, str):
786
  raw_img = Image.open(image)
 
850
  return inputs
851
 
852
 
853
+ def render_svg_to_image(svg_str, size=None):
854
  """Render SVG to high-quality PIL Image"""
855
+ # Use config value if not provided
856
+ if size is None:
857
+ size = RENDER_SIZE
858
+
859
  try:
860
  png_data = cairosvg.svg2png(
861
  bytestring=svg_str.encode('utf-8'),
 
920
 
921
  def is_valid_candidate(svg_str, img, subtype="illustration"):
922
  """Check candidate validity"""
923
+ if not svg_str or len(svg_str) < MIN_SVG_LENGTH:
924
  return False, "too_short"
925
 
926
  if '<svg' not in svg_str:
 
932
  img_array = np.array(img)
933
  mean_val = img_array.mean()
934
 
935
+ threshold = EMPTY_THRESHOLD_ILLUSTRATION if subtype == "illustration" else EMPTY_THRESHOLD_ICON
936
 
937
  if mean_val > threshold:
938
  return False, "empty_image"
 
969
  'repetition_penalty': repetition_penalty,
970
  'early_stopping': True,
971
  'no_repeat_ngram_size': 0,
972
+ 'eos_token_id': EOS_TOKEN_ID,
973
+ 'pad_token_id': PAD_TOKEN_ID,
974
+ 'bos_token_id': BOS_TOKEN_ID,
975
  }
976
 
977
+ actual_samples = num_samples + EXTRA_CANDIDATES_BUFFER
978
 
979
  try:
980
  if progress_callback:
 
1004
  current_ids = generated_ids_batch[i:i+1]
1005
 
1006
  fake_wrapper = torch.cat([
1007
+ torch.full((1, 1), BOS_TOKEN_ID, device=device),
1008
  current_ids,
1009
+ torch.full((1, 1), EOS_TOKEN_ID, device=device)
1010
  ], dim=1)
1011
 
1012
  generated_xy = svg_tokenizer.process_generated_tokens(fake_wrapper)
 
1027
  if 'width=' not in svg_str:
1028
  svg_str = svg_str.replace('<svg', f'<svg width="{TARGET_IMAGE_SIZE}" height="{TARGET_IMAGE_SIZE}"', 1)
1029
 
1030
+ png_image = render_svg_to_image(svg_str, size=RENDER_SIZE)
1031
 
1032
  is_valid, reason = is_valid_candidate(svg_str, png_image, subtype)
1033
  if is_valid:
 
1086
  progress(0.05, f"Detected: {subtype}")
1087
 
1088
  inputs = prepare_inputs("text-to-svg", text_description.strip())
 
1089
 
1090
  def update_progress(val, msg):
1091
  progress(val, msg)
 
1093
  all_candidates = generate_candidates(
1094
  inputs, "text-to-svg", subtype,
1095
  temperature, top_p, int(top_k), repetition_penalty,
1096
+ MAX_LENGTH, int(num_candidates),
1097
  progress_callback=update_progress
1098
  )
1099
 
 
1166
  try:
1167
  progress(0.1, "Preparing model inputs...")
1168
  inputs = prepare_inputs("image-to-svg", tmp_path)
 
1169
 
1170
  def update_progress(val, msg):
1171
  progress(val, msg)
 
1173
  all_candidates = generate_candidates(
1174
  inputs, "image-to-svg", "image",
1175
  temperature, top_p, int(top_k), repetition_penalty,
1176
+ MAX_LENGTH, int(num_candidates),
1177
  progress_callback=update_progress
1178
  )
1179
 
 
1309
  with gr.Group(elem_classes=["settings-group"]):
1310
  gr.Markdown("### Settings")
1311
  img_num_candidates = gr.Slider(
1312
+ minimum=1, maximum=MAX_NUM_CANDIDATES, value=DEFAULT_NUM_CANDIDATES, step=1,
1313
  label="Number of Candidates"
1314
  )
1315
  img_replace_bg = gr.Checkbox(
 
1320
 
1321
  with gr.Accordion("Advanced Parameters", open=False):
1322
  img_temperature = gr.Slider(
1323
+ minimum=0.1, maximum=1.0,
1324
+ value=TASK_CONFIGS["image-to-svg"].get("default_temperature", 0.3),
1325
+ step=0.05,
1326
  label="Temperature (Lower=accurate)",
1327
  info="0.2-0.4 recommended"
1328
  )
1329
  img_top_p = gr.Slider(
1330
+ minimum=0.5, maximum=1.0,
1331
+ value=TASK_CONFIGS["image-to-svg"].get("default_top_p", 0.90),
1332
+ step=0.02,
1333
  label="Top-P"
1334
  )
1335
  img_top_k = gr.Slider(
1336
+ minimum=10, maximum=100,
1337
+ value=TASK_CONFIGS["image-to-svg"].get("default_top_k", 50),
1338
+ step=5,
1339
  label="Top-K"
1340
  )
1341
  img_rep_penalty = gr.Slider(
1342
+ minimum=1.0, maximum=1.3,
1343
+ value=TASK_CONFIGS["image-to-svg"].get("default_repetition_penalty", 1.05),
1344
+ step=0.01,
1345
  label="Repetition Penalty"
1346
  )
1347
 
 
1395
  with gr.Group(elem_classes=["settings-group"]):
1396
  gr.Markdown("### Settings")
1397
  text_num_candidates = gr.Slider(
1398
+ minimum=1, maximum=MAX_NUM_CANDIDATES, value=6, step=1,
1399
  label="Number of Candidates",
1400
  info="More = better chances!"
1401
  )
1402
 
1403
  with gr.Accordion("Advanced Parameters", open=False):
1404
  text_temperature = gr.Slider(
1405
+ minimum=0.1, maximum=1.0,
1406
+ value=TASK_CONFIGS["text-to-svg-icon"].get("default_temperature", 0.5),
1407
+ step=0.05,
1408
  label="Temperature",
1409
  info="Icons: 0.3-0.5 | Complex: 0.5-0.7"
1410
  )
1411
  text_top_p = gr.Slider(
1412
+ minimum=0.5, maximum=1.0,
1413
+ value=TASK_CONFIGS["text-to-svg-icon"].get("default_top_p", 0.90),
1414
+ step=0.02,
1415
  label="Top-P"
1416
  )
1417
  text_top_k = gr.Slider(
1418
+ minimum=10, maximum=100,
1419
+ value=TASK_CONFIGS["text-to-svg-icon"].get("default_top_k", 60),
1420
+ step=5,
1421
  label="Top-K"
1422
  )
1423
  text_rep_penalty = gr.Slider(
1424
+ minimum=1.0, maximum=1.3,
1425
+ value=TASK_CONFIGS["text-to-svg-icon"].get("default_repetition_penalty", 1.03),
1426
+ step=0.01,
1427
  label="Repetition Penalty",
1428
  info="Increase if you see repetitive patterns"
1429
  )
 
1489
  print(f"Precision: {DTYPE}")
1490
  print("="*60)
1491
 
1492
+ # Print loaded config values
1493
+ print("\n[CONFIG] Loaded settings:")
1494
+ print(f" - TARGET_IMAGE_SIZE: {TARGET_IMAGE_SIZE}")
1495
+ print(f" - RENDER_SIZE: {RENDER_SIZE}")
1496
+ print(f" - BLACK_COLOR_TOKEN: {BLACK_COLOR_TOKEN}")
1497
+ print(f" - MAX_LENGTH: {MAX_LENGTH}")
1498
+ print(f" - BOS_TOKEN_ID: {BOS_TOKEN_ID}")
1499
+ print(f" - EOS_TOKEN_ID: {EOS_TOKEN_ID}")
1500
+ print(f" - PAD_TOKEN_ID: {PAD_TOKEN_ID}")
1501
+ print("="*60)
1502
+
1503
  print("\nLoading models (may download from HuggingFace Hub if needed)...")
1504
  load_models(args.weight_path, args.model_path)
1505
  print("Models loaded successfully!\n")
 
1513
  server_port=args.port,
1514
  share=args.share,
1515
  debug=args.debug,
1516
+ )