Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -374,6 +374,7 @@ TIPS_HTML = """
|
|
| 374 |
<ul style="margin: 8px 0 0 0; padding-left: 20px;">
|
| 375 |
<li><strong>8B Model:</strong> Higher quality, more details, better for complex illustrations. Requires more VRAM (~16GB+).</li>
|
| 376 |
<li><strong>4B Model:</strong> Faster, less VRAM required (~8GB+). Good for simple icons and basic shapes.</li>
|
|
|
|
| 377 |
</ul>
|
| 378 |
</div>
|
| 379 |
|
|
@@ -495,46 +496,6 @@ TIPS_HTML = """
|
|
| 495 |
|
| 496 |
</div>
|
| 497 |
|
| 498 |
-
<!-- Extended Examples Section -->
|
| 499 |
-
<div style="margin-top: 20px; padding: 15px; background: #f0f7ff; border-radius: 10px; border: 1px solid #cce5ff;">
|
| 500 |
-
<h4 style="margin-top: 0; color: #0066cc;">More Complex Examples (Generate 6-8 candidates!)</h4>
|
| 501 |
-
|
| 502 |
-
<div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(280px, 1fr)); gap: 12px; margin-top: 15px;">
|
| 503 |
-
<div class="example-prompt">
|
| 504 |
-
<strong>Business Avatar:</strong><br/>
|
| 505 |
-
"Circular professional avatar: man with short black hair, neutral skin tone round face, wearing dark navy suit with white shirt collar visible. Clean minimal style, centered in circle."
|
| 506 |
-
</div>
|
| 507 |
-
<div class="example-prompt">
|
| 508 |
-
<strong>Female Portrait:</strong><br/>
|
| 509 |
-
"Simple female face: oval face shape, long brown wavy hair on sides, two dot eyes, small nose, curved smile lips. Pink blush on cheeks. Cartoon portrait style."
|
| 510 |
-
</div>
|
| 511 |
-
<div class="example-prompt">
|
| 512 |
-
<strong>Child Character:</strong><br/>
|
| 513 |
-
"Cute child standing: large round head with short brown hair, big circular eyes with white highlights, small body in red t-shirt and blue shorts, simple stick arms and legs. Cheerful cartoon style."
|
| 514 |
-
</div>
|
| 515 |
-
<div class="example-prompt">
|
| 516 |
-
<strong>Active Pose:</strong><br/>
|
| 517 |
-
"Person walking: side view, circular head, rectangular torso in green jacket, legs in walking position (one forward, one back). Simple geometric style, moving right."
|
| 518 |
-
</div>
|
| 519 |
-
<div class="example-prompt">
|
| 520 |
-
<strong>Forest Scene:</strong><br/>
|
| 521 |
-
"Simple forest: light blue sky, row of 5 dark green triangular pine trees of varying heights, brown rectangular trunks, light green grass strip at bottom. Layered flat design."
|
| 522 |
-
</div>
|
| 523 |
-
<div class="example-prompt">
|
| 524 |
-
<strong>Ocean View:</strong><br/>
|
| 525 |
-
"Minimalist ocean: gradient blue sky at top, three horizontal wavy lines in dark blue for ocean, small white sailboat with triangular sail in center. Clean vector style."
|
| 526 |
-
</div>
|
| 527 |
-
<div class="example-prompt">
|
| 528 |
-
<strong>City Skyline:</strong><br/>
|
| 529 |
-
"Simple city skyline: orange sunset sky gradient, row of black rectangular building silhouettes of different heights, some with small yellow square windows. Minimalist style."
|
| 530 |
-
</div>
|
| 531 |
-
<div class="example-prompt">
|
| 532 |
-
<strong>Dog Character:</strong><br/>
|
| 533 |
-
"Friendly cartoon dog: brown oval body, round head with floppy ears, black dot nose, curved tail pointing up, four short legs. Sitting pose facing forward."
|
| 534 |
-
</div>
|
| 535 |
-
</div>
|
| 536 |
-
</div>
|
| 537 |
-
|
| 538 |
<!-- Quick Troubleshooting -->
|
| 539 |
<div class="green-box" style="margin-top: 15px;">
|
| 540 |
<strong>Quick Troubleshooting</strong>
|
|
@@ -582,26 +543,12 @@ def parse_args():
|
|
| 582 |
parser.add_argument('--port', type=int, default=7860)
|
| 583 |
parser.add_argument('--share', action='store_true')
|
| 584 |
parser.add_argument('--debug', action='store_true')
|
| 585 |
-
parser.add_argument('--model_size', type=str, default=None,
|
| 586 |
-
choices=AVAILABLE_MODEL_SIZES,
|
| 587 |
-
help=f'Model size to load at startup (default: {DEFAULT_MODEL_SIZE}). Can be changed in UI.')
|
| 588 |
-
parser.add_argument('--weight_path', type=str, default=None,
|
| 589 |
-
help='HuggingFace repo ID or local path for OmniSVG weights (overrides config)')
|
| 590 |
-
parser.add_argument('--model_path', type=str, default=None,
|
| 591 |
-
help='HuggingFace repo ID or local path for Qwen model (overrides config)')
|
| 592 |
return parser.parse_args()
|
| 593 |
|
| 594 |
|
| 595 |
def download_model_weights(repo_id: str, filename: str = "pytorch_model.bin") -> str:
|
| 596 |
"""
|
| 597 |
Download model weights from Hugging Face Hub.
|
| 598 |
-
|
| 599 |
-
Args:
|
| 600 |
-
repo_id: Hugging Face repository ID (e.g., 'OmniSVG/OmniSVG1.1_8B')
|
| 601 |
-
filename: Name of the weights file to download
|
| 602 |
-
|
| 603 |
-
Returns:
|
| 604 |
-
Local path to the downloaded file
|
| 605 |
"""
|
| 606 |
print(f"Downloading {filename} from {repo_id}...")
|
| 607 |
try:
|
|
@@ -633,11 +580,6 @@ def is_local_path(path: str) -> bool:
|
|
| 633 |
def load_models(model_size: str, weight_path: str = None, model_path: str = None):
|
| 634 |
"""
|
| 635 |
Load all models for a specific model size.
|
| 636 |
-
|
| 637 |
-
Args:
|
| 638 |
-
model_size: Model size ("8B" or "4B")
|
| 639 |
-
weight_path: Local path or HuggingFace repo ID for OmniSVG weights (optional, uses config if None)
|
| 640 |
-
model_path: Local path or HuggingFace repo ID for Qwen model (optional, uses config if None)
|
| 641 |
"""
|
| 642 |
global tokenizer, processor, sketch_decoder, svg_tokenizer, current_model_size
|
| 643 |
|
|
@@ -714,34 +656,34 @@ def load_models(model_size: str, weight_path: str = None, model_path: str = None
|
|
| 714 |
print("="*60 + "\n")
|
| 715 |
|
| 716 |
|
| 717 |
-
def
|
| 718 |
"""
|
| 719 |
-
|
| 720 |
-
|
| 721 |
-
Args:
|
| 722 |
-
new_model_size: Target model size ("8B" or "4B")
|
| 723 |
-
|
| 724 |
-
Returns:
|
| 725 |
-
Status message
|
| 726 |
"""
|
| 727 |
-
global current_model_size
|
| 728 |
|
| 729 |
-
if
|
| 730 |
-
return
|
| 731 |
|
| 732 |
with model_loading_lock:
|
| 733 |
-
#
|
| 734 |
-
|
| 735 |
-
|
| 736 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 737 |
|
| 738 |
-
|
| 739 |
-
|
| 740 |
-
return f"✅ Successfully switched to {new_model_size} model"
|
| 741 |
-
except Exception as e:
|
| 742 |
-
error_msg = f"❌ Failed to switch to {new_model_size}: {str(e)}"
|
| 743 |
-
print(error_msg)
|
| 744 |
-
return error_msg
|
| 745 |
|
| 746 |
|
| 747 |
def detect_text_subtype(text_prompt):
|
|
@@ -1116,20 +1058,21 @@ def generate_candidates(inputs, task_type, subtype, temperature, top_p, top_k, r
|
|
| 1116 |
return all_candidates
|
| 1117 |
|
| 1118 |
|
| 1119 |
-
@spaces.GPU
|
| 1120 |
-
def gradio_text_to_svg(text_description, num_candidates, temperature, top_p, top_k, repetition_penalty,
|
| 1121 |
progress=gr.Progress()):
|
| 1122 |
-
"""Gradio interface - text-to-svg with
|
| 1123 |
if not text_description or text_description.strip() == "":
|
| 1124 |
-
return '<div style="text-align:center;color:#999;padding:50px;">Please enter a description</div>', ""
|
| 1125 |
|
| 1126 |
print("\n" + "="*60)
|
| 1127 |
-
print(f"[TASK] text-to-svg
|
|
|
|
| 1128 |
print(f"[INPUT] {text_description[:100]}{'...' if len(text_description) > 100 else ''}")
|
| 1129 |
print(f"[PARAMS] candidates={num_candidates}, temp={temperature}, top_p={top_p}, top_k={top_k}, rep_penalty={repetition_penalty}")
|
| 1130 |
print("="*60)
|
| 1131 |
|
| 1132 |
-
progress(0, "
|
| 1133 |
|
| 1134 |
gc.collect()
|
| 1135 |
if torch.cuda.is_available():
|
|
@@ -1137,9 +1080,16 @@ def gradio_text_to_svg(text_description, num_candidates, temperature, top_p, top
|
|
| 1137 |
|
| 1138 |
start_time = time.time()
|
| 1139 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1140 |
subtype = detect_text_subtype(text_description)
|
| 1141 |
print(f"[SUBTYPE] Detected: {subtype}")
|
| 1142 |
-
progress(0.
|
| 1143 |
|
| 1144 |
inputs = prepare_inputs("text-to-svg", text_description.strip())
|
| 1145 |
|
|
@@ -1161,7 +1111,8 @@ def gradio_text_to_svg(text_description, num_candidates, temperature, top_p, top
|
|
| 1161 |
print("[WARNING] No valid SVG generated")
|
| 1162 |
return (
|
| 1163 |
'<div style="text-align:center;color:#999;padding:50px;">No valid SVG generated. Try different parameters or rephrase your prompt.</div>',
|
| 1164 |
-
f"<!-- No valid SVG (took {elapsed:.1f}s) -->"
|
|
|
|
| 1165 |
)
|
| 1166 |
|
| 1167 |
svg_codes = []
|
|
@@ -1174,28 +1125,30 @@ def gradio_text_to_svg(text_description, num_candidates, temperature, top_p, top
|
|
| 1174 |
progress(1.0, f"Done! {len(all_candidates)} candidates in {elapsed:.1f}s")
|
| 1175 |
print(f"[COMPLETE] text-to-svg finished\n")
|
| 1176 |
|
| 1177 |
-
return gallery_html, combined_svg
|
| 1178 |
|
| 1179 |
|
| 1180 |
-
@spaces.GPU
|
| 1181 |
-
def gradio_image_to_svg(image, num_candidates, temperature, top_p, top_k, repetition_penalty,
|
| 1182 |
replace_background, progress=gr.Progress()):
|
| 1183 |
-
"""Gradio interface - image-to-svg with
|
| 1184 |
|
| 1185 |
if image is None:
|
| 1186 |
return (
|
| 1187 |
'<div style="text-align:center;color:#999;padding:50px;">Please upload an image</div>',
|
| 1188 |
"",
|
| 1189 |
-
None
|
|
|
|
| 1190 |
)
|
| 1191 |
|
| 1192 |
print("\n" + "="*60)
|
| 1193 |
-
print(f"[TASK] image-to-svg
|
|
|
|
| 1194 |
print(f"[INPUT] Image size: {image.size if hasattr(image, 'size') else 'unknown'}, mode: {image.mode if hasattr(image, 'mode') else 'unknown'}")
|
| 1195 |
print(f"[PARAMS] candidates={num_candidates}, temp={temperature}, top_p={top_p}, top_k={top_k}, rep_penalty={repetition_penalty}, replace_bg={replace_background}")
|
| 1196 |
print("="*60)
|
| 1197 |
|
| 1198 |
-
progress(0, "
|
| 1199 |
|
| 1200 |
gc.collect()
|
| 1201 |
if torch.cuda.is_available():
|
|
@@ -1203,6 +1156,13 @@ def gradio_image_to_svg(image, num_candidates, temperature, top_p, top_k, repeti
|
|
| 1203 |
|
| 1204 |
start_time = time.time()
|
| 1205 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1206 |
img_processed, was_modified = preprocess_image_for_svg(
|
| 1207 |
image,
|
| 1208 |
replace_background=replace_background,
|
|
@@ -1211,7 +1171,7 @@ def gradio_image_to_svg(image, num_candidates, temperature, top_p, top_k, repeti
|
|
| 1211 |
|
| 1212 |
if was_modified:
|
| 1213 |
print("[PREPROCESS] Background processed/replaced")
|
| 1214 |
-
progress(0.
|
| 1215 |
|
| 1216 |
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_file:
|
| 1217 |
img_processed.save(tmp_file.name, format='PNG', quality=100)
|
|
@@ -1240,7 +1200,8 @@ def gradio_image_to_svg(image, num_candidates, temperature, top_p, top_k, repeti
|
|
| 1240 |
return (
|
| 1241 |
'<div style="text-align:center;color:#999;padding:50px;">No valid SVG generated. Try adjusting parameters.</div>',
|
| 1242 |
f"<!-- No valid SVG (took {elapsed:.1f}s) -->",
|
| 1243 |
-
img_processed
|
|
|
|
| 1244 |
)
|
| 1245 |
|
| 1246 |
svg_codes = []
|
|
@@ -1253,7 +1214,7 @@ def gradio_image_to_svg(image, num_candidates, temperature, top_p, top_k, repeti
|
|
| 1253 |
progress(1.0, f"Done! {len(all_candidates)} candidates in {elapsed:.1f}s")
|
| 1254 |
print(f"[COMPLETE] image-to-svg finished\n")
|
| 1255 |
|
| 1256 |
-
return gallery_html, combined_svg, img_processed
|
| 1257 |
|
| 1258 |
finally:
|
| 1259 |
if os.path.exists(tmp_path):
|
|
@@ -1323,7 +1284,7 @@ def create_interface():
|
|
| 1323 |
|
| 1324 |
example_images = get_example_images()
|
| 1325 |
|
| 1326 |
-
with gr.Blocks(title="OmniSVG
|
| 1327 |
# Header
|
| 1328 |
gr.HTML("""
|
| 1329 |
<div class="header-container">
|
|
@@ -1332,48 +1293,11 @@ def create_interface():
|
|
| 1332 |
</div>
|
| 1333 |
""")
|
| 1334 |
|
| 1335 |
-
# Model Selection Section
|
| 1336 |
-
with gr.Row():
|
| 1337 |
-
with gr.Column():
|
| 1338 |
-
gr.HTML("""
|
| 1339 |
-
<div class="blue-box">
|
| 1340 |
-
<strong>🔧 Model Selection</strong>
|
| 1341 |
-
<p style="margin: 5px 0 0 0; font-size: 0.9em;">
|
| 1342 |
-
Choose between <b>8B</b> (higher quality, more VRAM) or <b>4B</b> (faster, less VRAM).
|
| 1343 |
-
</p>
|
| 1344 |
-
</div>
|
| 1345 |
-
""")
|
| 1346 |
-
|
| 1347 |
-
with gr.Row():
|
| 1348 |
-
model_selector = gr.Dropdown(
|
| 1349 |
-
choices=AVAILABLE_MODEL_SIZES,
|
| 1350 |
-
value=DEFAULT_MODEL_SIZE,
|
| 1351 |
-
label="Model Size",
|
| 1352 |
-
info="8B: ~16GB VRAM, higher quality | 4B: ~8GB VRAM, faster",
|
| 1353 |
-
interactive=True,
|
| 1354 |
-
scale=1
|
| 1355 |
-
)
|
| 1356 |
-
model_status = gr.Textbox(
|
| 1357 |
-
label="Model Status",
|
| 1358 |
-
value=f"✅ Ready: {DEFAULT_MODEL_SIZE} model loaded",
|
| 1359 |
-
interactive=False,
|
| 1360 |
-
scale=2
|
| 1361 |
-
)
|
| 1362 |
-
switch_btn = gr.Button("Switch Model", variant="secondary", scale=1)
|
| 1363 |
-
|
| 1364 |
-
# Model switch handler
|
| 1365 |
-
switch_btn.click(
|
| 1366 |
-
fn=switch_model,
|
| 1367 |
-
inputs=[model_selector],
|
| 1368 |
-
outputs=[model_status],
|
| 1369 |
-
queue=True
|
| 1370 |
-
)
|
| 1371 |
-
|
| 1372 |
# Queue status
|
| 1373 |
gr.HTML("""
|
| 1374 |
<div style="background: #e7f3ff; border: 1px solid #b3d7ff; border-radius: 8px; padding: 12px 15px; margin: 15px 0;">
|
| 1375 |
<span style="font-size: 1.5em;">ℹ️</span>
|
| 1376 |
-
<strong>Queue System Active</strong> - Requests processed one at a time.
|
| 1377 |
</div>
|
| 1378 |
""")
|
| 1379 |
|
|
@@ -1399,6 +1323,15 @@ def create_interface():
|
|
| 1399 |
|
| 1400 |
with gr.Group(elem_classes=["settings-group"]):
|
| 1401 |
gr.Markdown("### Settings")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1402 |
img_num_candidates = gr.Slider(
|
| 1403 |
minimum=1, maximum=MAX_NUM_CANDIDATES, value=DEFAULT_NUM_CANDIDATES, step=1,
|
| 1404 |
label="Number of Candidates"
|
|
@@ -1443,6 +1376,13 @@ def create_interface():
|
|
| 1443 |
elem_classes=["primary-btn"]
|
| 1444 |
)
|
| 1445 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1446 |
if example_images:
|
| 1447 |
gr.Markdown("### Examples")
|
| 1448 |
gr.Examples(examples=example_images, inputs=[image_input], label="")
|
|
@@ -1461,9 +1401,9 @@ def create_interface():
|
|
| 1461 |
|
| 1462 |
image_generate_btn.click(
|
| 1463 |
fn=gradio_image_to_svg,
|
| 1464 |
-
inputs=[image_input, img_num_candidates, img_temperature, img_top_p,
|
| 1465 |
img_top_k, img_rep_penalty, img_replace_bg],
|
| 1466 |
-
outputs=[image_gallery, image_svg_output, image_processed],
|
| 1467 |
queue=True
|
| 1468 |
)
|
| 1469 |
|
|
@@ -1485,6 +1425,15 @@ def create_interface():
|
|
| 1485 |
|
| 1486 |
with gr.Group(elem_classes=["settings-group"]):
|
| 1487 |
gr.Markdown("### Settings")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1488 |
text_num_candidates = gr.Slider(
|
| 1489 |
minimum=1, maximum=MAX_NUM_CANDIDATES, value=6, step=1,
|
| 1490 |
label="Number of Candidates",
|
|
@@ -1526,6 +1475,13 @@ def create_interface():
|
|
| 1526 |
elem_classes=["primary-btn"]
|
| 1527 |
)
|
| 1528 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1529 |
gr.Markdown("### Example Prompts (30)")
|
| 1530 |
gr.Examples(
|
| 1531 |
examples=[[text] for text in example_texts],
|
|
@@ -1549,16 +1505,16 @@ def create_interface():
|
|
| 1549 |
|
| 1550 |
text_generate_btn.click(
|
| 1551 |
fn=gradio_text_to_svg,
|
| 1552 |
-
inputs=[text_input, text_num_candidates, text_temperature, text_top_p,
|
| 1553 |
text_top_k, text_rep_penalty],
|
| 1554 |
-
outputs=[text_gallery, text_svg_output],
|
| 1555 |
queue=True
|
| 1556 |
)
|
| 1557 |
|
| 1558 |
# Footer
|
| 1559 |
gr.HTML(f"""
|
| 1560 |
<div class="footer">
|
| 1561 |
-
<p>Built with OmniSVG |
|
| 1562 |
<p style="color: #dc3545; font-weight: 600;">Remember: Generate 4-8 candidates and pick the best!</p>
|
| 1563 |
</div>
|
| 1564 |
""")
|
|
@@ -1571,16 +1527,14 @@ if __name__ == "__main__":
|
|
| 1571 |
|
| 1572 |
args = parse_args()
|
| 1573 |
|
| 1574 |
-
# Determine initial model size
|
| 1575 |
-
initial_model_size = args.model_size or DEFAULT_MODEL_SIZE
|
| 1576 |
-
|
| 1577 |
print("="*60)
|
| 1578 |
-
print("OmniSVG Demo Page - Gradio App")
|
| 1579 |
print("="*60)
|
| 1580 |
print(f"Available model sizes: {AVAILABLE_MODEL_SIZES}")
|
| 1581 |
-
print(f"
|
| 1582 |
print(f"Device: {device}")
|
| 1583 |
print(f"Precision: {DTYPE}")
|
|
|
|
| 1584 |
print("="*60)
|
| 1585 |
|
| 1586 |
# Print loaded config values
|
|
@@ -1594,10 +1548,6 @@ if __name__ == "__main__":
|
|
| 1594 |
print(f" - PAD_TOKEN_ID: {PAD_TOKEN_ID}")
|
| 1595 |
print("="*60)
|
| 1596 |
|
| 1597 |
-
print("\nLoading models (may download from HuggingFace Hub if needed)...")
|
| 1598 |
-
load_models(initial_model_size, args.weight_path, args.model_path)
|
| 1599 |
-
print("Models loaded successfully!\n")
|
| 1600 |
-
|
| 1601 |
demo = create_interface()
|
| 1602 |
|
| 1603 |
demo.queue(default_concurrency_limit=1, max_size=20)
|
|
|
|
| 374 |
<ul style="margin: 8px 0 0 0; padding-left: 20px;">
|
| 375 |
<li><strong>8B Model:</strong> Higher quality, more details, better for complex illustrations. Requires more VRAM (~16GB+).</li>
|
| 376 |
<li><strong>4B Model:</strong> Faster, less VRAM required (~8GB+). Good for simple icons and basic shapes.</li>
|
| 377 |
+
<li><strong>Note:</strong> First generation with a new model size may take longer to load.</li>
|
| 378 |
</ul>
|
| 379 |
</div>
|
| 380 |
|
|
|
|
| 496 |
|
| 497 |
</div>
|
| 498 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 499 |
<!-- Quick Troubleshooting -->
|
| 500 |
<div class="green-box" style="margin-top: 15px;">
|
| 501 |
<strong>Quick Troubleshooting</strong>
|
|
|
|
| 543 |
parser.add_argument('--port', type=int, default=7860)
|
| 544 |
parser.add_argument('--share', action='store_true')
|
| 545 |
parser.add_argument('--debug', action='store_true')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 546 |
return parser.parse_args()
|
| 547 |
|
| 548 |
|
| 549 |
def download_model_weights(repo_id: str, filename: str = "pytorch_model.bin") -> str:
|
| 550 |
"""
|
| 551 |
Download model weights from Hugging Face Hub.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 552 |
"""
|
| 553 |
print(f"Downloading {filename} from {repo_id}...")
|
| 554 |
try:
|
|
|
|
| 580 |
def load_models(model_size: str, weight_path: str = None, model_path: str = None):
|
| 581 |
"""
|
| 582 |
Load all models for a specific model size.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 583 |
"""
|
| 584 |
global tokenizer, processor, sketch_decoder, svg_tokenizer, current_model_size
|
| 585 |
|
|
|
|
| 656 |
print("="*60 + "\n")
|
| 657 |
|
| 658 |
|
| 659 |
+
def ensure_model_loaded(model_size: str):
|
| 660 |
"""
|
| 661 |
+
Ensure the specified model is loaded. Load or switch if necessary.
|
| 662 |
+
This function should be called within @spaces.GPU decorated functions.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 663 |
"""
|
| 664 |
+
global current_model_size, sketch_decoder, tokenizer, processor, svg_tokenizer
|
| 665 |
|
| 666 |
+
if current_model_size == model_size and sketch_decoder is not None:
|
| 667 |
+
return # Already loaded
|
| 668 |
|
| 669 |
with model_loading_lock:
|
| 670 |
+
# Double-check after acquiring lock
|
| 671 |
+
if current_model_size == model_size and sketch_decoder is not None:
|
| 672 |
+
return
|
| 673 |
+
|
| 674 |
+
# Clear old models if switching
|
| 675 |
+
if current_model_size is not None:
|
| 676 |
+
print(f"Switching from {current_model_size} to {model_size}...")
|
| 677 |
+
del sketch_decoder
|
| 678 |
+
del tokenizer
|
| 679 |
+
del processor
|
| 680 |
+
del svg_tokenizer
|
| 681 |
+
gc.collect()
|
| 682 |
+
if torch.cuda.is_available():
|
| 683 |
+
torch.cuda.empty_cache()
|
| 684 |
|
| 685 |
+
# Load new model
|
| 686 |
+
load_models(model_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 687 |
|
| 688 |
|
| 689 |
def detect_text_subtype(text_prompt):
|
|
|
|
| 1058 |
return all_candidates
|
| 1059 |
|
| 1060 |
|
| 1061 |
+
@spaces.GPU(duration=300) # 增加时间以支持模型加载
|
| 1062 |
+
def gradio_text_to_svg(text_description, model_size, num_candidates, temperature, top_p, top_k, repetition_penalty,
|
| 1063 |
progress=gr.Progress()):
|
| 1064 |
+
"""Gradio interface - text-to-svg with model selection"""
|
| 1065 |
if not text_description or text_description.strip() == "":
|
| 1066 |
+
return '<div style="text-align:center;color:#999;padding:50px;">Please enter a description</div>', "", f"Ready (no model loaded)"
|
| 1067 |
|
| 1068 |
print("\n" + "="*60)
|
| 1069 |
+
print(f"[TASK] text-to-svg")
|
| 1070 |
+
print(f"[MODEL] Requested: {model_size}")
|
| 1071 |
print(f"[INPUT] {text_description[:100]}{'...' if len(text_description) > 100 else ''}")
|
| 1072 |
print(f"[PARAMS] candidates={num_candidates}, temp={temperature}, top_p={top_p}, top_k={top_k}, rep_penalty={repetition_penalty}")
|
| 1073 |
print("="*60)
|
| 1074 |
|
| 1075 |
+
progress(0, "Initializing...")
|
| 1076 |
|
| 1077 |
gc.collect()
|
| 1078 |
if torch.cuda.is_available():
|
|
|
|
| 1080 |
|
| 1081 |
start_time = time.time()
|
| 1082 |
|
| 1083 |
+
# 确保模型已加载
|
| 1084 |
+
progress(0.02, f"Loading {model_size} model (first time may take a while)...")
|
| 1085 |
+
ensure_model_loaded(model_size)
|
| 1086 |
+
model_status = f"✅ Using {model_size} model"
|
| 1087 |
+
|
| 1088 |
+
progress(0.05, "Model ready, starting generation...")
|
| 1089 |
+
|
| 1090 |
subtype = detect_text_subtype(text_description)
|
| 1091 |
print(f"[SUBTYPE] Detected: {subtype}")
|
| 1092 |
+
progress(0.08, f"Detected: {subtype}")
|
| 1093 |
|
| 1094 |
inputs = prepare_inputs("text-to-svg", text_description.strip())
|
| 1095 |
|
|
|
|
| 1111 |
print("[WARNING] No valid SVG generated")
|
| 1112 |
return (
|
| 1113 |
'<div style="text-align:center;color:#999;padding:50px;">No valid SVG generated. Try different parameters or rephrase your prompt.</div>',
|
| 1114 |
+
f"<!-- No valid SVG (took {elapsed:.1f}s) -->",
|
| 1115 |
+
model_status
|
| 1116 |
)
|
| 1117 |
|
| 1118 |
svg_codes = []
|
|
|
|
| 1125 |
progress(1.0, f"Done! {len(all_candidates)} candidates in {elapsed:.1f}s")
|
| 1126 |
print(f"[COMPLETE] text-to-svg finished\n")
|
| 1127 |
|
| 1128 |
+
return gallery_html, combined_svg, model_status
|
| 1129 |
|
| 1130 |
|
| 1131 |
+
@spaces.GPU(duration=300) # 增加时间以支持模型加载
|
| 1132 |
+
def gradio_image_to_svg(image, model_size, num_candidates, temperature, top_p, top_k, repetition_penalty,
|
| 1133 |
replace_background, progress=gr.Progress()):
|
| 1134 |
+
"""Gradio interface - image-to-svg with model selection"""
|
| 1135 |
|
| 1136 |
if image is None:
|
| 1137 |
return (
|
| 1138 |
'<div style="text-align:center;color:#999;padding:50px;">Please upload an image</div>',
|
| 1139 |
"",
|
| 1140 |
+
None,
|
| 1141 |
+
f"Ready (no model loaded)"
|
| 1142 |
)
|
| 1143 |
|
| 1144 |
print("\n" + "="*60)
|
| 1145 |
+
print(f"[TASK] image-to-svg")
|
| 1146 |
+
print(f"[MODEL] Requested: {model_size}")
|
| 1147 |
print(f"[INPUT] Image size: {image.size if hasattr(image, 'size') else 'unknown'}, mode: {image.mode if hasattr(image, 'mode') else 'unknown'}")
|
| 1148 |
print(f"[PARAMS] candidates={num_candidates}, temp={temperature}, top_p={top_p}, top_k={top_k}, rep_penalty={repetition_penalty}, replace_bg={replace_background}")
|
| 1149 |
print("="*60)
|
| 1150 |
|
| 1151 |
+
progress(0, "Initializing...")
|
| 1152 |
|
| 1153 |
gc.collect()
|
| 1154 |
if torch.cuda.is_available():
|
|
|
|
| 1156 |
|
| 1157 |
start_time = time.time()
|
| 1158 |
|
| 1159 |
+
# 确保模型已加载
|
| 1160 |
+
progress(0.02, f"Loading {model_size} model (first time may take a while)...")
|
| 1161 |
+
ensure_model_loaded(model_size)
|
| 1162 |
+
model_status = f"✅ Using {model_size} model"
|
| 1163 |
+
|
| 1164 |
+
progress(0.05, "Processing input image...")
|
| 1165 |
+
|
| 1166 |
img_processed, was_modified = preprocess_image_for_svg(
|
| 1167 |
image,
|
| 1168 |
replace_background=replace_background,
|
|
|
|
| 1171 |
|
| 1172 |
if was_modified:
|
| 1173 |
print("[PREPROCESS] Background processed/replaced")
|
| 1174 |
+
progress(0.08, "Background processed")
|
| 1175 |
|
| 1176 |
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_file:
|
| 1177 |
img_processed.save(tmp_file.name, format='PNG', quality=100)
|
|
|
|
| 1200 |
return (
|
| 1201 |
'<div style="text-align:center;color:#999;padding:50px;">No valid SVG generated. Try adjusting parameters.</div>',
|
| 1202 |
f"<!-- No valid SVG (took {elapsed:.1f}s) -->",
|
| 1203 |
+
img_processed,
|
| 1204 |
+
model_status
|
| 1205 |
)
|
| 1206 |
|
| 1207 |
svg_codes = []
|
|
|
|
| 1214 |
progress(1.0, f"Done! {len(all_candidates)} candidates in {elapsed:.1f}s")
|
| 1215 |
print(f"[COMPLETE] image-to-svg finished\n")
|
| 1216 |
|
| 1217 |
+
return gallery_html, combined_svg, img_processed, model_status
|
| 1218 |
|
| 1219 |
finally:
|
| 1220 |
if os.path.exists(tmp_path):
|
|
|
|
| 1284 |
|
| 1285 |
example_images = get_example_images()
|
| 1286 |
|
| 1287 |
+
with gr.Blocks(title="OmniSVG Generator", css=CUSTOM_CSS) as demo:
|
| 1288 |
# Header
|
| 1289 |
gr.HTML("""
|
| 1290 |
<div class="header-container">
|
|
|
|
| 1293 |
</div>
|
| 1294 |
""")
|
| 1295 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1296 |
# Queue status
|
| 1297 |
gr.HTML("""
|
| 1298 |
<div style="background: #e7f3ff; border: 1px solid #b3d7ff; border-radius: 8px; padding: 12px 15px; margin: 15px 0;">
|
| 1299 |
<span style="font-size: 1.5em;">ℹ️</span>
|
| 1300 |
+
<strong>Queue System Active</strong> - Requests processed one at a time. First generation with a new model may take longer to load.
|
| 1301 |
</div>
|
| 1302 |
""")
|
| 1303 |
|
|
|
|
| 1323 |
|
| 1324 |
with gr.Group(elem_classes=["settings-group"]):
|
| 1325 |
gr.Markdown("### Settings")
|
| 1326 |
+
|
| 1327 |
+
# Model Selection
|
| 1328 |
+
img_model_size = gr.Dropdown(
|
| 1329 |
+
choices=AVAILABLE_MODEL_SIZES,
|
| 1330 |
+
value=DEFAULT_MODEL_SIZE,
|
| 1331 |
+
label="Model Size",
|
| 1332 |
+
info="8B: Higher quality (~16GB VRAM) | 4B: Faster (~8GB VRAM)"
|
| 1333 |
+
)
|
| 1334 |
+
|
| 1335 |
img_num_candidates = gr.Slider(
|
| 1336 |
minimum=1, maximum=MAX_NUM_CANDIDATES, value=DEFAULT_NUM_CANDIDATES, step=1,
|
| 1337 |
label="Number of Candidates"
|
|
|
|
| 1376 |
elem_classes=["primary-btn"]
|
| 1377 |
)
|
| 1378 |
|
| 1379 |
+
# Model status display
|
| 1380 |
+
img_model_status = gr.Textbox(
|
| 1381 |
+
label="Model Status",
|
| 1382 |
+
value="Ready (model loads on first generation)",
|
| 1383 |
+
interactive=False
|
| 1384 |
+
)
|
| 1385 |
+
|
| 1386 |
if example_images:
|
| 1387 |
gr.Markdown("### Examples")
|
| 1388 |
gr.Examples(examples=example_images, inputs=[image_input], label="")
|
|
|
|
| 1401 |
|
| 1402 |
image_generate_btn.click(
|
| 1403 |
fn=gradio_image_to_svg,
|
| 1404 |
+
inputs=[image_input, img_model_size, img_num_candidates, img_temperature, img_top_p,
|
| 1405 |
img_top_k, img_rep_penalty, img_replace_bg],
|
| 1406 |
+
outputs=[image_gallery, image_svg_output, image_processed, img_model_status],
|
| 1407 |
queue=True
|
| 1408 |
)
|
| 1409 |
|
|
|
|
| 1425 |
|
| 1426 |
with gr.Group(elem_classes=["settings-group"]):
|
| 1427 |
gr.Markdown("### Settings")
|
| 1428 |
+
|
| 1429 |
+
# Model Selection
|
| 1430 |
+
text_model_size = gr.Dropdown(
|
| 1431 |
+
choices=AVAILABLE_MODEL_SIZES,
|
| 1432 |
+
value=DEFAULT_MODEL_SIZE,
|
| 1433 |
+
label="Model Size",
|
| 1434 |
+
info="8B: Higher quality (~16GB VRAM) | 4B: Faster (~8GB VRAM)"
|
| 1435 |
+
)
|
| 1436 |
+
|
| 1437 |
text_num_candidates = gr.Slider(
|
| 1438 |
minimum=1, maximum=MAX_NUM_CANDIDATES, value=6, step=1,
|
| 1439 |
label="Number of Candidates",
|
|
|
|
| 1475 |
elem_classes=["primary-btn"]
|
| 1476 |
)
|
| 1477 |
|
| 1478 |
+
# Model status display
|
| 1479 |
+
text_model_status = gr.Textbox(
|
| 1480 |
+
label="Model Status",
|
| 1481 |
+
value="Ready (model loads on first generation)",
|
| 1482 |
+
interactive=False
|
| 1483 |
+
)
|
| 1484 |
+
|
| 1485 |
gr.Markdown("### Example Prompts (30)")
|
| 1486 |
gr.Examples(
|
| 1487 |
examples=[[text] for text in example_texts],
|
|
|
|
| 1505 |
|
| 1506 |
text_generate_btn.click(
|
| 1507 |
fn=gradio_text_to_svg,
|
| 1508 |
+
inputs=[text_input, text_model_size, text_num_candidates, text_temperature, text_top_p,
|
| 1509 |
text_top_k, text_rep_penalty],
|
| 1510 |
+
outputs=[text_gallery, text_svg_output, text_model_status],
|
| 1511 |
queue=True
|
| 1512 |
)
|
| 1513 |
|
| 1514 |
# Footer
|
| 1515 |
gr.HTML(f"""
|
| 1516 |
<div class="footer">
|
| 1517 |
+
<p>Built with OmniSVG | Available models: {', '.join(AVAILABLE_MODEL_SIZES)}</p>
|
| 1518 |
<p style="color: #dc3545; font-weight: 600;">Remember: Generate 4-8 candidates and pick the best!</p>
|
| 1519 |
</div>
|
| 1520 |
""")
|
|
|
|
| 1527 |
|
| 1528 |
args = parse_args()
|
| 1529 |
|
|
|
|
|
|
|
|
|
|
| 1530 |
print("="*60)
|
| 1531 |
+
print("OmniSVG Demo Page - Gradio App (HuggingFace Space)")
|
| 1532 |
print("="*60)
|
| 1533 |
print(f"Available model sizes: {AVAILABLE_MODEL_SIZES}")
|
| 1534 |
+
print(f"Default model size: {DEFAULT_MODEL_SIZE}")
|
| 1535 |
print(f"Device: {device}")
|
| 1536 |
print(f"Precision: {DTYPE}")
|
| 1537 |
+
print("Models will be loaded on-demand when first used.")
|
| 1538 |
print("="*60)
|
| 1539 |
|
| 1540 |
# Print loaded config values
|
|
|
|
| 1548 |
print(f" - PAD_TOKEN_ID: {PAD_TOKEN_ID}")
|
| 1549 |
print("="*60)
|
| 1550 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1551 |
demo = create_interface()
|
| 1552 |
|
| 1553 |
demo.queue(default_concurrency_limit=1, max_size=20)
|