Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -12,7 +12,6 @@ import glob
|
|
| 12 |
import numpy as np
|
| 13 |
import time
|
| 14 |
import threading
|
| 15 |
-
import copy
|
| 16 |
import spaces
|
| 17 |
|
| 18 |
from huggingface_hub import hf_hub_download, snapshot_download
|
|
@@ -22,167 +21,131 @@ from transformers import AutoTokenizer, AutoProcessor
|
|
| 22 |
from qwen_vl_utils import process_vision_info
|
| 23 |
from tokenizer import SVGTokenizer
|
| 24 |
|
| 25 |
-
#
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
"""
|
| 30 |
-
Load config file and merge variant-specific settings.
|
| 31 |
-
|
| 32 |
-
Args:
|
| 33 |
-
config_path: Path to the config.yaml file
|
| 34 |
-
variant: Model variant ("8B" or "4B"). If None, uses default_variant from config.
|
| 35 |
-
|
| 36 |
-
Returns:
|
| 37 |
-
Merged configuration dictionary
|
| 38 |
-
"""
|
| 39 |
-
with open(config_path, 'r') as f:
|
| 40 |
-
raw_config = yaml.safe_load(f)
|
| 41 |
-
|
| 42 |
-
# Determine which variant to use
|
| 43 |
-
if variant is None:
|
| 44 |
-
variant = raw_config.get('default_variant', '8B')
|
| 45 |
-
|
| 46 |
-
# Check if variant exists
|
| 47 |
-
variants = raw_config.get('variants', {})
|
| 48 |
-
if variant not in variants:
|
| 49 |
-
available = list(variants.keys())
|
| 50 |
-
raise ValueError(f"Unknown model variant '{variant}'. Available variants: {available}")
|
| 51 |
-
|
| 52 |
-
# Start with a copy of raw config (excluding 'variants' key)
|
| 53 |
-
merged_config = {k: v for k, v in raw_config.items() if k != 'variants'}
|
| 54 |
-
|
| 55 |
-
# Merge variant-specific settings
|
| 56 |
-
variant_config = variants[variant]
|
| 57 |
-
for key, value in variant_config.items():
|
| 58 |
-
if isinstance(value, dict) and key in merged_config and isinstance(merged_config[key], dict):
|
| 59 |
-
# Deep merge for nested dicts
|
| 60 |
-
merged_config[key] = {**merged_config.get(key, {}), **value}
|
| 61 |
-
else:
|
| 62 |
-
merged_config[key] = value
|
| 63 |
-
|
| 64 |
-
# Store the active variant name
|
| 65 |
-
merged_config['active_variant'] = variant
|
| 66 |
-
|
| 67 |
-
return merged_config
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
def write_variant_config(config: dict, output_path: str):
|
| 71 |
-
"""
|
| 72 |
-
Write a variant-specific config file for SVGTokenizer.
|
| 73 |
-
|
| 74 |
-
Args:
|
| 75 |
-
config: Merged configuration dictionary
|
| 76 |
-
output_path: Path to write the temporary config file
|
| 77 |
-
"""
|
| 78 |
-
# Create a config without the 'variants' and 'active_variant' keys
|
| 79 |
-
clean_config = {k: v for k, v in config.items()
|
| 80 |
-
if k not in ['variants', 'active_variant', 'default_variant']}
|
| 81 |
-
|
| 82 |
-
with open(output_path, 'w') as f:
|
| 83 |
-
yaml.safe_dump(clean_config, f, default_flow_style=False)
|
| 84 |
|
| 85 |
-
|
| 86 |
-
# ============================================================
|
| 87 |
-
# Global Variables
|
| 88 |
-
# ============================================================
|
| 89 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 90 |
DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
| 91 |
|
| 92 |
-
# Global Models
|
| 93 |
tokenizer = None
|
| 94 |
processor = None
|
| 95 |
sketch_decoder = None
|
| 96 |
svg_tokenizer = None
|
| 97 |
-
|
| 98 |
-
# Global Config (will be set after loading)
|
| 99 |
-
config = None
|
| 100 |
-
MODEL_VARIANT = None
|
| 101 |
|
| 102 |
# Thread lock for model inference
|
| 103 |
generation_lock = threading.Lock()
|
|
|
|
| 104 |
|
| 105 |
-
# Constants
|
| 106 |
SYSTEM_PROMPT = """You are an expert SVG code generator.
|
| 107 |
Generate precise, valid SVG path commands that accurately represent the described scene or object.
|
| 108 |
Focus on capturing key shapes, spatial relationships, and visual composition."""
|
| 109 |
|
| 110 |
SUPPORTED_FORMATS = ['.png', '.jpg', '.jpeg', '.webp', '.bmp', '.gif']
|
|
|
|
|
|
|
| 111 |
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
colors_config = cfg.get('colors', {})
|
| 136 |
-
BLACK_COLOR_TOKEN = colors_config.get('black_color_token',
|
| 137 |
-
colors_config.get('color_token_start', 40010) + 2)
|
| 138 |
-
|
| 139 |
-
# Model settings
|
| 140 |
-
model_config = cfg.get('model', {})
|
| 141 |
-
BOS_TOKEN_ID = model_config.get('bos_token_id', 196998)
|
| 142 |
-
EOS_TOKEN_ID = model_config.get('eos_token_id', 196999)
|
| 143 |
-
PAD_TOKEN_ID = model_config.get('pad_token_id', 151643)
|
| 144 |
-
MAX_LENGTH = model_config.get('max_length', 1536)
|
| 145 |
-
|
| 146 |
-
# HuggingFace model IDs
|
| 147 |
-
hf_config = cfg.get('huggingface', {})
|
| 148 |
-
DEFAULT_QWEN_MODEL = hf_config.get('qwen_model', "Qwen/Qwen2.5-VL-7B-Instruct")
|
| 149 |
-
DEFAULT_OMNISVG_MODEL = hf_config.get('omnisvg_model', "OmniSVG/OmniSVG1.1_8B")
|
| 150 |
-
|
| 151 |
-
# Task configurations
|
| 152 |
-
task_config = cfg.get('task_configs', {})
|
| 153 |
-
TASK_CONFIGS = {
|
| 154 |
-
"text-to-svg-icon": task_config.get('text_to_svg_icon', {
|
| 155 |
-
"default_temperature": 0.5,
|
| 156 |
-
"default_top_p": 0.88,
|
| 157 |
-
"default_top_k": 50,
|
| 158 |
-
"default_repetition_penalty": 1.05,
|
| 159 |
-
}),
|
| 160 |
-
"text-to-svg-illustration": task_config.get('text_to_svg_illustration', {
|
| 161 |
-
"default_temperature": 0.6,
|
| 162 |
-
"default_top_p": 0.90,
|
| 163 |
-
"default_top_k": 60,
|
| 164 |
-
"default_repetition_penalty": 1.03,
|
| 165 |
-
}),
|
| 166 |
-
"image-to-svg": task_config.get('image_to_svg', {
|
| 167 |
-
"default_temperature": 0.3,
|
| 168 |
-
"default_top_p": 0.90,
|
| 169 |
-
"default_top_k": 50,
|
| 170 |
-
"default_repetition_penalty": 1.05,
|
| 171 |
-
})
|
| 172 |
-
}
|
| 173 |
-
|
| 174 |
-
# Generation parameters
|
| 175 |
-
gen_config = cfg.get('generation', {})
|
| 176 |
-
DEFAULT_NUM_CANDIDATES = gen_config.get('default_num_candidates', 4)
|
| 177 |
-
MAX_NUM_CANDIDATES = gen_config.get('max_num_candidates', 8)
|
| 178 |
-
EXTRA_CANDIDATES_BUFFER = gen_config.get('extra_candidates_buffer', 4)
|
| 179 |
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
|
| 185 |
-
# Custom CSS
|
| 186 |
CUSTOM_CSS = """
|
| 187 |
/* Main container centering */
|
| 188 |
.gradio-container {
|
|
@@ -209,14 +172,18 @@ CUSTOM_CSS = """
|
|
| 209 |
opacity: 0.9;
|
| 210 |
font-size: 1.1em;
|
| 211 |
}
|
| 212 |
-
/* Model
|
| 213 |
-
.model-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
}
|
| 221 |
/* Tips section */
|
| 222 |
.tips-box {
|
|
@@ -295,6 +262,17 @@ CUSTOM_CSS = """
|
|
| 295 |
.green-box strong {
|
| 296 |
color: #4caf50;
|
| 297 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
/* Tab styling */
|
| 299 |
.tabs {
|
| 300 |
border-radius: 12px !important;
|
|
@@ -363,7 +341,7 @@ CUSTOM_CSS = """
|
|
| 363 |
}
|
| 364 |
"""
|
| 365 |
|
| 366 |
-
# Enhanced Tips HTML
|
| 367 |
TIPS_HTML = """
|
| 368 |
<div class="tips-box">
|
| 369 |
<h3>Prompting Guide & Best Practices</h3>
|
|
@@ -390,6 +368,15 @@ TIPS_HTML = """
|
|
| 390 |
</ul>
|
| 391 |
</div>
|
| 392 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 393 |
<!-- Parameter Tuning Tips -->
|
| 394 |
<div class="orange-box">
|
| 395 |
<strong>Parameter Tuning Guide</strong>
|
|
@@ -455,15 +442,32 @@ TIPS_HTML = """
|
|
| 455 |
<div class="example-prompt">
|
| 456 |
"A simple person: round beige head, rectangular blue shirt body, two dark gray rectangular legs. Standing pose, arms at sides, flat colors."
|
| 457 |
</div>
|
|
|
|
|
|
|
|
|
|
| 458 |
<p class="red-tip">Keep poses SIMPLE: standing, sitting, waving. Avoid complex actions!</p>
|
| 459 |
</div>
|
| 460 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 461 |
<div class="tip-category">
|
| 462 |
<h4>Landscapes & Scenes</h4>
|
| 463 |
<p>Layer elements from background to foreground. Specify color for EACH layer.</p>
|
| 464 |
<div class="example-prompt">
|
| 465 |
"Layered landscape: light blue sky at top, gray triangular mountains in middle, dark green triangular pine trees at bottom. Flat colors, simple shapes."
|
| 466 |
</div>
|
|
|
|
|
|
|
|
|
|
| 467 |
<p class="red-tip">Use geometric shapes for nature: triangular trees, wavy water, semicircle sun!</p>
|
| 468 |
</div>
|
| 469 |
|
|
@@ -473,10 +477,64 @@ TIPS_HTML = """
|
|
| 473 |
<div class="example-prompt">
|
| 474 |
"Cute cat: orange round head with two triangular ears, oval orange body, curved tail. Simple cartoon style with black outlines, sitting pose."
|
| 475 |
</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 476 |
</div>
|
| 477 |
|
| 478 |
</div>
|
| 479 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 480 |
<!-- Quick Troubleshooting -->
|
| 481 |
<div class="green-box" style="margin-top: 15px;">
|
| 482 |
<strong>Quick Troubleshooting</strong>
|
|
@@ -489,6 +547,17 @@ TIPS_HTML = """
|
|
| 489 |
<li><strong>Inconsistent?</strong> <span class="red-tip">Generate MORE candidates (6-8) and pick the best!</span></li>
|
| 490 |
</ul>
|
| 491 |
</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 492 |
</div>
|
| 493 |
"""
|
| 494 |
|
|
@@ -513,10 +582,9 @@ def parse_args():
|
|
| 513 |
parser.add_argument('--port', type=int, default=7860)
|
| 514 |
parser.add_argument('--share', action='store_true')
|
| 515 |
parser.add_argument('--debug', action='store_true')
|
| 516 |
-
parser.add_argument('--model_size', type=str, default=None,
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
help='Path to config file (default: ./config.yaml)')
|
| 520 |
parser.add_argument('--weight_path', type=str, default=None,
|
| 521 |
help='HuggingFace repo ID or local path for OmniSVG weights (overrides config)')
|
| 522 |
parser.add_argument('--model_path', type=str, default=None,
|
|
@@ -527,6 +595,13 @@ def parse_args():
|
|
| 527 |
def download_model_weights(repo_id: str, filename: str = "pytorch_model.bin") -> str:
|
| 528 |
"""
|
| 529 |
Download model weights from Hugging Face Hub.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 530 |
"""
|
| 531 |
print(f"Downloading {filename} from {repo_id}...")
|
| 532 |
try:
|
|
@@ -555,20 +630,29 @@ def is_local_path(path: str) -> bool:
|
|
| 555 |
return False
|
| 556 |
|
| 557 |
|
| 558 |
-
def load_models(
|
| 559 |
"""
|
| 560 |
-
Load all models
|
| 561 |
|
| 562 |
Args:
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
"""
|
| 567 |
-
global tokenizer, processor, sketch_decoder, svg_tokenizer
|
| 568 |
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 572 |
|
| 573 |
# Load Qwen tokenizer and processor
|
| 574 |
print("\n[1/3] Loading tokenizer and processor...")
|
|
@@ -585,12 +669,14 @@ def load_models(weight_path: str, model_path: str, variant_config_path: str):
|
|
| 585 |
processor.tokenizer.padding_side = "left"
|
| 586 |
print("Tokenizer and processor loaded successfully!")
|
| 587 |
|
| 588 |
-
# Initialize sketch decoder
|
| 589 |
print("\n[2/3] Initializing SketchDecoder...")
|
| 590 |
sketch_decoder = SketchDecoder(
|
|
|
|
|
|
|
|
|
|
| 591 |
pix_len=MAX_LENGTH,
|
| 592 |
text_len=config.get('text', {}).get('max_length', 200),
|
| 593 |
-
model_path=model_path,
|
| 594 |
torch_dtype=DTYPE
|
| 595 |
)
|
| 596 |
|
|
@@ -618,14 +704,46 @@ def load_models(weight_path: str, model_path: str, variant_config_path: str):
|
|
| 618 |
|
| 619 |
sketch_decoder = sketch_decoder.to(device).eval()
|
| 620 |
|
| 621 |
-
# Initialize SVG tokenizer with
|
| 622 |
-
svg_tokenizer = SVGTokenizer(
|
|
|
|
|
|
|
| 623 |
|
| 624 |
print("\n" + "="*60)
|
| 625 |
-
print("All models loaded successfully!")
|
| 626 |
print("="*60 + "\n")
|
| 627 |
|
| 628 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 629 |
def detect_text_subtype(text_prompt):
|
| 630 |
"""Auto-detect text prompt subtype"""
|
| 631 |
text_lower = text_prompt.lower()
|
|
@@ -652,7 +770,9 @@ def detect_text_subtype(text_prompt):
|
|
| 652 |
|
| 653 |
|
| 654 |
def detect_and_replace_background(image, threshold=None, edge_sample_ratio=None):
|
| 655 |
-
"""
|
|
|
|
|
|
|
| 656 |
if threshold is None:
|
| 657 |
threshold = BACKGROUND_THRESHOLD
|
| 658 |
if edge_sample_ratio is None:
|
|
@@ -686,6 +806,11 @@ def detect_and_replace_background(image, threshold=None, edge_sample_ratio=None)
|
|
| 686 |
return image, False
|
| 687 |
|
| 688 |
if len(img_array.shape) == 3 and img_array.shape[2] >= 3:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 689 |
edge_colors = []
|
| 690 |
for i in range(w):
|
| 691 |
edge_colors.append(tuple(img_array[0, i, :3]))
|
|
@@ -713,7 +838,9 @@ def detect_and_replace_background(image, threshold=None, edge_sample_ratio=None)
|
|
| 713 |
|
| 714 |
|
| 715 |
def preprocess_image_for_svg(image, replace_background=True, target_size=None):
|
| 716 |
-
"""
|
|
|
|
|
|
|
| 717 |
if target_size is None:
|
| 718 |
target_size = TARGET_IMAGE_SIZE
|
| 719 |
|
|
@@ -889,7 +1016,8 @@ def generate_candidates(inputs, task_type, subtype, temperature, top_p, top_k, r
|
|
| 889 |
|
| 890 |
all_candidates = []
|
| 891 |
|
| 892 |
-
|
|
|
|
| 893 |
'do_sample': True,
|
| 894 |
'temperature': temperature,
|
| 895 |
'top_p': top_p,
|
|
@@ -918,7 +1046,7 @@ def generate_candidates(inputs, task_type, subtype, temperature, top_p, top_k, r
|
|
| 918 |
max_new_tokens=max_length,
|
| 919 |
num_return_sequences=actual_samples,
|
| 920 |
use_cache=True,
|
| 921 |
-
**
|
| 922 |
)
|
| 923 |
|
| 924 |
input_len = input_ids.shape[1]
|
|
@@ -996,7 +1124,7 @@ def gradio_text_to_svg(text_description, num_candidates, temperature, top_p, top
|
|
| 996 |
return '<div style="text-align:center;color:#999;padding:50px;">Please enter a description</div>', ""
|
| 997 |
|
| 998 |
print("\n" + "="*60)
|
| 999 |
-
print(f"[TASK] text-to-svg ({
|
| 1000 |
print(f"[INPUT] {text_description[:100]}{'...' if len(text_description) > 100 else ''}")
|
| 1001 |
print(f"[PARAMS] candidates={num_candidates}, temp={temperature}, top_p={top_p}, top_k={top_k}, rep_penalty={repetition_penalty}")
|
| 1002 |
print("="*60)
|
|
@@ -1062,7 +1190,7 @@ def gradio_image_to_svg(image, num_candidates, temperature, top_p, top_k, repeti
|
|
| 1062 |
)
|
| 1063 |
|
| 1064 |
print("\n" + "="*60)
|
| 1065 |
-
print(f"[TASK] image-to-svg ({
|
| 1066 |
print(f"[INPUT] Image size: {image.size if hasattr(image, 'size') else 'unknown'}, mode: {image.mode if hasattr(image, 'mode') else 'unknown'}")
|
| 1067 |
print(f"[PARAMS] candidates={num_candidates}, temp={temperature}, top_p={top_p}, top_k={top_k}, rep_penalty={repetition_penalty}, replace_bg={replace_background}")
|
| 1068 |
print("="*60)
|
|
@@ -1149,43 +1277,101 @@ def get_example_images():
|
|
| 1149 |
def create_interface():
|
| 1150 |
"""Create Gradio interface"""
|
| 1151 |
|
| 1152 |
-
# Example prompts
|
| 1153 |
example_texts = [
|
|
|
|
| 1154 |
"A black triangle pointing downward, centrally positioned.",
|
| 1155 |
"A red heart shape with smooth curved edges, centered.",
|
| 1156 |
"A yellow star with five sharp points, simple geometric design, flat color.",
|
| 1157 |
"A blue arrow pointing to the right, thick solid shape, centered.",
|
| 1158 |
"A green circle with a white checkmark inside, centered.",
|
| 1159 |
"A black plus sign with equal length arms, thick lines, centered.",
|
|
|
|
|
|
|
| 1160 |
"A simple person standing: round beige head, rectangular blue shirt body, two dark gray rectangular legs, arms at sides. Flat colors.",
|
| 1161 |
"A girl with long black hair, wearing pink dress with triangular skirt, small circular face with dot eyes and curved smile. Simple cartoon style.",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1162 |
"Circular avatar: person with short black hair, round face with two dot eyes and small curved smile, wearing blue collar shirt. Minimal style, centered in circle.",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1163 |
"Layered mountain landscape: light blue sky at top, gray triangular snow-capped mountains in middle, dark green triangular pine trees at bottom. Flat colors.",
|
| 1164 |
"Sunset beach scene: orange gradient sky at top, yellow semicircle sun on horizon, dark blue wavy ocean, tan beach strip at bottom. Simple shapes.",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1165 |
"Cute orange cat sitting: round head with two triangular ears, oval body, curved tail. Black outline cartoon style, facing forward.",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1166 |
"Simple house icon: red triangular roof, beige rectangular walls, brown door in center, two blue square windows, green ground at bottom.",
|
| 1167 |
"Coffee mug: brown cylindrical cup with curved handle on right, three wavy steam lines rising from top. Flat style.",
|
| 1168 |
-
"
|
| 1169 |
]
|
| 1170 |
|
| 1171 |
example_images = get_example_images()
|
| 1172 |
|
| 1173 |
-
|
| 1174 |
-
header_html = f"""
|
| 1175 |
-
<div class="header-container">
|
| 1176 |
-
<h1>OmniSVG Generator</h1>
|
| 1177 |
-
<p>Transform images and text descriptions into scalable vector graphics</p>
|
| 1178 |
-
<div class="model-badge">Model: OmniSVG {MODEL_VARIANT} | Qwen: {DEFAULT_QWEN_MODEL.split('/')[-1]}</div>
|
| 1179 |
-
</div>
|
| 1180 |
-
"""
|
| 1181 |
-
|
| 1182 |
-
with gr.Blocks(title=f"OmniSVG Generator ({MODEL_VARIANT})") as demo:
|
| 1183 |
# Header
|
| 1184 |
-
gr.HTML(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1185 |
|
| 1186 |
# Queue status
|
| 1187 |
gr.HTML("""
|
| 1188 |
-
<div style="background: #e7f3ff; border: 1px solid #b3d7ff; border-radius: 8px; padding: 12px 15px; margin
|
| 1189 |
<span style="font-size: 1.5em;">ℹ️</span>
|
| 1190 |
<strong>Queue System Active</strong> - Requests processed one at a time. Please wait patiently if busy.
|
| 1191 |
</div>
|
|
@@ -1340,7 +1526,7 @@ def create_interface():
|
|
| 1340 |
elem_classes=["primary-btn"]
|
| 1341 |
)
|
| 1342 |
|
| 1343 |
-
gr.Markdown("### Example Prompts")
|
| 1344 |
gr.Examples(
|
| 1345 |
examples=[[text] for text in example_texts],
|
| 1346 |
inputs=[text_input],
|
|
@@ -1372,7 +1558,7 @@ def create_interface():
|
|
| 1372 |
# Footer
|
| 1373 |
gr.HTML(f"""
|
| 1374 |
<div class="footer">
|
| 1375 |
-
<p>Built with OmniSVG {
|
| 1376 |
<p style="color: #dc3545; font-weight: 600;">Remember: Generate 4-8 candidates and pick the best!</p>
|
| 1377 |
</div>
|
| 1378 |
""")
|
|
@@ -1385,31 +1571,20 @@ if __name__ == "__main__":
|
|
| 1385 |
|
| 1386 |
args = parse_args()
|
| 1387 |
|
|
|
|
|
|
|
|
|
|
| 1388 |
print("="*60)
|
| 1389 |
print("OmniSVG Demo Page - Gradio App")
|
| 1390 |
print("="*60)
|
| 1391 |
-
|
| 1392 |
-
|
| 1393 |
-
print(f"
|
| 1394 |
-
|
| 1395 |
-
MODEL_VARIANT = config['active_variant']
|
| 1396 |
-
|
| 1397 |
-
# Initialize constants from config
|
| 1398 |
-
init_config_constants(config)
|
| 1399 |
-
|
| 1400 |
-
# Override model paths if provided via command line
|
| 1401 |
-
weight_path = args.weight_path if args.weight_path else DEFAULT_OMNISVG_MODEL
|
| 1402 |
-
model_path = args.model_path if args.model_path else DEFAULT_QWEN_MODEL
|
| 1403 |
-
|
| 1404 |
-
print(f"\n[CONFIG] Active variant: {MODEL_VARIANT}")
|
| 1405 |
-
print(f"[CONFIG] Qwen model: {model_path}")
|
| 1406 |
-
print(f"[CONFIG] OmniSVG weights: {weight_path}")
|
| 1407 |
-
print(f"[CONFIG] Device: {device}")
|
| 1408 |
-
print(f"[CONFIG] Precision: {DTYPE}")
|
| 1409 |
print("="*60)
|
| 1410 |
|
| 1411 |
# Print loaded config values
|
| 1412 |
-
print("\n[CONFIG]
|
| 1413 |
print(f" - TARGET_IMAGE_SIZE: {TARGET_IMAGE_SIZE}")
|
| 1414 |
print(f" - RENDER_SIZE: {RENDER_SIZE}")
|
| 1415 |
print(f" - BLACK_COLOR_TOKEN: {BLACK_COLOR_TOKEN}")
|
|
@@ -1417,21 +1592,10 @@ if __name__ == "__main__":
|
|
| 1417 |
print(f" - BOS_TOKEN_ID: {BOS_TOKEN_ID}")
|
| 1418 |
print(f" - EOS_TOKEN_ID: {EOS_TOKEN_ID}")
|
| 1419 |
print(f" - PAD_TOKEN_ID: {PAD_TOKEN_ID}")
|
| 1420 |
-
|
| 1421 |
-
# Print variant-specific token offsets
|
| 1422 |
-
print(f"\n[CONFIG] Variant-specific ({MODEL_VARIANT}):")
|
| 1423 |
-
print(f" - base_offset: {config.get('tokens', {}).get('base_offset', 'N/A')}")
|
| 1424 |
-
print(f" - color_start_offset: {config.get('colors', {}).get('color_start_offset', 'N/A')}")
|
| 1425 |
-
print(f" - color_end_offset: {config.get('colors', {}).get('color_end_offset', 'N/A')}")
|
| 1426 |
print("="*60)
|
| 1427 |
|
| 1428 |
-
# Write variant-specific config for SVGTokenizer
|
| 1429 |
-
variant_config_path = f'./config_{MODEL_VARIANT.lower()}_runtime.yaml'
|
| 1430 |
-
write_variant_config(config, variant_config_path)
|
| 1431 |
-
print(f"\n[CONFIG] Written variant config to: {variant_config_path}")
|
| 1432 |
-
|
| 1433 |
print("\nLoading models (may download from HuggingFace Hub if needed)...")
|
| 1434 |
-
load_models(weight_path, model_path
|
| 1435 |
print("Models loaded successfully!\n")
|
| 1436 |
|
| 1437 |
demo = create_interface()
|
|
|
|
| 12 |
import numpy as np
|
| 13 |
import time
|
| 14 |
import threading
|
|
|
|
| 15 |
import spaces
|
| 16 |
|
| 17 |
from huggingface_hub import hf_hub_download, snapshot_download
|
|
|
|
| 21 |
from qwen_vl_utils import process_vision_info
|
| 22 |
from tokenizer import SVGTokenizer
|
| 23 |
|
| 24 |
+
# Load config
|
| 25 |
+
CONFIG_PATH = './config.yaml'
|
| 26 |
+
with open(CONFIG_PATH, 'r') as f:
|
| 27 |
+
config = yaml.safe_load(f)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 30 |
DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
| 31 |
|
| 32 |
+
# Global Models (will be loaded based on selected model size)
|
| 33 |
tokenizer = None
|
| 34 |
processor = None
|
| 35 |
sketch_decoder = None
|
| 36 |
svg_tokenizer = None
|
| 37 |
+
current_model_size = None # Track which model is currently loaded
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
# Thread lock for model inference
|
| 40 |
generation_lock = threading.Lock()
|
| 41 |
+
model_loading_lock = threading.Lock()
|
| 42 |
|
| 43 |
+
# Constants from config
|
| 44 |
SYSTEM_PROMPT = """You are an expert SVG code generator.
|
| 45 |
Generate precise, valid SVG path commands that accurately represent the described scene or object.
|
| 46 |
Focus on capturing key shapes, spatial relationships, and visual composition."""
|
| 47 |
|
| 48 |
SUPPORTED_FORMATS = ['.png', '.jpg', '.jpeg', '.webp', '.bmp', '.gif']
|
| 49 |
+
AVAILABLE_MODEL_SIZES = list(config.get('models', {}).keys())
|
| 50 |
+
DEFAULT_MODEL_SIZE = config.get('default_model_size', '8B')
|
| 51 |
|
| 52 |
+
# ============================================================
|
| 53 |
+
# Helper function to get config value (model-specific or shared)
|
| 54 |
+
# ============================================================
|
| 55 |
+
def get_config_value(model_size, *keys):
|
| 56 |
+
"""Get config value with model-specific override support."""
|
| 57 |
+
# Try model-specific config first
|
| 58 |
+
model_cfg = config.get('models', {}).get(model_size, {})
|
| 59 |
+
value = model_cfg
|
| 60 |
+
for key in keys:
|
| 61 |
+
if isinstance(value, dict) and key in value:
|
| 62 |
+
value = value[key]
|
| 63 |
+
else:
|
| 64 |
+
value = None
|
| 65 |
+
break
|
| 66 |
+
|
| 67 |
+
# Fallback to shared config if not found
|
| 68 |
+
if value is None:
|
| 69 |
+
value = config
|
| 70 |
+
for key in keys:
|
| 71 |
+
if isinstance(value, dict) and key in value:
|
| 72 |
+
value = value[key]
|
| 73 |
+
else:
|
| 74 |
+
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
+
return value
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# ============================================================
|
| 80 |
+
# Image processing settings from config (shared)
|
| 81 |
+
# ============================================================
|
| 82 |
+
image_config = config.get('image', {})
|
| 83 |
+
TARGET_IMAGE_SIZE = image_config.get('target_size', 448)
|
| 84 |
+
RENDER_SIZE = image_config.get('render_size', 512)
|
| 85 |
+
BACKGROUND_THRESHOLD = image_config.get('background_threshold', 240)
|
| 86 |
+
EMPTY_THRESHOLD_ILLUSTRATION = image_config.get('empty_threshold_illustration', 250)
|
| 87 |
+
EMPTY_THRESHOLD_ICON = image_config.get('empty_threshold_icon', 252)
|
| 88 |
+
EDGE_SAMPLE_RATIO = image_config.get('edge_sample_ratio', 0.1)
|
| 89 |
+
COLOR_SIMILARITY_THRESHOLD = image_config.get('color_similarity_threshold', 30)
|
| 90 |
+
MIN_EDGE_SAMPLES = image_config.get('min_edge_samples', 10)
|
| 91 |
+
|
| 92 |
+
# ============================================================
|
| 93 |
+
# Color settings from config (shared)
|
| 94 |
+
# ============================================================
|
| 95 |
+
colors_config = config.get('colors', {})
|
| 96 |
+
BLACK_COLOR_TOKEN = colors_config.get('black_color_token',
|
| 97 |
+
colors_config.get('color_token_start', 40010) + 2)
|
| 98 |
+
|
| 99 |
+
# ============================================================
|
| 100 |
+
# Model settings from config (shared)
|
| 101 |
+
# ============================================================
|
| 102 |
+
model_config = config.get('model', {})
|
| 103 |
+
BOS_TOKEN_ID = model_config.get('bos_token_id', 196998)
|
| 104 |
+
EOS_TOKEN_ID = model_config.get('eos_token_id', 196999)
|
| 105 |
+
PAD_TOKEN_ID = model_config.get('pad_token_id', 151643)
|
| 106 |
+
MAX_LENGTH = model_config.get('max_length', 1536)
|
| 107 |
+
|
| 108 |
+
# ============================================================
|
| 109 |
+
# Task configurations with defaults from config (shared)
|
| 110 |
+
# ============================================================
|
| 111 |
+
task_config = config.get('task_configs', {})
|
| 112 |
+
|
| 113 |
+
TASK_CONFIGS = {
|
| 114 |
+
"text-to-svg-icon": task_config.get('text_to_svg_icon', {
|
| 115 |
+
"default_temperature": 0.5,
|
| 116 |
+
"default_top_p": 0.88,
|
| 117 |
+
"default_top_k": 50,
|
| 118 |
+
"default_repetition_penalty": 1.05,
|
| 119 |
+
}),
|
| 120 |
+
"text-to-svg-illustration": task_config.get('text_to_svg_illustration', {
|
| 121 |
+
"default_temperature": 0.6,
|
| 122 |
+
"default_top_p": 0.90,
|
| 123 |
+
"default_top_k": 60,
|
| 124 |
+
"default_repetition_penalty": 1.03,
|
| 125 |
+
}),
|
| 126 |
+
"image-to-svg": task_config.get('image_to_svg', {
|
| 127 |
+
"default_temperature": 0.3,
|
| 128 |
+
"default_top_p": 0.90,
|
| 129 |
+
"default_top_k": 50,
|
| 130 |
+
"default_repetition_penalty": 1.05,
|
| 131 |
+
})
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
# ============================================================
|
| 135 |
+
# Generation parameters from config (shared)
|
| 136 |
+
# ============================================================
|
| 137 |
+
gen_config = config.get('generation', {})
|
| 138 |
+
DEFAULT_NUM_CANDIDATES = gen_config.get('default_num_candidates', 4)
|
| 139 |
+
MAX_NUM_CANDIDATES = gen_config.get('max_num_candidates', 8)
|
| 140 |
+
EXTRA_CANDIDATES_BUFFER = gen_config.get('extra_candidates_buffer', 4)
|
| 141 |
|
| 142 |
+
# ============================================================
|
| 143 |
+
# Validation settings from config (shared)
|
| 144 |
+
# ============================================================
|
| 145 |
+
validation_config = config.get('validation', {})
|
| 146 |
+
MIN_SVG_LENGTH = validation_config.get('min_svg_length', 20)
|
| 147 |
|
| 148 |
+
# Custom CSS
|
| 149 |
CUSTOM_CSS = """
|
| 150 |
/* Main container centering */
|
| 151 |
.gradio-container {
|
|
|
|
| 172 |
opacity: 0.9;
|
| 173 |
font-size: 1.1em;
|
| 174 |
}
|
| 175 |
+
/* Model selector styling */
|
| 176 |
+
.model-selector {
|
| 177 |
+
background: #f0f4f8;
|
| 178 |
+
border: 2px solid #667eea;
|
| 179 |
+
border-radius: 12px;
|
| 180 |
+
padding: 15px;
|
| 181 |
+
margin-bottom: 20px;
|
| 182 |
+
}
|
| 183 |
+
.model-selector-title {
|
| 184 |
+
font-weight: 700;
|
| 185 |
+
color: #667eea;
|
| 186 |
+
margin-bottom: 10px;
|
| 187 |
}
|
| 188 |
/* Tips section */
|
| 189 |
.tips-box {
|
|
|
|
| 262 |
.green-box strong {
|
| 263 |
color: #4caf50;
|
| 264 |
}
|
| 265 |
+
.blue-box {
|
| 266 |
+
background: #e3f2fd;
|
| 267 |
+
border: 1px solid #90caf9;
|
| 268 |
+
border-left: 4px solid #2196f3;
|
| 269 |
+
padding: 12px;
|
| 270 |
+
border-radius: 8px;
|
| 271 |
+
margin: 10px 0;
|
| 272 |
+
}
|
| 273 |
+
.blue-box strong {
|
| 274 |
+
color: #2196f3;
|
| 275 |
+
}
|
| 276 |
/* Tab styling */
|
| 277 |
.tabs {
|
| 278 |
border-radius: 12px !important;
|
|
|
|
| 341 |
}
|
| 342 |
"""
|
| 343 |
|
| 344 |
+
# Enhanced Tips HTML
|
| 345 |
TIPS_HTML = """
|
| 346 |
<div class="tips-box">
|
| 347 |
<h3>Prompting Guide & Best Practices</h3>
|
|
|
|
| 368 |
</ul>
|
| 369 |
</div>
|
| 370 |
|
| 371 |
+
<!-- Model Selection Tips -->
|
| 372 |
+
<div class="blue-box">
|
| 373 |
+
<strong>Model Selection Guide</strong>
|
| 374 |
+
<ul style="margin: 8px 0 0 0; padding-left: 20px;">
|
| 375 |
+
<li><strong>8B Model:</strong> Higher quality, more details, better for complex illustrations. Requires more VRAM (~16GB+).</li>
|
| 376 |
+
<li><strong>4B Model:</strong> Faster, less VRAM required (~8GB+). Good for simple icons and basic shapes.</li>
|
| 377 |
+
</ul>
|
| 378 |
+
</div>
|
| 379 |
+
|
| 380 |
<!-- Parameter Tuning Tips -->
|
| 381 |
<div class="orange-box">
|
| 382 |
<strong>Parameter Tuning Guide</strong>
|
|
|
|
| 442 |
<div class="example-prompt">
|
| 443 |
"A simple person: round beige head, rectangular blue shirt body, two dark gray rectangular legs. Standing pose, arms at sides, flat colors."
|
| 444 |
</div>
|
| 445 |
+
<div class="example-prompt">
|
| 446 |
+
"A girl with long black hair, pink dress with triangular skirt shape, small circular face with dot eyes and curved smile. Simple cartoon style."
|
| 447 |
+
</div>
|
| 448 |
<p class="red-tip">Keep poses SIMPLE: standing, sitting, waving. Avoid complex actions!</p>
|
| 449 |
</div>
|
| 450 |
|
| 451 |
+
<div class="tip-category">
|
| 452 |
+
<h4>Avatars & Portraits</h4>
|
| 453 |
+
<p>Use circular frame, focus on face and upper body only.</p>
|
| 454 |
+
<div class="example-prompt">
|
| 455 |
+
"Circular avatar: person with short black hair, round face with two dot eyes and small curved smile, wearing blue collar shirt. Minimal style."
|
| 456 |
+
</div>
|
| 457 |
+
<div class="example-prompt">
|
| 458 |
+
"Profile avatar silhouette: black side view of head with short hair, facing right. Simple solid shape."
|
| 459 |
+
</div>
|
| 460 |
+
</div>
|
| 461 |
+
|
| 462 |
<div class="tip-category">
|
| 463 |
<h4>Landscapes & Scenes</h4>
|
| 464 |
<p>Layer elements from background to foreground. Specify color for EACH layer.</p>
|
| 465 |
<div class="example-prompt">
|
| 466 |
"Layered landscape: light blue sky at top, gray triangular mountains in middle, dark green triangular pine trees at bottom. Flat colors, simple shapes."
|
| 467 |
</div>
|
| 468 |
+
<div class="example-prompt">
|
| 469 |
+
"Sunset beach: orange gradient sky at top, yellow semicircle sun on horizon, dark blue wavy ocean below, tan beach at bottom."
|
| 470 |
+
</div>
|
| 471 |
<p class="red-tip">Use geometric shapes for nature: triangular trees, wavy water, semicircle sun!</p>
|
| 472 |
</div>
|
| 473 |
|
|
|
|
| 477 |
<div class="example-prompt">
|
| 478 |
"Cute cat: orange round head with two triangular ears, oval orange body, curved tail. Simple cartoon style with black outlines, sitting pose."
|
| 479 |
</div>
|
| 480 |
+
<div class="example-prompt">
|
| 481 |
+
"Simple black bird: oval body, small round head, pointed triangular beak facing right, triangular tail, two stick legs. Silhouette style."
|
| 482 |
+
</div>
|
| 483 |
+
</div>
|
| 484 |
+
|
| 485 |
+
<div class="tip-category">
|
| 486 |
+
<h4>Buildings & Objects</h4>
|
| 487 |
+
<p>Use basic shapes: rectangles for walls, triangles for roofs, squares for windows.</p>
|
| 488 |
+
<div class="example-prompt">
|
| 489 |
+
"Simple house: red triangular roof on top, beige rectangular wall, brown rectangular door in center, two small blue square windows. Green ground at bottom."
|
| 490 |
+
</div>
|
| 491 |
+
<div class="example-prompt">
|
| 492 |
+
"Coffee mug: brown cylindrical cup shape with curved handle on right side, three wavy steam lines rising from top. Simple flat style."
|
| 493 |
+
</div>
|
| 494 |
</div>
|
| 495 |
|
| 496 |
</div>
|
| 497 |
|
| 498 |
+
<!-- Extended Examples Section -->
|
| 499 |
+
<div style="margin-top: 20px; padding: 15px; background: #f0f7ff; border-radius: 10px; border: 1px solid #cce5ff;">
|
| 500 |
+
<h4 style="margin-top: 0; color: #0066cc;">More Complex Examples (Generate 6-8 candidates!)</h4>
|
| 501 |
+
|
| 502 |
+
<div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(280px, 1fr)); gap: 12px; margin-top: 15px;">
|
| 503 |
+
<div class="example-prompt">
|
| 504 |
+
<strong>Business Avatar:</strong><br/>
|
| 505 |
+
"Circular professional avatar: man with short black hair, neutral skin tone round face, wearing dark navy suit with white shirt collar visible. Clean minimal style, centered in circle."
|
| 506 |
+
</div>
|
| 507 |
+
<div class="example-prompt">
|
| 508 |
+
<strong>Female Portrait:</strong><br/>
|
| 509 |
+
"Simple female face: oval face shape, long brown wavy hair on sides, two dot eyes, small nose, curved smile lips. Pink blush on cheeks. Cartoon portrait style."
|
| 510 |
+
</div>
|
| 511 |
+
<div class="example-prompt">
|
| 512 |
+
<strong>Child Character:</strong><br/>
|
| 513 |
+
"Cute child standing: large round head with short brown hair, big circular eyes with white highlights, small body in red t-shirt and blue shorts, simple stick arms and legs. Cheerful cartoon style."
|
| 514 |
+
</div>
|
| 515 |
+
<div class="example-prompt">
|
| 516 |
+
<strong>Active Pose:</strong><br/>
|
| 517 |
+
"Person walking: side view, circular head, rectangular torso in green jacket, legs in walking position (one forward, one back). Simple geometric style, moving right."
|
| 518 |
+
</div>
|
| 519 |
+
<div class="example-prompt">
|
| 520 |
+
<strong>Forest Scene:</strong><br/>
|
| 521 |
+
"Simple forest: light blue sky, row of 5 dark green triangular pine trees of varying heights, brown rectangular trunks, light green grass strip at bottom. Layered flat design."
|
| 522 |
+
</div>
|
| 523 |
+
<div class="example-prompt">
|
| 524 |
+
<strong>Ocean View:</strong><br/>
|
| 525 |
+
"Minimalist ocean: gradient blue sky at top, three horizontal wavy lines in dark blue for ocean, small white sailboat with triangular sail in center. Clean vector style."
|
| 526 |
+
</div>
|
| 527 |
+
<div class="example-prompt">
|
| 528 |
+
<strong>City Skyline:</strong><br/>
|
| 529 |
+
"Simple city skyline: orange sunset sky gradient, row of black rectangular building silhouettes of different heights, some with small yellow square windows. Minimalist style."
|
| 530 |
+
</div>
|
| 531 |
+
<div class="example-prompt">
|
| 532 |
+
<strong>Dog Character:</strong><br/>
|
| 533 |
+
"Friendly cartoon dog: brown oval body, round head with floppy ears, black dot nose, curved tail pointing up, four short legs. Sitting pose facing forward."
|
| 534 |
+
</div>
|
| 535 |
+
</div>
|
| 536 |
+
</div>
|
| 537 |
+
|
| 538 |
<!-- Quick Troubleshooting -->
|
| 539 |
<div class="green-box" style="margin-top: 15px;">
|
| 540 |
<strong>Quick Troubleshooting</strong>
|
|
|
|
| 547 |
<li><strong>Inconsistent?</strong> <span class="red-tip">Generate MORE candidates (6-8) and pick the best!</span></li>
|
| 548 |
</ul>
|
| 549 |
</div>
|
| 550 |
+
|
| 551 |
+
<!-- Prompt Template -->
|
| 552 |
+
<div style="margin-top: 15px; padding: 12px; background: #e8f5e9; border-radius: 8px; border-left: 4px solid #4caf50;">
|
| 553 |
+
<strong>Recommended Prompt Structure</strong>
|
| 554 |
+
<div style="background: white; padding: 10px; border-radius: 6px; margin-top: 8px; font-family: monospace; font-size: 0.9em;">
|
| 555 |
+
[Subject] + [Shape descriptions with colors] + [Position/orientation] + [Style]
|
| 556 |
+
</div>
|
| 557 |
+
<p style="margin: 10px 0 0 0; color: #2e7d32; font-size: 0.95em;">
|
| 558 |
+
Example: "A fox logo: triangular orange head, pointed ears, white chest marking, facing right. Minimalist flat style, centered."
|
| 559 |
+
</p>
|
| 560 |
+
</div>
|
| 561 |
</div>
|
| 562 |
"""
|
| 563 |
|
|
|
|
| 582 |
parser.add_argument('--port', type=int, default=7860)
|
| 583 |
parser.add_argument('--share', action='store_true')
|
| 584 |
parser.add_argument('--debug', action='store_true')
|
| 585 |
+
parser.add_argument('--model_size', type=str, default=None,
|
| 586 |
+
choices=AVAILABLE_MODEL_SIZES,
|
| 587 |
+
help=f'Model size to load at startup (default: {DEFAULT_MODEL_SIZE}). Can be changed in UI.')
|
|
|
|
| 588 |
parser.add_argument('--weight_path', type=str, default=None,
|
| 589 |
help='HuggingFace repo ID or local path for OmniSVG weights (overrides config)')
|
| 590 |
parser.add_argument('--model_path', type=str, default=None,
|
|
|
|
| 595 |
def download_model_weights(repo_id: str, filename: str = "pytorch_model.bin") -> str:
|
| 596 |
"""
|
| 597 |
Download model weights from Hugging Face Hub.
|
| 598 |
+
|
| 599 |
+
Args:
|
| 600 |
+
repo_id: Hugging Face repository ID (e.g., 'OmniSVG/OmniSVG1.1_8B')
|
| 601 |
+
filename: Name of the weights file to download
|
| 602 |
+
|
| 603 |
+
Returns:
|
| 604 |
+
Local path to the downloaded file
|
| 605 |
"""
|
| 606 |
print(f"Downloading {filename} from {repo_id}...")
|
| 607 |
try:
|
|
|
|
| 630 |
return False
|
| 631 |
|
| 632 |
|
| 633 |
+
def load_models(model_size: str, weight_path: str = None, model_path: str = None):
|
| 634 |
"""
|
| 635 |
+
Load all models for a specific model size.
|
| 636 |
|
| 637 |
Args:
|
| 638 |
+
model_size: Model size ("8B" or "4B")
|
| 639 |
+
weight_path: Local path or HuggingFace repo ID for OmniSVG weights (optional, uses config if None)
|
| 640 |
+
model_path: Local path or HuggingFace repo ID for Qwen model (optional, uses config if None)
|
| 641 |
"""
|
| 642 |
+
global tokenizer, processor, sketch_decoder, svg_tokenizer, current_model_size
|
| 643 |
|
| 644 |
+
# Use config values if not provided
|
| 645 |
+
if weight_path is None:
|
| 646 |
+
weight_path = get_config_value(model_size, 'huggingface', 'omnisvg_model')
|
| 647 |
+
if model_path is None:
|
| 648 |
+
model_path = get_config_value(model_size, 'huggingface', 'qwen_model')
|
| 649 |
+
|
| 650 |
+
print(f"\n{'='*60}")
|
| 651 |
+
print(f"Loading {model_size} Model")
|
| 652 |
+
print(f"{'='*60}")
|
| 653 |
+
print(f"Qwen model: {model_path}")
|
| 654 |
+
print(f"OmniSVG weights: {weight_path}")
|
| 655 |
+
print(f"Precision: {DTYPE}")
|
| 656 |
|
| 657 |
# Load Qwen tokenizer and processor
|
| 658 |
print("\n[1/3] Loading tokenizer and processor...")
|
|
|
|
| 669 |
processor.tokenizer.padding_side = "left"
|
| 670 |
print("Tokenizer and processor loaded successfully!")
|
| 671 |
|
| 672 |
+
# Initialize sketch decoder with model_size
|
| 673 |
print("\n[2/3] Initializing SketchDecoder...")
|
| 674 |
sketch_decoder = SketchDecoder(
|
| 675 |
+
config_path=CONFIG_PATH,
|
| 676 |
+
model_path=model_path,
|
| 677 |
+
model_size=model_size,
|
| 678 |
pix_len=MAX_LENGTH,
|
| 679 |
text_len=config.get('text', {}).get('max_length', 200),
|
|
|
|
| 680 |
torch_dtype=DTYPE
|
| 681 |
)
|
| 682 |
|
|
|
|
| 704 |
|
| 705 |
sketch_decoder = sketch_decoder.to(device).eval()
|
| 706 |
|
| 707 |
+
# Initialize SVG tokenizer with model_size
|
| 708 |
+
svg_tokenizer = SVGTokenizer(CONFIG_PATH, model_size=model_size)
|
| 709 |
+
|
| 710 |
+
current_model_size = model_size
|
| 711 |
|
| 712 |
print("\n" + "="*60)
|
| 713 |
+
print(f"All {model_size} models loaded successfully!")
|
| 714 |
print("="*60 + "\n")
|
| 715 |
|
| 716 |
|
| 717 |
+
def switch_model(new_model_size: str):
|
| 718 |
+
"""
|
| 719 |
+
Switch to a different model size.
|
| 720 |
+
|
| 721 |
+
Args:
|
| 722 |
+
new_model_size: Target model size ("8B" or "4B")
|
| 723 |
+
|
| 724 |
+
Returns:
|
| 725 |
+
Status message
|
| 726 |
+
"""
|
| 727 |
+
global current_model_size
|
| 728 |
+
|
| 729 |
+
if new_model_size == current_model_size:
|
| 730 |
+
return f"✅ Already using {new_model_size} model"
|
| 731 |
+
|
| 732 |
+
with model_loading_lock:
|
| 733 |
+
# Clear memory
|
| 734 |
+
gc.collect()
|
| 735 |
+
if torch.cuda.is_available():
|
| 736 |
+
torch.cuda.empty_cache()
|
| 737 |
+
|
| 738 |
+
try:
|
| 739 |
+
load_models(new_model_size)
|
| 740 |
+
return f"✅ Successfully switched to {new_model_size} model"
|
| 741 |
+
except Exception as e:
|
| 742 |
+
error_msg = f"❌ Failed to switch to {new_model_size}: {str(e)}"
|
| 743 |
+
print(error_msg)
|
| 744 |
+
return error_msg
|
| 745 |
+
|
| 746 |
+
|
| 747 |
def detect_text_subtype(text_prompt):
|
| 748 |
"""Auto-detect text prompt subtype"""
|
| 749 |
text_lower = text_prompt.lower()
|
|
|
|
| 770 |
|
| 771 |
|
| 772 |
def detect_and_replace_background(image, threshold=None, edge_sample_ratio=None):
|
| 773 |
+
"""
|
| 774 |
+
Detect if image has non-white background and optionally replace it.
|
| 775 |
+
"""
|
| 776 |
if threshold is None:
|
| 777 |
threshold = BACKGROUND_THRESHOLD
|
| 778 |
if edge_sample_ratio is None:
|
|
|
|
| 806 |
return image, False
|
| 807 |
|
| 808 |
if len(img_array.shape) == 3 and img_array.shape[2] >= 3:
|
| 809 |
+
if img_array.shape[2] == 4:
|
| 810 |
+
gray = np.mean(img_array[:, :, :3], axis=2)
|
| 811 |
+
else:
|
| 812 |
+
gray = np.mean(img_array, axis=2)
|
| 813 |
+
|
| 814 |
edge_colors = []
|
| 815 |
for i in range(w):
|
| 816 |
edge_colors.append(tuple(img_array[0, i, :3]))
|
|
|
|
| 838 |
|
| 839 |
|
| 840 |
def preprocess_image_for_svg(image, replace_background=True, target_size=None):
|
| 841 |
+
"""
|
| 842 |
+
Preprocess image for SVG generation.
|
| 843 |
+
"""
|
| 844 |
if target_size is None:
|
| 845 |
target_size = TARGET_IMAGE_SIZE
|
| 846 |
|
|
|
|
| 1016 |
|
| 1017 |
all_candidates = []
|
| 1018 |
|
| 1019 |
+
# Generation config with user parameters
|
| 1020 |
+
gen_cfg = {
|
| 1021 |
'do_sample': True,
|
| 1022 |
'temperature': temperature,
|
| 1023 |
'top_p': top_p,
|
|
|
|
| 1046 |
max_new_tokens=max_length,
|
| 1047 |
num_return_sequences=actual_samples,
|
| 1048 |
use_cache=True,
|
| 1049 |
+
**gen_cfg
|
| 1050 |
)
|
| 1051 |
|
| 1052 |
input_len = input_ids.shape[1]
|
|
|
|
| 1124 |
return '<div style="text-align:center;color:#999;padding:50px;">Please enter a description</div>', ""
|
| 1125 |
|
| 1126 |
print("\n" + "="*60)
|
| 1127 |
+
print(f"[TASK] text-to-svg (Model: {current_model_size})")
|
| 1128 |
print(f"[INPUT] {text_description[:100]}{'...' if len(text_description) > 100 else ''}")
|
| 1129 |
print(f"[PARAMS] candidates={num_candidates}, temp={temperature}, top_p={top_p}, top_k={top_k}, rep_penalty={repetition_penalty}")
|
| 1130 |
print("="*60)
|
|
|
|
| 1190 |
)
|
| 1191 |
|
| 1192 |
print("\n" + "="*60)
|
| 1193 |
+
print(f"[TASK] image-to-svg (Model: {current_model_size})")
|
| 1194 |
print(f"[INPUT] Image size: {image.size if hasattr(image, 'size') else 'unknown'}, mode: {image.mode if hasattr(image, 'mode') else 'unknown'}")
|
| 1195 |
print(f"[PARAMS] candidates={num_candidates}, temp={temperature}, top_p={top_p}, top_k={top_k}, rep_penalty={repetition_penalty}, replace_bg={replace_background}")
|
| 1196 |
print("="*60)
|
|
|
|
| 1277 |
def create_interface():
|
| 1278 |
"""Create Gradio interface"""
|
| 1279 |
|
| 1280 |
+
# 30 Example prompts covering various categories
|
| 1281 |
example_texts = [
|
| 1282 |
+
# === Simple Icons (1-6) ===
|
| 1283 |
"A black triangle pointing downward, centrally positioned.",
|
| 1284 |
"A red heart shape with smooth curved edges, centered.",
|
| 1285 |
"A yellow star with five sharp points, simple geometric design, flat color.",
|
| 1286 |
"A blue arrow pointing to the right, thick solid shape, centered.",
|
| 1287 |
"A green circle with a white checkmark inside, centered.",
|
| 1288 |
"A black plus sign with equal length arms, thick lines, centered.",
|
| 1289 |
+
|
| 1290 |
+
# === Characters & People (7-12) ===
|
| 1291 |
"A simple person standing: round beige head, rectangular blue shirt body, two dark gray rectangular legs, arms at sides. Flat colors.",
|
| 1292 |
"A girl with long black hair, wearing pink dress with triangular skirt, small circular face with dot eyes and curved smile. Simple cartoon style.",
|
| 1293 |
+
"A child waving: large round head with brown messy hair, big circular eyes, small body in red t-shirt and blue shorts, one arm raised. Cheerful cartoon style.",
|
| 1294 |
+
"A person sitting on chair: side view, round head, rectangular torso in green sweater, bent legs on simple chair shape. Relaxed pose.",
|
| 1295 |
+
"A running person: side view silhouette in black, dynamic pose with one leg forward, arms pumping. Motion style.",
|
| 1296 |
+
|
| 1297 |
+
# === Avatars & Portraits (13-17) ===
|
| 1298 |
"Circular avatar: person with short black hair, round face with two dot eyes and small curved smile, wearing blue collar shirt. Minimal style, centered in circle.",
|
| 1299 |
+
"Female avatar: oval face with long wavy brown hair, simple eyes, pink lips, wearing v-neck purple top. Soft cartoon style in circular frame.",
|
| 1300 |
+
"Profile silhouette avatar: black side view of head with short hair and glasses outline, facing right. Simple solid shape.",
|
| 1301 |
+
"Cute cartoon avatar: round face with big sparkly eyes, rosy cheeks, short bob haircut in orange. Kawaii style, circular frame.",
|
| 1302 |
+
"Professional headshot avatar: person with neat hair, neutral expression, wearing suit collar. Corporate minimal style, circular frame.",
|
| 1303 |
+
|
| 1304 |
+
# === Landscapes & Scenes (18-23) ===
|
| 1305 |
"Layered mountain landscape: light blue sky at top, gray triangular snow-capped mountains in middle, dark green triangular pine trees at bottom. Flat colors.",
|
| 1306 |
"Sunset beach scene: orange gradient sky at top, yellow semicircle sun on horizon, dark blue wavy ocean, tan beach strip at bottom. Simple shapes.",
|
| 1307 |
+
"Forest scene: light blue sky, row of 5 dark green triangular pine trees of varying heights on brown trunks, light green grass at bottom.",
|
| 1308 |
+
"City skyline at dusk: purple-orange gradient sky, row of black rectangular building silhouettes of different heights, some with yellow window squares.",
|
| 1309 |
+
"Desert landscape: light orange sky with white circle sun, tan sand dunes as curved shapes, one green cactus with arms on the right side.",
|
| 1310 |
+
"Countryside scene: blue sky with white fluffy clouds, green rolling hills, small red barn with white door in the center, yellow hay bales.",
|
| 1311 |
+
|
| 1312 |
+
# === Animals (24-27) ===
|
| 1313 |
"Cute orange cat sitting: round head with two triangular ears, oval body, curved tail. Black outline cartoon style, facing forward.",
|
| 1314 |
+
"Simple black bird: oval body, round head, pointed triangular beak facing right, triangular tail, two stick legs. Silhouette style.",
|
| 1315 |
+
"Friendly cartoon dog: brown oval body, round head with floppy ears, black dot nose, wagging curved tail, four short legs. Sitting pose.",
|
| 1316 |
+
"Red fox logo: triangular orange face with pointed ears, white chest marking, bushy tail. Minimalist style, facing right, centered.",
|
| 1317 |
+
|
| 1318 |
+
# === Objects & Misc (28-30) ===
|
| 1319 |
"Simple house icon: red triangular roof, beige rectangular walls, brown door in center, two blue square windows, green ground at bottom.",
|
| 1320 |
"Coffee mug: brown cylindrical cup with curved handle on right, three wavy steam lines rising from top. Flat style.",
|
| 1321 |
+
"Open book: two rectangular white pages spread open, black text lines on each page, brown spine in center. Simple top-down view."
|
| 1322 |
]
|
| 1323 |
|
| 1324 |
example_images = get_example_images()
|
| 1325 |
|
| 1326 |
+
with gr.Blocks(title="OmniSVG Generator", css=CUSTOM_CSS) as demo:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1327 |
# Header
|
| 1328 |
+
gr.HTML("""
|
| 1329 |
+
<div class="header-container">
|
| 1330 |
+
<h1>OmniSVG Generator</h1>
|
| 1331 |
+
<p>Transform images and text descriptions into scalable vector graphics</p>
|
| 1332 |
+
</div>
|
| 1333 |
+
""")
|
| 1334 |
+
|
| 1335 |
+
# Model Selection Section
|
| 1336 |
+
with gr.Row():
|
| 1337 |
+
with gr.Column():
|
| 1338 |
+
gr.HTML("""
|
| 1339 |
+
<div class="blue-box">
|
| 1340 |
+
<strong>🔧 Model Selection</strong>
|
| 1341 |
+
<p style="margin: 5px 0 0 0; font-size: 0.9em;">
|
| 1342 |
+
Choose between <b>8B</b> (higher quality, more VRAM) or <b>4B</b> (faster, less VRAM).
|
| 1343 |
+
</p>
|
| 1344 |
+
</div>
|
| 1345 |
+
""")
|
| 1346 |
+
|
| 1347 |
+
with gr.Row():
|
| 1348 |
+
model_selector = gr.Dropdown(
|
| 1349 |
+
choices=AVAILABLE_MODEL_SIZES,
|
| 1350 |
+
value=DEFAULT_MODEL_SIZE,
|
| 1351 |
+
label="Model Size",
|
| 1352 |
+
info="8B: ~16GB VRAM, higher quality | 4B: ~8GB VRAM, faster",
|
| 1353 |
+
interactive=True,
|
| 1354 |
+
scale=1
|
| 1355 |
+
)
|
| 1356 |
+
model_status = gr.Textbox(
|
| 1357 |
+
label="Model Status",
|
| 1358 |
+
value=f"✅ Ready: {DEFAULT_MODEL_SIZE} model loaded",
|
| 1359 |
+
interactive=False,
|
| 1360 |
+
scale=2
|
| 1361 |
+
)
|
| 1362 |
+
switch_btn = gr.Button("Switch Model", variant="secondary", scale=1)
|
| 1363 |
+
|
| 1364 |
+
# Model switch handler
|
| 1365 |
+
switch_btn.click(
|
| 1366 |
+
fn=switch_model,
|
| 1367 |
+
inputs=[model_selector],
|
| 1368 |
+
outputs=[model_status],
|
| 1369 |
+
queue=True
|
| 1370 |
+
)
|
| 1371 |
|
| 1372 |
# Queue status
|
| 1373 |
gr.HTML("""
|
| 1374 |
+
<div style="background: #e7f3ff; border: 1px solid #b3d7ff; border-radius: 8px; padding: 12px 15px; margin: 15px 0;">
|
| 1375 |
<span style="font-size: 1.5em;">ℹ️</span>
|
| 1376 |
<strong>Queue System Active</strong> - Requests processed one at a time. Please wait patiently if busy.
|
| 1377 |
</div>
|
|
|
|
| 1526 |
elem_classes=["primary-btn"]
|
| 1527 |
)
|
| 1528 |
|
| 1529 |
+
gr.Markdown("### Example Prompts (30)")
|
| 1530 |
gr.Examples(
|
| 1531 |
examples=[[text] for text in example_texts],
|
| 1532 |
inputs=[text_input],
|
|
|
|
| 1558 |
# Footer
|
| 1559 |
gr.HTML(f"""
|
| 1560 |
<div class="footer">
|
| 1561 |
+
<p>Built with OmniSVG | Current Model: <strong>{DEFAULT_MODEL_SIZE}</strong></p>
|
| 1562 |
<p style="color: #dc3545; font-weight: 600;">Remember: Generate 4-8 candidates and pick the best!</p>
|
| 1563 |
</div>
|
| 1564 |
""")
|
|
|
|
| 1571 |
|
| 1572 |
args = parse_args()
|
| 1573 |
|
| 1574 |
+
# Determine initial model size
|
| 1575 |
+
initial_model_size = args.model_size or DEFAULT_MODEL_SIZE
|
| 1576 |
+
|
| 1577 |
print("="*60)
|
| 1578 |
print("OmniSVG Demo Page - Gradio App")
|
| 1579 |
print("="*60)
|
| 1580 |
+
print(f"Available model sizes: {AVAILABLE_MODEL_SIZES}")
|
| 1581 |
+
print(f"Initial model size: {initial_model_size}")
|
| 1582 |
+
print(f"Device: {device}")
|
| 1583 |
+
print(f"Precision: {DTYPE}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1584 |
print("="*60)
|
| 1585 |
|
| 1586 |
# Print loaded config values
|
| 1587 |
+
print("\n[CONFIG] Shared settings:")
|
| 1588 |
print(f" - TARGET_IMAGE_SIZE: {TARGET_IMAGE_SIZE}")
|
| 1589 |
print(f" - RENDER_SIZE: {RENDER_SIZE}")
|
| 1590 |
print(f" - BLACK_COLOR_TOKEN: {BLACK_COLOR_TOKEN}")
|
|
|
|
| 1592 |
print(f" - BOS_TOKEN_ID: {BOS_TOKEN_ID}")
|
| 1593 |
print(f" - EOS_TOKEN_ID: {EOS_TOKEN_ID}")
|
| 1594 |
print(f" - PAD_TOKEN_ID: {PAD_TOKEN_ID}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1595 |
print("="*60)
|
| 1596 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1597 |
print("\nLoading models (may download from HuggingFace Hub if needed)...")
|
| 1598 |
+
load_models(initial_model_size, args.weight_path, args.model_path)
|
| 1599 |
print("Models loaded successfully!\n")
|
| 1600 |
|
| 1601 |
demo = create_interface()
|