OmniSVG commited on
Commit
0fa895c
·
verified ·
1 Parent(s): 67c9e68

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -43
app.py CHANGED
@@ -14,6 +14,8 @@ import time
14
  import threading
15
  import spaces
16
 
 
 
17
  from decoder import SketchDecoder
18
  from transformers import AutoTokenizer, AutoProcessor
19
  from qwen_vl_utils import process_vision_info
@@ -44,6 +46,10 @@ SUPPORTED_FORMATS = ['.png', '.jpg', '.jpeg', '.webp', '.bmp', '.gif']
44
  TARGET_IMAGE_SIZE = 448
45
  BLACK_COLOR_TOKEN = 40012
46
 
 
 
 
 
47
  # Task configurations with defaults
48
  TASK_CONFIGS = {
49
  "text-to-svg-icon": {
@@ -74,7 +80,6 @@ CUSTOM_CSS = """
74
  margin: 0 auto !important;
75
  padding: 20px !important;
76
  }
77
-
78
  /* Header styling */
79
  .header-container {
80
  text-align: center;
@@ -84,19 +89,16 @@ CUSTOM_CSS = """
84
  border-radius: 16px;
85
  color: white;
86
  }
87
-
88
  .header-container h1 {
89
  margin: 0;
90
  font-size: 2.5em;
91
  font-weight: 700;
92
  }
93
-
94
  .header-container p {
95
  margin: 10px 0 0 0;
96
  opacity: 0.9;
97
  font-size: 1.1em;
98
  }
99
-
100
  /* Tips section */
101
  .tips-box {
102
  background: #f8f9fa;
@@ -105,14 +107,12 @@ CUSTOM_CSS = """
105
  margin-bottom: 20px;
106
  border: 1px solid #e0e0e0;
107
  }
108
-
109
  .tips-box h3 {
110
  margin-top: 0;
111
  color: #333;
112
  border-bottom: 2px solid #667eea;
113
  padding-bottom: 10px;
114
  }
115
-
116
  .tip-category {
117
  background: white;
118
  border-radius: 8px;
@@ -120,19 +120,16 @@ CUSTOM_CSS = """
120
  margin: 10px 0;
121
  border-left: 4px solid #667eea;
122
  }
123
-
124
  .tip-category h4 {
125
  margin: 0 0 10px 0;
126
  color: #667eea;
127
  }
128
-
129
  .tip-category code {
130
  background: #f0f0f0;
131
  padding: 2px 6px;
132
  border-radius: 4px;
133
  font-size: 0.9em;
134
  }
135
-
136
  .example-prompt {
137
  background: #e8f4fd;
138
  padding: 10px;
@@ -142,12 +139,10 @@ CUSTOM_CSS = """
142
  font-size: 0.95em;
143
  color: #333;
144
  }
145
-
146
  .red-tip {
147
  color: #dc3545;
148
  font-weight: 600;
149
  }
150
-
151
  .red-box {
152
  background: #fff5f5;
153
  border: 1px solid #ffcccc;
@@ -156,11 +151,9 @@ CUSTOM_CSS = """
156
  border-radius: 8px;
157
  margin: 10px 0;
158
  }
159
-
160
  .red-box strong {
161
  color: #dc3545;
162
  }
163
-
164
  .orange-box {
165
  background: #fff8e6;
166
  border: 1px solid #ffc107;
@@ -169,11 +162,9 @@ CUSTOM_CSS = """
169
  border-radius: 8px;
170
  margin: 10px 0;
171
  }
172
-
173
  .orange-box strong {
174
  color: #ff9800;
175
  }
176
-
177
  .green-box {
178
  background: #e8f5e9;
179
  border: 1px solid #81c784;
@@ -182,21 +173,17 @@ CUSTOM_CSS = """
182
  border-radius: 8px;
183
  margin: 10px 0;
184
  }
185
-
186
  .green-box strong {
187
  color: #4caf50;
188
  }
189
-
190
  /* Tab styling */
191
  .tabs {
192
  border-radius: 12px !important;
193
  overflow: hidden;
194
  }
195
-
196
  .tabitem {
197
  padding: 20px !important;
198
  }
199
-
200
  /* Button styling */
201
  .primary-btn {
202
  background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
@@ -205,12 +192,10 @@ CUSTOM_CSS = """
205
  padding: 12px 24px !important;
206
  font-size: 1.1em !important;
207
  }
208
-
209
  .primary-btn:hover {
210
  transform: translateY(-2px);
211
  box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4);
212
  }
213
-
214
  /* Settings group */
215
  .settings-group {
216
  background: #f8f9fa;
@@ -218,14 +203,12 @@ CUSTOM_CSS = """
218
  padding: 15px;
219
  margin: 10px 0;
220
  }
221
-
222
  .advanced-settings {
223
  background: #f0f4f8;
224
  border-radius: 8px;
225
  padding: 12px;
226
  margin-top: 10px;
227
  }
228
-
229
  /* Code output */
230
  .code-output textarea {
231
  font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', monospace !important;
@@ -234,18 +217,15 @@ CUSTOM_CSS = """
234
  color: #d4d4d4 !important;
235
  border-radius: 8px !important;
236
  }
237
-
238
  /* Input image area */
239
  .input-image {
240
  border: 2px dashed #ccc;
241
  border-radius: 12px;
242
  transition: border-color 0.3s;
243
  }
244
-
245
  .input-image:hover {
246
  border-color: #667eea;
247
  }
248
-
249
  /* Footer */
250
  .footer {
251
  text-align: center;
@@ -253,7 +233,6 @@ CUSTOM_CSS = """
253
  color: #666;
254
  font-size: 0.9em;
255
  }
256
-
257
  /* Responsive adjustments */
258
  @media (max-width: 768px) {
259
  .gradio-container {
@@ -530,22 +509,85 @@ def parse_args():
530
  parser.add_argument('--port', type=int, default=7860)
531
  parser.add_argument('--share', action='store_true')
532
  parser.add_argument('--debug', action='store_true')
533
- parser.add_argument('--weight_path', type=str, default="/mnt/jfs-test/OmniSVG_result/8B_1126/1688_bs_4/merge_slerp/merge_150_350_bf16")
534
- parser.add_argument('--model_path', type=str, default="/mnt/jfs-test/Qwen2.5-VL-7B-Instruct")
 
 
535
  return parser.parse_args()
536
 
537
 
538
- def load_models(weight_path, model_path):
539
- """Load all models"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
540
  global tokenizer, processor, sketch_decoder, svg_tokenizer
541
 
542
- print(f"Loading models from {model_path}...")
 
543
  print(f"Using precision: {DTYPE}")
544
 
545
- tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left")
546
- processor = AutoProcessor.from_pretrained(model_path, padding_side="left")
 
 
 
 
 
 
 
 
 
 
547
  processor.tokenizer.padding_side = "left"
 
548
 
 
 
549
  sketch_decoder = SketchDecoder(
550
  pix_len=config['model']['max_length'],
551
  text_len=200,
@@ -553,17 +595,42 @@ def load_models(weight_path, model_path):
553
  torch_dtype=DTYPE
554
  )
555
 
556
- bin_path = os.path.join(weight_path, "pytorch_model.bin")
557
- if os.path.exists(bin_path):
558
- print(f"Loading weights from: {bin_path}")
559
- sketch_decoder.load_state_dict(torch.load(bin_path, map_location='cpu'))
 
 
 
 
 
 
 
 
 
 
 
 
 
560
  else:
561
- raise FileNotFoundError(f"No weights found at {bin_path}")
 
 
562
 
 
 
 
 
 
 
563
  sketch_decoder = sketch_decoder.to(device).eval()
 
 
564
  svg_tokenizer = SVGTokenizer('./config.yaml')
565
 
 
566
  print("All models loaded successfully!")
 
567
 
568
 
569
  def detect_text_subtype(text_prompt):
@@ -731,7 +798,6 @@ def prepare_inputs(task_type, content):
731
  prompt_text = str(content).strip()
732
 
733
  instruction = f"""Generate an SVG illustration for: {prompt_text}
734
-
735
  Requirements:
736
  - Create complete SVG path commands
737
  - Include proper coordinates and colors
@@ -1354,12 +1420,13 @@ if __name__ == "__main__":
1354
  print("="*60)
1355
  print("OmniSVG Generator - Gradio App")
1356
  print("="*60)
1357
- print(f"Model path: {args.model_path}")
1358
- print(f"Weight path: {args.weight_path}")
1359
  print(f"Device: {device}")
 
1360
  print("="*60)
1361
 
1362
- print("\nLoading models...")
1363
  load_models(args.weight_path, args.model_path)
1364
  print("Models loaded successfully!\n")
1365
 
 
14
  import threading
15
  import spaces
16
 
17
+ from huggingface_hub import hf_hub_download, snapshot_download
18
+
19
  from decoder import SketchDecoder
20
  from transformers import AutoTokenizer, AutoProcessor
21
  from qwen_vl_utils import process_vision_info
 
46
  TARGET_IMAGE_SIZE = 448
47
  BLACK_COLOR_TOKEN = 40012
48
 
49
+ # Default Hugging Face model IDs
50
+ DEFAULT_QWEN_MODEL = "Qwen/Qwen2.5-VL-7B-Instruct"
51
+ DEFAULT_OMNISVG_MODEL = "OmniSVG/OmniSVG1.1_8B"
52
+
53
  # Task configurations with defaults
54
  TASK_CONFIGS = {
55
  "text-to-svg-icon": {
 
80
  margin: 0 auto !important;
81
  padding: 20px !important;
82
  }
 
83
  /* Header styling */
84
  .header-container {
85
  text-align: center;
 
89
  border-radius: 16px;
90
  color: white;
91
  }
 
92
  .header-container h1 {
93
  margin: 0;
94
  font-size: 2.5em;
95
  font-weight: 700;
96
  }
 
97
  .header-container p {
98
  margin: 10px 0 0 0;
99
  opacity: 0.9;
100
  font-size: 1.1em;
101
  }
 
102
  /* Tips section */
103
  .tips-box {
104
  background: #f8f9fa;
 
107
  margin-bottom: 20px;
108
  border: 1px solid #e0e0e0;
109
  }
 
110
  .tips-box h3 {
111
  margin-top: 0;
112
  color: #333;
113
  border-bottom: 2px solid #667eea;
114
  padding-bottom: 10px;
115
  }
 
116
  .tip-category {
117
  background: white;
118
  border-radius: 8px;
 
120
  margin: 10px 0;
121
  border-left: 4px solid #667eea;
122
  }
 
123
  .tip-category h4 {
124
  margin: 0 0 10px 0;
125
  color: #667eea;
126
  }
 
127
  .tip-category code {
128
  background: #f0f0f0;
129
  padding: 2px 6px;
130
  border-radius: 4px;
131
  font-size: 0.9em;
132
  }
 
133
  .example-prompt {
134
  background: #e8f4fd;
135
  padding: 10px;
 
139
  font-size: 0.95em;
140
  color: #333;
141
  }
 
142
  .red-tip {
143
  color: #dc3545;
144
  font-weight: 600;
145
  }
 
146
  .red-box {
147
  background: #fff5f5;
148
  border: 1px solid #ffcccc;
 
151
  border-radius: 8px;
152
  margin: 10px 0;
153
  }
 
154
  .red-box strong {
155
  color: #dc3545;
156
  }
 
157
  .orange-box {
158
  background: #fff8e6;
159
  border: 1px solid #ffc107;
 
162
  border-radius: 8px;
163
  margin: 10px 0;
164
  }
 
165
  .orange-box strong {
166
  color: #ff9800;
167
  }
 
168
  .green-box {
169
  background: #e8f5e9;
170
  border: 1px solid #81c784;
 
173
  border-radius: 8px;
174
  margin: 10px 0;
175
  }
 
176
  .green-box strong {
177
  color: #4caf50;
178
  }
 
179
  /* Tab styling */
180
  .tabs {
181
  border-radius: 12px !important;
182
  overflow: hidden;
183
  }
 
184
  .tabitem {
185
  padding: 20px !important;
186
  }
 
187
  /* Button styling */
188
  .primary-btn {
189
  background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
 
192
  padding: 12px 24px !important;
193
  font-size: 1.1em !important;
194
  }
 
195
  .primary-btn:hover {
196
  transform: translateY(-2px);
197
  box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4);
198
  }
 
199
  /* Settings group */
200
  .settings-group {
201
  background: #f8f9fa;
 
203
  padding: 15px;
204
  margin: 10px 0;
205
  }
 
206
  .advanced-settings {
207
  background: #f0f4f8;
208
  border-radius: 8px;
209
  padding: 12px;
210
  margin-top: 10px;
211
  }
 
212
  /* Code output */
213
  .code-output textarea {
214
  font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', monospace !important;
 
217
  color: #d4d4d4 !important;
218
  border-radius: 8px !important;
219
  }
 
220
  /* Input image area */
221
  .input-image {
222
  border: 2px dashed #ccc;
223
  border-radius: 12px;
224
  transition: border-color 0.3s;
225
  }
 
226
  .input-image:hover {
227
  border-color: #667eea;
228
  }
 
229
  /* Footer */
230
  .footer {
231
  text-align: center;
 
233
  color: #666;
234
  font-size: 0.9em;
235
  }
 
236
  /* Responsive adjustments */
237
  @media (max-width: 768px) {
238
  .gradio-container {
 
509
  parser.add_argument('--port', type=int, default=7860)
510
  parser.add_argument('--share', action='store_true')
511
  parser.add_argument('--debug', action='store_true')
512
+ parser.add_argument('--weight_path', type=str, default=DEFAULT_OMNISVG_MODEL,
513
+ help=f'HuggingFace repo ID or local path for OmniSVG weights (default: {DEFAULT_OMNISVG_MODEL})')
514
+ parser.add_argument('--model_path', type=str, default=DEFAULT_QWEN_MODEL,
515
+ help=f'HuggingFace repo ID or local path for Qwen model (default: {DEFAULT_QWEN_MODEL})')
516
  return parser.parse_args()
517
 
518
 
519
+ def download_model_weights(repo_id: str, filename: str = "pytorch_model.bin") -> str:
520
+ """
521
+ Download model weights from Hugging Face Hub.
522
+
523
+ Args:
524
+ repo_id: Hugging Face repository ID (e.g., 'OmniSVG/OmniSVG1.1_8B')
525
+ filename: Name of the weights file to download
526
+
527
+ Returns:
528
+ Local path to the downloaded file
529
+ """
530
+ print(f"Downloading {filename} from {repo_id}...")
531
+ try:
532
+ local_path = hf_hub_download(
533
+ repo_id=repo_id,
534
+ filename=filename,
535
+ resume_download=True,
536
+ )
537
+ print(f"Successfully downloaded to: {local_path}")
538
+ return local_path
539
+ except Exception as e:
540
+ print(f"Error downloading from {repo_id}: {e}")
541
+ raise
542
+
543
+
544
+ def is_local_path(path: str) -> bool:
545
+ """Check if a path is a local filesystem path or a HuggingFace repo ID."""
546
+ # If it contains a slash and is not a valid local path, assume it's a HF repo
547
+ if os.path.exists(path):
548
+ return True
549
+ # Check for common local path patterns
550
+ if path.startswith('/') or path.startswith('./') or path.startswith('../'):
551
+ return True
552
+ if os.path.sep in path and os.path.exists(os.path.dirname(path)):
553
+ return True
554
+ # Check Windows-style paths
555
+ if len(path) > 1 and path[1] == ':':
556
+ return True
557
+ return False
558
+
559
+
560
+ def load_models(weight_path: str, model_path: str):
561
+ """
562
+ Load all models with support for both local paths and HuggingFace Hub.
563
+
564
+ Args:
565
+ weight_path: Local path or HuggingFace repo ID for OmniSVG weights
566
+ model_path: Local path or HuggingFace repo ID for Qwen model
567
+ """
568
  global tokenizer, processor, sketch_decoder, svg_tokenizer
569
 
570
+ print(f"Loading Qwen model from: {model_path}")
571
+ print(f"Loading OmniSVG weights from: {weight_path}")
572
  print(f"Using precision: {DTYPE}")
573
 
574
+ # Load Qwen tokenizer and processor (transformers handles HF download automatically)
575
+ print("\n[1/3] Loading tokenizer and processor...")
576
+ tokenizer = AutoTokenizer.from_pretrained(
577
+ model_path,
578
+ padding_side="left",
579
+ trust_remote_code=True
580
+ )
581
+ processor = AutoProcessor.from_pretrained(
582
+ model_path,
583
+ padding_side="left",
584
+ trust_remote_code=True
585
+ )
586
  processor.tokenizer.padding_side = "left"
587
+ print("Tokenizer and processor loaded successfully!")
588
 
589
+ # Initialize sketch decoder
590
+ print("\n[2/3] Initializing SketchDecoder...")
591
  sketch_decoder = SketchDecoder(
592
  pix_len=config['model']['max_length'],
593
  text_len=200,
 
595
  torch_dtype=DTYPE
596
  )
597
 
598
+ # Load OmniSVG weights
599
+ print("\n[3/3] Loading OmniSVG weights...")
600
+
601
+ # Check if weight_path is local or needs to be downloaded from HF
602
+ if is_local_path(weight_path):
603
+ # Local path - check for pytorch_model.bin
604
+ bin_path = os.path.join(weight_path, "pytorch_model.bin")
605
+ if not os.path.exists(bin_path):
606
+ # Maybe it's a direct path to the bin file
607
+ if os.path.exists(weight_path) and weight_path.endswith('.bin'):
608
+ bin_path = weight_path
609
+ else:
610
+ raise FileNotFoundError(
611
+ f"Could not find pytorch_model.bin at {weight_path}. "
612
+ f"Please provide a valid local path or HuggingFace repo ID."
613
+ )
614
+ print(f"Loading weights from local path: {bin_path}")
615
  else:
616
+ # HuggingFace repo ID - download the weights
617
+ print(f"Downloading weights from HuggingFace: {weight_path}")
618
+ bin_path = download_model_weights(weight_path, "pytorch_model.bin")
619
 
620
+ # Load the weights
621
+ state_dict = torch.load(bin_path, map_location='cpu')
622
+ sketch_decoder.load_state_dict(state_dict)
623
+ print("OmniSVG weights loaded successfully!")
624
+
625
+ # Move to device and set eval mode
626
  sketch_decoder = sketch_decoder.to(device).eval()
627
+
628
+ # Initialize SVG tokenizer
629
  svg_tokenizer = SVGTokenizer('./config.yaml')
630
 
631
+ print("\n" + "="*60)
632
  print("All models loaded successfully!")
633
+ print("="*60 + "\n")
634
 
635
 
636
  def detect_text_subtype(text_prompt):
 
798
  prompt_text = str(content).strip()
799
 
800
  instruction = f"""Generate an SVG illustration for: {prompt_text}
 
801
  Requirements:
802
  - Create complete SVG path commands
803
  - Include proper coordinates and colors
 
1420
  print("="*60)
1421
  print("OmniSVG Generator - Gradio App")
1422
  print("="*60)
1423
+ print(f"Qwen model: {args.model_path}")
1424
+ print(f"OmniSVG weights: {args.weight_path}")
1425
  print(f"Device: {device}")
1426
+ print(f"Precision: {DTYPE}")
1427
  print("="*60)
1428
 
1429
+ print("\nLoading models (may download from HuggingFace Hub if needed)...")
1430
  load_models(args.weight_path, args.model_path)
1431
  print("Models loaded successfully!\n")
1432