Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -9,28 +9,519 @@ import argparse
|
|
| 9 |
import gc
|
| 10 |
import yaml
|
| 11 |
import glob
|
| 12 |
-
import
|
| 13 |
-
|
|
|
|
|
|
|
| 14 |
from decoder import SketchDecoder
|
| 15 |
from transformers import AutoTokenizer, AutoProcessor
|
| 16 |
from qwen_vl_utils import process_vision_info
|
| 17 |
from tokenizer import SVGTokenizer
|
| 18 |
-
import spaces
|
| 19 |
|
| 20 |
-
#
|
| 21 |
-
with open('config.yaml', 'r') as f:
|
| 22 |
config = yaml.safe_load(f)
|
| 23 |
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
| 25 |
tokenizer = None
|
| 26 |
processor = None
|
| 27 |
sketch_decoder = None
|
| 28 |
svg_tokenizer = None
|
| 29 |
-
device = "cpu"
|
| 30 |
|
| 31 |
-
#
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
SUPPORTED_FORMATS = ['.png', '.jpg', '.jpeg', '.webp', '.bmp', '.gif']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
def parse_args():
|
| 36 |
parser = argparse.ArgumentParser(description='SVG Generator Service')
|
|
@@ -38,314 +529,846 @@ def parse_args():
|
|
| 38 |
parser.add_argument('--port', type=int, default=7860)
|
| 39 |
parser.add_argument('--share', action='store_true')
|
| 40 |
parser.add_argument('--debug', action='store_true')
|
|
|
|
|
|
|
| 41 |
return parser.parse_args()
|
| 42 |
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
|
|
|
| 46 |
|
| 47 |
-
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
-
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
-
if tokenizer is None:
|
| 54 |
-
# 1. 准备本地模型目录
|
| 55 |
-
local_model_dir = "custom_model_build"
|
| 56 |
-
|
| 57 |
-
# 只有当目录里没有权重文件时才执行构建
|
| 58 |
-
if not os.path.exists(os.path.join(local_model_dir, "pytorch_model.bin")):
|
| 59 |
-
print("🛠️ Building custom model directory...")
|
| 60 |
-
os.makedirs(local_model_dir, exist_ok=True)
|
| 61 |
-
|
| 62 |
-
# (A) 下载 Qwen 的配置文件
|
| 63 |
-
print("Downloading Qwen configurations...")
|
| 64 |
-
snapshot_download(
|
| 65 |
-
repo_id="Qwen/Qwen2.5-VL-3B-Instruct",
|
| 66 |
-
local_dir=local_model_dir,
|
| 67 |
-
allow_patterns=["*.json", "*.txt", "*.py"], # 这会下载 index.json,下面我们会删掉它
|
| 68 |
-
ignore_patterns=["*.safetensors", "*.bin", "*.pt"]
|
| 69 |
-
)
|
| 70 |
-
|
| 71 |
-
# (B) 下载 OmniSVG 权重
|
| 72 |
-
print("Downloading OmniSVG weights...")
|
| 73 |
-
sketch_weight_path = hf_hub_download(repo_id="OmniSVG/OmniSVG", filename="pytorch_model.bin")
|
| 74 |
-
|
| 75 |
-
# (C) 处理并保存权重
|
| 76 |
-
print("Processing and saving weights...")
|
| 77 |
-
state_dict = torch.load(sketch_weight_path, map_location="cpu")
|
| 78 |
-
|
| 79 |
-
new_state_dict = {}
|
| 80 |
-
for key in list(state_dict.keys()):
|
| 81 |
-
if key.startswith("transformer."):
|
| 82 |
-
new_key = key.replace("transformer.", "", 1)
|
| 83 |
-
new_state_dict[new_key] = state_dict[key]
|
| 84 |
-
else:
|
| 85 |
-
new_state_dict[key] = state_dict[key]
|
| 86 |
-
|
| 87 |
-
torch.save(new_state_dict, os.path.join(local_model_dir, "pytorch_model.bin"))
|
| 88 |
-
del state_dict, new_state_dict
|
| 89 |
-
gc.collect()
|
| 90 |
-
print("✅ Custom model directory built successfully.")
|
| 91 |
-
|
| 92 |
-
# [关键修复] 强制删除所有的 index.json 文件
|
| 93 |
-
# 即使之前的运行残留了这些文件,这里也会把它们清理掉,防止报错 FileNotFoundError
|
| 94 |
-
print("🧹 Cleaning up conflicting index files...")
|
| 95 |
-
for index_file in glob.glob(os.path.join(local_model_dir, "*.index.json")):
|
| 96 |
-
try:
|
| 97 |
-
os.remove(index_file)
|
| 98 |
-
print(f" Removed: {index_file}")
|
| 99 |
-
except Exception as e:
|
| 100 |
-
print(f" Failed to remove {index_file}: {e}")
|
| 101 |
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
|
| 107 |
-
#
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
else:
|
| 120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
|
| 122 |
-
|
| 123 |
-
|
|
|
|
|
|
|
| 124 |
|
| 125 |
-
def get_example_images():
|
| 126 |
-
example_dir = "./examples"
|
| 127 |
-
example_images = []
|
| 128 |
-
if os.path.exists(example_dir):
|
| 129 |
-
for ext in SUPPORTED_FORMATS:
|
| 130 |
-
pattern = os.path.join(example_dir, f"*{ext}")
|
| 131 |
-
example_images.extend(glob.glob(pattern))
|
| 132 |
-
example_images.sort()
|
| 133 |
-
return example_images
|
| 134 |
|
| 135 |
-
def
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
"
|
| 141 |
-
|
| 142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
]
|
| 144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
|
| 146 |
-
|
| 147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
|
|
|
| 167 |
|
| 168 |
-
|
| 169 |
-
|
| 170 |
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
|
| 178 |
input_ids = inputs['input_ids'].to(device)
|
| 179 |
attention_mask = inputs['attention_mask'].to(device)
|
| 180 |
-
pixel_values = inputs['pixel_values'].to(device) if 'pixel_values' in inputs else None
|
| 181 |
-
image_grid_thw = inputs['image_grid_thw'].to(device) if 'image_grid_thw' in inputs else None
|
| 182 |
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
try:
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
torch.cuda.empty_cache()
|
| 190 |
-
|
| 191 |
-
print(f"Generating SVG for {task_type}...")
|
| 192 |
-
|
| 193 |
-
if task_type == "image-to-svg":
|
| 194 |
-
gen_config = dict(
|
| 195 |
-
do_sample=True,
|
| 196 |
-
temperature=0.1,
|
| 197 |
-
top_p=0.001,
|
| 198 |
-
top_k=1,
|
| 199 |
-
num_beams=5,
|
| 200 |
-
repetition_penalty=1.05,
|
| 201 |
-
)
|
| 202 |
-
else:
|
| 203 |
-
gen_config = dict(
|
| 204 |
-
do_sample=True,
|
| 205 |
-
temperature=0.8,
|
| 206 |
-
top_p=0.95,
|
| 207 |
-
top_k=50,
|
| 208 |
-
repetition_penalty=1.05,
|
| 209 |
-
early_stopping=True,
|
| 210 |
-
)
|
| 211 |
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
model_config = config['model']
|
| 216 |
-
max_length = model_config['max_length']
|
| 217 |
-
output_ids = torch.ones(1, max_length).long().to(device) * model_config['eos_token_id']
|
| 218 |
-
|
| 219 |
-
with torch.no_grad():
|
| 220 |
-
results = sketch_decoder.transformer.generate(
|
| 221 |
-
input_ids=input_ids,
|
| 222 |
-
attention_mask=attention_mask,
|
| 223 |
-
pixel_values=pixel_values,
|
| 224 |
-
image_grid_thw=image_grid_thw,
|
| 225 |
-
max_new_tokens=max_length-1,
|
| 226 |
-
num_return_sequences=1,
|
| 227 |
-
bos_token_id=model_config['bos_token_id'],
|
| 228 |
-
eos_token_id=model_config['eos_token_id'],
|
| 229 |
-
pad_token_id=model_config['pad_token_id'],
|
| 230 |
-
use_cache=True,
|
| 231 |
-
**gen_config
|
| 232 |
-
)
|
| 233 |
-
results = results[:, :max_length-1]
|
| 234 |
-
output_ids[:, :results.shape[1]] = results
|
| 235 |
-
|
| 236 |
-
generated_xy, generated_colors = svg_tokenizer.process_generated_tokens(output_ids)
|
| 237 |
-
svg_tensors = svg_tokenizer.raster_svg(generated_xy)
|
| 238 |
-
|
| 239 |
-
if not svg_tensors or not svg_tensors[0]:
|
| 240 |
-
return "Error: No valid SVG paths generated", None
|
| 241 |
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
|
| 246 |
-
|
| 247 |
-
|
| 248 |
|
| 249 |
-
|
| 250 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
except Exception as e:
|
| 252 |
-
print(f"Generation
|
| 253 |
import traceback
|
| 254 |
traceback.print_exc()
|
| 255 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
|
| 257 |
@spaces.GPU
|
| 258 |
-
def
|
| 259 |
-
|
| 260 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
|
| 262 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_file:
|
| 264 |
-
|
| 265 |
tmp_path = tmp_file.name
|
| 266 |
|
| 267 |
try:
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
finally:
|
| 272 |
-
if os.path.exists(tmp_path):
|
|
|
|
| 273 |
|
| 274 |
-
|
| 275 |
-
def
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
|
| 284 |
def create_interface():
|
|
|
|
|
|
|
|
|
|
| 285 |
example_texts = [
|
| 286 |
-
|
| 287 |
-
"A
|
| 288 |
-
"A
|
| 289 |
-
"A
|
| 290 |
-
"A
|
| 291 |
-
"A
|
| 292 |
-
"A
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
"A
|
| 296 |
-
"A
|
| 297 |
-
"A
|
| 298 |
-
"A
|
| 299 |
-
"A
|
| 300 |
-
"
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
"
|
| 304 |
-
"
|
| 305 |
-
"
|
| 306 |
-
"
|
| 307 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 308 |
]
|
|
|
|
| 309 |
example_images = get_example_images()
|
| 310 |
|
| 311 |
-
with gr.Blocks(title="OmniSVG
|
| 312 |
-
|
| 313 |
-
gr.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 314 |
|
| 315 |
with gr.Tabs():
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 320 |
if example_images:
|
| 321 |
-
gr.
|
| 322 |
-
|
| 323 |
|
| 324 |
-
with gr.Column():
|
| 325 |
-
|
| 326 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 327 |
|
| 328 |
-
image_generate_btn.click(
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
|
| 337 |
-
with gr.Column():
|
| 338 |
-
|
| 339 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
|
| 341 |
-
text_generate_btn.click(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
|
| 343 |
-
gr.Markdown("""## Usage Instructions...""")
|
| 344 |
-
return demo
|
| 345 |
|
| 346 |
if __name__ == "__main__":
|
| 347 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
|
|
| 348 |
args = parse_args()
|
| 349 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
demo = create_interface()
|
| 351 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
import gc
|
| 10 |
import yaml
|
| 11 |
import glob
|
| 12 |
+
import numpy as np
|
| 13 |
+
import time
|
| 14 |
+
import threading
|
| 15 |
+
|
| 16 |
from decoder import SketchDecoder
|
| 17 |
from transformers import AutoTokenizer, AutoProcessor
|
| 18 |
from qwen_vl_utils import process_vision_info
|
| 19 |
from tokenizer import SVGTokenizer
|
|
|
|
| 20 |
|
| 21 |
+
# Load config
|
| 22 |
+
with open('./config.yaml', 'r') as f:
|
| 23 |
config = yaml.safe_load(f)
|
| 24 |
|
| 25 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 26 |
+
DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
| 27 |
+
|
| 28 |
+
# Global Models
|
| 29 |
tokenizer = None
|
| 30 |
processor = None
|
| 31 |
sketch_decoder = None
|
| 32 |
svg_tokenizer = None
|
|
|
|
| 33 |
|
| 34 |
+
# Thread lock for model inference
|
| 35 |
+
generation_lock = threading.Lock()
|
| 36 |
+
|
| 37 |
+
# Constants
|
| 38 |
+
SYSTEM_PROMPT = """You are an expert SVG code generator.
|
| 39 |
+
Generate precise, valid SVG path commands that accurately represent the described scene or object.
|
| 40 |
+
Focus on capturing key shapes, spatial relationships, and visual composition."""
|
| 41 |
+
|
| 42 |
SUPPORTED_FORMATS = ['.png', '.jpg', '.jpeg', '.webp', '.bmp', '.gif']
|
| 43 |
+
TARGET_IMAGE_SIZE = 448
|
| 44 |
+
BLACK_COLOR_TOKEN = 40012
|
| 45 |
+
|
| 46 |
+
# Task configurations with defaults
|
| 47 |
+
TASK_CONFIGS = {
|
| 48 |
+
"text-to-svg-icon": {
|
| 49 |
+
"default_temperature": 0.5,
|
| 50 |
+
"default_top_p": 0.88,
|
| 51 |
+
"default_top_k": 50,
|
| 52 |
+
"default_repetition_penalty": 1.05,
|
| 53 |
+
},
|
| 54 |
+
"text-to-svg-illustration": {
|
| 55 |
+
"default_temperature": 0.6,
|
| 56 |
+
"default_top_p": 0.90,
|
| 57 |
+
"default_top_k": 60,
|
| 58 |
+
"default_repetition_penalty": 1.03,
|
| 59 |
+
},
|
| 60 |
+
"image-to-svg": {
|
| 61 |
+
"default_temperature": 0.3,
|
| 62 |
+
"default_top_p": 0.90,
|
| 63 |
+
"default_top_k": 50,
|
| 64 |
+
"default_repetition_penalty": 1.05,
|
| 65 |
+
}
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
# Custom CSS
|
| 69 |
+
CUSTOM_CSS = """
|
| 70 |
+
/* Main container centering */
|
| 71 |
+
.gradio-container {
|
| 72 |
+
max-width: 1400px !important;
|
| 73 |
+
margin: 0 auto !important;
|
| 74 |
+
padding: 20px !important;
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
/* Header styling */
|
| 78 |
+
.header-container {
|
| 79 |
+
text-align: center;
|
| 80 |
+
margin-bottom: 20px;
|
| 81 |
+
padding: 20px;
|
| 82 |
+
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 83 |
+
border-radius: 16px;
|
| 84 |
+
color: white;
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
.header-container h1 {
|
| 88 |
+
margin: 0;
|
| 89 |
+
font-size: 2.5em;
|
| 90 |
+
font-weight: 700;
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
.header-container p {
|
| 94 |
+
margin: 10px 0 0 0;
|
| 95 |
+
opacity: 0.9;
|
| 96 |
+
font-size: 1.1em;
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
/* Tips section */
|
| 100 |
+
.tips-box {
|
| 101 |
+
background: #f8f9fa;
|
| 102 |
+
border-radius: 12px;
|
| 103 |
+
padding: 20px;
|
| 104 |
+
margin-bottom: 20px;
|
| 105 |
+
border: 1px solid #e0e0e0;
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
.tips-box h3 {
|
| 109 |
+
margin-top: 0;
|
| 110 |
+
color: #333;
|
| 111 |
+
border-bottom: 2px solid #667eea;
|
| 112 |
+
padding-bottom: 10px;
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
.tip-category {
|
| 116 |
+
background: white;
|
| 117 |
+
border-radius: 8px;
|
| 118 |
+
padding: 15px;
|
| 119 |
+
margin: 10px 0;
|
| 120 |
+
border-left: 4px solid #667eea;
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
.tip-category h4 {
|
| 124 |
+
margin: 0 0 10px 0;
|
| 125 |
+
color: #667eea;
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
.tip-category code {
|
| 129 |
+
background: #f0f0f0;
|
| 130 |
+
padding: 2px 6px;
|
| 131 |
+
border-radius: 4px;
|
| 132 |
+
font-size: 0.9em;
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
.example-prompt {
|
| 136 |
+
background: #e8f4fd;
|
| 137 |
+
padding: 10px;
|
| 138 |
+
border-radius: 6px;
|
| 139 |
+
margin: 8px 0;
|
| 140 |
+
font-style: italic;
|
| 141 |
+
font-size: 0.95em;
|
| 142 |
+
color: #333;
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
.red-tip {
|
| 146 |
+
color: #dc3545;
|
| 147 |
+
font-weight: 600;
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
.red-box {
|
| 151 |
+
background: #fff5f5;
|
| 152 |
+
border: 1px solid #ffcccc;
|
| 153 |
+
border-left: 4px solid #dc3545;
|
| 154 |
+
padding: 12px;
|
| 155 |
+
border-radius: 8px;
|
| 156 |
+
margin: 10px 0;
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
.red-box strong {
|
| 160 |
+
color: #dc3545;
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
.orange-box {
|
| 164 |
+
background: #fff8e6;
|
| 165 |
+
border: 1px solid #ffc107;
|
| 166 |
+
border-left: 4px solid #ff9800;
|
| 167 |
+
padding: 12px;
|
| 168 |
+
border-radius: 8px;
|
| 169 |
+
margin: 10px 0;
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
.orange-box strong {
|
| 173 |
+
color: #ff9800;
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
.green-box {
|
| 177 |
+
background: #e8f5e9;
|
| 178 |
+
border: 1px solid #81c784;
|
| 179 |
+
border-left: 4px solid #4caf50;
|
| 180 |
+
padding: 12px;
|
| 181 |
+
border-radius: 8px;
|
| 182 |
+
margin: 10px 0;
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
.green-box strong {
|
| 186 |
+
color: #4caf50;
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
/* Tab styling */
|
| 190 |
+
.tabs {
|
| 191 |
+
border-radius: 12px !important;
|
| 192 |
+
overflow: hidden;
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
.tabitem {
|
| 196 |
+
padding: 20px !important;
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
/* Button styling */
|
| 200 |
+
.primary-btn {
|
| 201 |
+
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
|
| 202 |
+
border: none !important;
|
| 203 |
+
font-weight: 600 !important;
|
| 204 |
+
padding: 12px 24px !important;
|
| 205 |
+
font-size: 1.1em !important;
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
.primary-btn:hover {
|
| 209 |
+
transform: translateY(-2px);
|
| 210 |
+
box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4);
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
/* Settings group */
|
| 214 |
+
.settings-group {
|
| 215 |
+
background: #f8f9fa;
|
| 216 |
+
border-radius: 10px;
|
| 217 |
+
padding: 15px;
|
| 218 |
+
margin: 10px 0;
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
.advanced-settings {
|
| 222 |
+
background: #f0f4f8;
|
| 223 |
+
border-radius: 8px;
|
| 224 |
+
padding: 12px;
|
| 225 |
+
margin-top: 10px;
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
/* Code output */
|
| 229 |
+
.code-output textarea {
|
| 230 |
+
font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', monospace !important;
|
| 231 |
+
font-size: 12px !important;
|
| 232 |
+
background: #1e1e1e !important;
|
| 233 |
+
color: #d4d4d4 !important;
|
| 234 |
+
border-radius: 8px !important;
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
/* Input image area */
|
| 238 |
+
.input-image {
|
| 239 |
+
border: 2px dashed #ccc;
|
| 240 |
+
border-radius: 12px;
|
| 241 |
+
transition: border-color 0.3s;
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
.input-image:hover {
|
| 245 |
+
border-color: #667eea;
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
/* Footer */
|
| 249 |
+
.footer {
|
| 250 |
+
text-align: center;
|
| 251 |
+
padding: 20px;
|
| 252 |
+
color: #666;
|
| 253 |
+
font-size: 0.9em;
|
| 254 |
+
}
|
| 255 |
+
|
| 256 |
+
/* Responsive adjustments */
|
| 257 |
+
@media (max-width: 768px) {
|
| 258 |
+
.gradio-container {
|
| 259 |
+
padding: 10px !important;
|
| 260 |
+
}
|
| 261 |
+
.header-container h1 {
|
| 262 |
+
font-size: 1.8em;
|
| 263 |
+
}
|
| 264 |
+
}
|
| 265 |
+
"""
|
| 266 |
+
|
| 267 |
+
# Enhanced Tips HTML - Bilingual with Red Tips
|
| 268 |
+
TIPS_HTML = """
|
| 269 |
+
<div class="tips-box">
|
| 270 |
+
<h3>💡 Prompting Guide & Best Practices | 提示词指南与最佳实践</h3>
|
| 271 |
+
|
| 272 |
+
<!-- Critical Red Tips Section -->
|
| 273 |
+
<div class="red-box">
|
| 274 |
+
<strong>🔴 CRITICAL: Tips That WILL Improve Your Results | 关键:一定能提升效果的技巧</strong>
|
| 275 |
+
<ul style="margin: 8px 0 0 0; padding-left: 20px;">
|
| 276 |
+
<li style="color: #dc3545; font-weight: 600;">
|
| 277 |
+
<strong>🎲 Generate 4-8 candidates and pick the best one!</strong> Results vary significantly between generations - this is NORMAL!<br/>
|
| 278 |
+
<span style="color: #666; font-weight: normal;">生成4-8个候选结果并选择最好的!每次生成结果差异很大 - 这是正常的!</span>
|
| 279 |
+
</li>
|
| 280 |
+
<li style="color: #dc3545; font-weight: 600;">
|
| 281 |
+
<strong>📐 Use GEOMETRIC descriptions:</strong> "triangular roof", "circular head", "rectangular body", "curved tail"<br/>
|
| 282 |
+
<span style="color: #666; font-weight: normal;">使用几何描述:"三角形屋顶"、"圆形头部"、"矩形身体"、"弯曲尾巴"</span>
|
| 283 |
+
</li>
|
| 284 |
+
<li style="color: #dc3545; font-weight: 600;">
|
| 285 |
+
<strong>🎨 ALWAYS specify colors for EACH element:</strong> "black outline", "red roof", "blue shirt", "green grass"<br/>
|
| 286 |
+
<span style="color: #666; font-weight: normal;">始终为每个元素指定颜色:"黑色轮廓"、"红色屋顶"、"蓝色衬衫"、"绿色草地"</span>
|
| 287 |
+
</li>
|
| 288 |
+
<li style="color: #dc3545; font-weight: 600;">
|
| 289 |
+
<strong>⬜ Say "white background" or "on white background"</strong> for cleaner results<br/>
|
| 290 |
+
<span style="color: #666; font-weight: normal;">说"白色背景"或"在白色背景上"可获得更干净的结果</span>
|
| 291 |
+
</li>
|
| 292 |
+
<li style="color: #dc3545; font-weight: 600;">
|
| 293 |
+
<strong>📍 Describe position & orientation:</strong> "centrally positioned", "pointing upward", "facing right", "at the bottom"<br/>
|
| 294 |
+
<span style="color: #666; font-weight: normal;">描述位置和方向:"居中放置"、"指向上方"、"朝右"、"在底部"</span>
|
| 295 |
+
</li>
|
| 296 |
+
<li style="color: #dc3545; font-weight: 600;">
|
| 297 |
+
<strong>✂️ Keep it SIMPLE:</strong> Avoid complex sentences. Use short, clear phrases connected by commas.<br/>
|
| 298 |
+
<span style="color: #666; font-weight: normal;">保持简单:避免复杂句子。使用简短清晰的短语,用逗号连接。</span>
|
| 299 |
+
</li>
|
| 300 |
+
</ul>
|
| 301 |
+
</div>
|
| 302 |
+
|
| 303 |
+
<!-- Parameter Tuning Tips -->
|
| 304 |
+
<div class="orange-box">
|
| 305 |
+
<strong>🎛️ Parameter Tuning Guide | 参数调整指南</strong>
|
| 306 |
+
<table style="width: 100%; margin-top: 10px; border-collapse: collapse;">
|
| 307 |
+
<tr style="background: rgba(255,255,255,0.5);">
|
| 308 |
+
<th style="padding: 8px; text-align: left; border-bottom: 1px solid #ddd;">Scenario 场景</th>
|
| 309 |
+
<th style="padding: 8px; text-align: center; border-bottom: 1px solid #ddd;">Temperature</th>
|
| 310 |
+
<th style="padding: 8px; text-align: center; border-bottom: 1px solid #ddd;">Top-P</th>
|
| 311 |
+
<th style="padding: 8px; text-align: center; border-bottom: 1px solid #ddd;">Top-K</th>
|
| 312 |
+
<th style="padding: 8px; text-align: center; border-bottom: 1px solid #ddd;">Rep. Penalty</th>
|
| 313 |
+
</tr>
|
| 314 |
+
<tr>
|
| 315 |
+
<td style="padding: 8px;">Simple icons/shapes 简单图标</td>
|
| 316 |
+
<td style="padding: 8px; text-align: center;">0.3 - 0.5</td>
|
| 317 |
+
<td style="padding: 8px; text-align: center;">0.85 - 0.90</td>
|
| 318 |
+
<td style="padding: 8px; text-align: center;">40 - 50</td>
|
| 319 |
+
<td style="padding: 8px; text-align: center;">1.05</td>
|
| 320 |
+
</tr>
|
| 321 |
+
<tr style="background: rgba(255,255,255,0.3);">
|
| 322 |
+
<td style="padding: 8px;">Characters/Avatars 人物/头像</td>
|
| 323 |
+
<td style="padding: 8px; text-align: center;">0.5 - 0.7</td>
|
| 324 |
+
<td style="padding: 8px; text-align: center;">0.88 - 0.92</td>
|
| 325 |
+
<td style="padding: 8px; text-align: center;">50 - 70</td>
|
| 326 |
+
<td style="padding: 8px; text-align: center;">1.02 - 1.05</td>
|
| 327 |
+
</tr>
|
| 328 |
+
<tr>
|
| 329 |
+
<td style="padding: 8px;">Landscapes/Scenes 风景/场景</td>
|
| 330 |
+
<td style="padding: 8px; text-align: center;">0.5 - 0.7</td>
|
| 331 |
+
<td style="padding: 8px; text-align: center;">0.88 - 0.92</td>
|
| 332 |
+
<td style="padding: 8px; text-align: center;">50 - 60</td>
|
| 333 |
+
<td style="padding: 8px; text-align: center;">1.03</td>
|
| 334 |
+
</tr>
|
| 335 |
+
<tr style="background: rgba(255,255,255,0.3);">
|
| 336 |
+
<td style="padding: 8px;">Image-to-SVG 图像转SVG</td>
|
| 337 |
+
<td style="padding: 8px; text-align: center;">0.2 - 0.4</td>
|
| 338 |
+
<td style="padding: 8px; text-align: center;">0.88 - 0.92</td>
|
| 339 |
+
<td style="padding: 8px; text-align: center;">40 - 50</td>
|
| 340 |
+
<td style="padding: 8px; text-align: center;">1.05</td>
|
| 341 |
+
</tr>
|
| 342 |
+
</table>
|
| 343 |
+
<p style="margin: 10px 0 0 0; font-size: 0.9em; color: #856404;">
|
| 344 |
+
💡 <strong>Tip:</strong> If results are too chaotic, lower temperature. If too simple/empty, raise it slightly.<br/>
|
| 345 |
+
如果结果太混乱,降低温度。如果太简单/空白,稍微提高。
|
| 346 |
+
</p>
|
| 347 |
+
</div>
|
| 348 |
+
|
| 349 |
+
<div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(300px, 1fr)); gap: 15px; margin-top: 15px;">
|
| 350 |
+
|
| 351 |
+
<div class="tip-category">
|
| 352 |
+
<h4>🎯 Icons & Simple Shapes | 图标与简单形状</h4>
|
| 353 |
+
<p>Use clear geometric descriptions with explicit colors.<br/>
|
| 354 |
+
<span style="color: #666; font-size: 0.9em;">使用清晰的几何描述和明确的颜色。</span></p>
|
| 355 |
+
<div class="example-prompt">
|
| 356 |
+
"A black triangle pointing downward, centrally positioned on white background."<br/>
|
| 357 |
+
<span style="color: #666;">"黑色三角形,指向下方,居中在白色背景上。"</span>
|
| 358 |
+
</div>
|
| 359 |
+
<div class="example-prompt">
|
| 360 |
+
"A red heart shape with smooth curved edges, centered on white background."<br/>
|
| 361 |
+
<span style="color: #666;">"红色心形,边缘光滑弯曲,居中在白色背景上。"</span>
|
| 362 |
+
</div>
|
| 363 |
+
<p><strong>Keywords:</strong> <code>triangle</code> <code>circle</code> <code>arrow</code> <code>heart</code> <code>star</code> <code>centered</code></p>
|
| 364 |
+
</div>
|
| 365 |
+
|
| 366 |
+
<div class="tip-category">
|
| 367 |
+
<h4>👤 Characters & People | 人物角色</h4>
|
| 368 |
+
<p>Break down into simple geometric parts. Describe each body part with shape + color.<br/>
|
| 369 |
+
<span style="color: #666; font-size: 0.9em;">分解为简单几何部分。用形状+颜色描述每个身体部位。</span></p>
|
| 370 |
+
<div class="example-prompt">
|
| 371 |
+
"A simple person: round beige head, rectangular blue shirt body, two dark gray rectangular legs. Standing pose, arms at sides, flat colors, white background."<br/>
|
| 372 |
+
<span style="color: #666;">"简单人物:米色圆形头,蓝色矩形衬衫身体,两条深灰矩形腿。站立姿势,双臂下垂,平面颜色,白色背景。"</span>
|
| 373 |
+
</div>
|
| 374 |
+
<div class="example-prompt">
|
| 375 |
+
"A girl with long black hair, pink dress with triangular skirt shape, small circular face with dot eyes and curved smile. Simple cartoon style."<br/>
|
| 376 |
+
<span style="color: #666;">"长黑发女孩,粉色连衣裙(三角形裙摆),小圆脸配点状眼睛和弯曲微笑。简单卡通风格。"</span>
|
| 377 |
+
</div>
|
| 378 |
+
<p class="red-tip">⚠️ Keep poses SIMPLE: standing, sitting, waving. Avoid complex actions!</p>
|
| 379 |
+
</div>
|
| 380 |
+
|
| 381 |
+
<div class="tip-category">
|
| 382 |
+
<h4>😊 Avatars & Portraits | 头像与肖像</h4>
|
| 383 |
+
<p>Use circular frame, focus on face and upper body only.<br/>
|
| 384 |
+
<span style="color: #666; font-size: 0.9em;">使用圆形框架,只关注脸部和上半身。</span></p>
|
| 385 |
+
<div class="example-prompt">
|
| 386 |
+
"Circular avatar: person with short black hair, round face with two dot eyes and small curved smile, wearing blue collar shirt. Minimal style, white background."<br/>
|
| 387 |
+
<span style="color: #666;">"圆形头像:短黑发人物,圆脸配两个点状眼睛和小弯曲微笑,穿蓝色衬衫领子。极简风格,白色背景。"</span>
|
| 388 |
+
</div>
|
| 389 |
+
<div class="example-prompt">
|
| 390 |
+
"Profile avatar silhouette: black side view of head with short hair, facing right. Simple solid shape on white background."<br/>
|
| 391 |
+
<span style="color: #666;">"侧面头像剪影:黑色短发头部侧视图,朝右。简单实心形状,白色背景。"</span>
|
| 392 |
+
</div>
|
| 393 |
+
</div>
|
| 394 |
+
|
| 395 |
+
<div class="tip-category">
|
| 396 |
+
<h4>🏔️ Landscapes & Scenes | 风景与场景</h4>
|
| 397 |
+
<p>Layer elements from background to foreground. Specify color for EACH layer.<br/>
|
| 398 |
+
<span style="color: #666; font-size: 0.9em;">从背景到前景分层。为每层指定颜色。</span></p>
|
| 399 |
+
<div class="example-prompt">
|
| 400 |
+
"Layered landscape: light blue sky at top, gray triangular mountains in middle, dark green triangular pine trees at bottom. Flat colors, simple shapes."<br/>
|
| 401 |
+
<span style="color: #666;">"分层风景:顶部浅蓝天空,中间灰色三角山脉,底部深绿三角松树。平面颜色,简单形状。"</span>
|
| 402 |
+
</div>
|
| 403 |
+
<div class="example-prompt">
|
| 404 |
+
"Sunset beach: orange gradient sky at top, yellow semicircle sun on horizon, dark blue wavy ocean below, tan beach at bottom."<br/>
|
| 405 |
+
<span style="color: #666;">"日落海滩:顶部橙色渐变天空,地平线黄色半圆太阳,下方深蓝波浪海洋,底部棕褐色沙滩。"</span>
|
| 406 |
+
</div>
|
| 407 |
+
<p class="red-tip">⚠️ Use geometric shapes for nature: triangular trees, wavy water, semicircle sun!</p>
|
| 408 |
+
</div>
|
| 409 |
+
|
| 410 |
+
<div class="tip-category">
|
| 411 |
+
<h4>🐱 Animals | 动物</h4>
|
| 412 |
+
<p>Describe as geometric shapes: oval body, round head, triangular ears, curved tail.<br/>
|
| 413 |
+
<span style="color: #666; font-size: 0.9em;">描述为几何形状:椭圆身体,圆头,三角耳朵,弯曲尾巴。</span></p>
|
| 414 |
+
<div class="example-prompt">
|
| 415 |
+
"Cute cat: orange round head with two triangular ears, oval orange body, curved tail. Simple cartoon style with black outlines, sitting pose, white background."<br/>
|
| 416 |
+
<span style="color: #666;">"可爱猫咪:橙色圆头配两个三角耳朵,橙色椭圆身体,弯曲尾巴。简单卡通风格,黑色轮廓,坐姿,白色背景。"</span>
|
| 417 |
+
</div>
|
| 418 |
+
<div class="example-prompt">
|
| 419 |
+
"Simple black bird: oval body, small round head, pointed triangular beak facing right, triangular tail, two stick legs. Silhouette style on white."<br/>
|
| 420 |
+
<span style="color: #666;">"简单黑鸟:椭圆身体,小圆头,尖三角喙朝右,三角尾巴,两条棒状腿。白色背景剪影风格。"</span>
|
| 421 |
+
</div>
|
| 422 |
+
</div>
|
| 423 |
+
|
| 424 |
+
<div class="tip-category">
|
| 425 |
+
<h4>🏠 Buildings & Objects | 建筑与物体</h4>
|
| 426 |
+
<p>Use basic shapes: rectangles for walls, triangles for roofs, squares for windows.<br/>
|
| 427 |
+
<span style="color: #666; font-size: 0.9em;">使用基本形状:矩形墙壁,三角屋顶,方形窗户。</span></p>
|
| 428 |
+
<div class="example-prompt">
|
| 429 |
+
"Simple house: red triangular roof on top, beige rectangular wall, brown rectangular door in center, two small blue square windows. Green ground at bottom, white background."<br/>
|
| 430 |
+
<span style="color: #666;">"简单房屋:顶部红色三角屋顶,米色矩形墙壁,中间棕色矩形门,两个小蓝色方形窗户。底部绿色地面,白色背景。"</span>
|
| 431 |
+
</div>
|
| 432 |
+
<div class="example-prompt">
|
| 433 |
+
"Coffee mug: brown cylindrical cup shape with curved handle on right side, three wavy steam lines rising from top. Simple flat style on white."<br/>
|
| 434 |
+
<span style="color: #666;">"咖啡杯:棕色圆柱杯身,右侧弯曲把手,顶部三条波浪蒸汽线上升。简单平面风格,白色背景。"</span>
|
| 435 |
+
</div>
|
| 436 |
+
</div>
|
| 437 |
+
|
| 438 |
+
</div>
|
| 439 |
+
|
| 440 |
+
<!-- Extended Examples Section -->
|
| 441 |
+
<div style="margin-top: 20px; padding: 15px; background: #f0f7ff; border-radius: 10px; border: 1px solid #cce5ff;">
|
| 442 |
+
<h4 style="margin-top: 0; color: #0066cc;">🎨 More Complex Examples (Generate 6-8 candidates!) | 更多复杂示例(请生成6-8个候选!)</h4>
|
| 443 |
+
|
| 444 |
+
<div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(280px, 1fr)); gap: 12px; margin-top: 15px;">
|
| 445 |
+
<div class="example-prompt">
|
| 446 |
+
<strong>👨💼 Business Avatar:</strong><br/>
|
| 447 |
+
"Circular professional avatar: man with short black hair, neutral skin tone round face, wearing dark navy suit with white shirt collar visible. Clean minimal style, centered in circle, white background."
|
| 448 |
+
</div>
|
| 449 |
+
<div class="example-prompt">
|
| 450 |
+
<strong>👩 Female Portrait:</strong><br/>
|
| 451 |
+
"Simple female face: oval face shape, long brown wavy hair on sides, two dot eyes, small nose, curved smile lips. Pink blush on cheeks. Cartoon portrait style, white background."
|
| 452 |
+
</div>
|
| 453 |
+
<div class="example-prompt">
|
| 454 |
+
<strong>🧒 Child Character:</strong><br/>
|
| 455 |
+
"Cute child standing: large round head with short brown hair, big circular eyes with white highlights, small body in red t-shirt and blue shorts, simple stick arms and legs. Cheerful cartoon style."
|
| 456 |
+
</div>
|
| 457 |
+
<div class="example-prompt">
|
| 458 |
+
<strong>🏃 Active Pose:</strong><br/>
|
| 459 |
+
"Person walking: side view, circular head, rectangular torso in green jacket, legs in walking position (one forward, one back). Simple geometric style, moving right, white background."
|
| 460 |
+
</div>
|
| 461 |
+
<div class="example-prompt">
|
| 462 |
+
<strong>🌲 Forest Scene:</strong><br/>
|
| 463 |
+
"Simple forest: light blue sky, row of 5 dark green triangular pine trees of varying heights, brown rectangular trunks, light green grass strip at bottom. Layered flat design."
|
| 464 |
+
</div>
|
| 465 |
+
<div class="example-prompt">
|
| 466 |
+
<strong>🌊 Ocean View:</strong><br/>
|
| 467 |
+
"Minimalist ocean: gradient blue sky at top, three horizontal wavy lines in dark blue for ocean, small white sailboat with triangular sail in center. Clean vector style."
|
| 468 |
+
</div>
|
| 469 |
+
<div class="example-prompt">
|
| 470 |
+
<strong>🌆 City Skyline:</strong><br/>
|
| 471 |
+
"Simple city skyline: orange sunset sky gradient, row of black rectangular building silhouettes of different heights, some with small yellow square windows. Minimalist style."
|
| 472 |
+
</div>
|
| 473 |
+
<div class="example-prompt">
|
| 474 |
+
<strong>🐕 Dog Character:</strong><br/>
|
| 475 |
+
"Friendly cartoon dog: brown oval body, round head with floppy ears, black dot nose, curved tail pointing up, four short legs. Sitting pose facing forward, white background."
|
| 476 |
+
</div>
|
| 477 |
+
</div>
|
| 478 |
+
</div>
|
| 479 |
+
|
| 480 |
+
<!-- Quick Troubleshooting -->
|
| 481 |
+
<div class="green-box" style="margin-top: 15px;">
|
| 482 |
+
<strong>⚡ Quick Troubleshooting | 快速故障排除</strong>
|
| 483 |
+
<ul style="margin: 8px 0 0 0; padding-left: 20px;">
|
| 484 |
+
<li><strong>Messy/chaotic? 混乱?</strong> → Lower temperature to 0.3-0.4, simplify description, reduce top_k</li>
|
| 485 |
+
<li><strong>Too simple/empty? 太简单?</strong> → Raise temperature to 0.5-0.6, add more shape details</li>
|
| 486 |
+
<li><strong>Wrong colors? 颜色错误?</strong> → Explicitly name EVERY color: "red roof", "blue shirt", "black outline"</li>
|
| 487 |
+
<li><strong>Missing elements? 元素缺失?</strong> → Add position words: "at top", "in center", "at bottom left"</li>
|
| 488 |
+
<li><strong>Repetitive patterns? 重复图案?</strong> → Increase repetition_penalty to 1.08-1.15</li>
|
| 489 |
+
<li><strong>Inconsistent? 不一致?</strong> → <span class="red-tip">Generate MORE candidates (6-8) and pick the best!</span></li>
|
| 490 |
+
</ul>
|
| 491 |
+
</div>
|
| 492 |
+
|
| 493 |
+
<!-- Prompt Template -->
|
| 494 |
+
<div style="margin-top: 15px; padding: 12px; background: #e8f5e9; border-radius: 8px; border-left: 4px solid #4caf50;">
|
| 495 |
+
<strong>✅ Recommended Prompt Structure | 推荐提示词结构</strong>
|
| 496 |
+
<div style="background: white; padding: 10px; border-radius: 6px; margin-top: 8px; font-family: monospace; font-size: 0.9em;">
|
| 497 |
+
[Subject] + [Shape descriptions with colors] + [Position/orientation] + [Style] + [Background]
|
| 498 |
+
</div>
|
| 499 |
+
<p style="margin: 10px 0 0 0; color: #2e7d32; font-size: 0.95em;">
|
| 500 |
+
✓ "A fox logo: triangular orange head, pointed ears, white chest marking, facing right. Minimalist flat style, centered on white background."
|
| 501 |
+
</p>
|
| 502 |
+
</div>
|
| 503 |
+
</div>
|
| 504 |
+
"""
|
| 505 |
+
|
| 506 |
+
# Image-to-SVG specific tips
|
| 507 |
+
IMAGE_TIPS_HTML = """
|
| 508 |
+
<div class="red-box">
|
| 509 |
+
<strong>🔴 Image-to-SVG Tips | 图片转SVG技巧</strong>
|
| 510 |
+
<ul style="margin: 8px 0 0 0; padding-left: 20px;">
|
| 511 |
+
<li><strong>Best input: Simple images with white/transparent background</strong><br/>
|
| 512 |
+
<span style="color: #666;">最佳输入:白色或透明背景的简单图片</span></li>
|
| 513 |
+
<li><strong>PNG with transparency (RGBA) works best!</strong> We auto-convert to white background.<br/>
|
| 514 |
+
<span style="color: #666;">透明背景的PNG效果最好!我们会自动转换为白色背景。</span></li>
|
| 515 |
+
<li><strong>For complex backgrounds:</strong> Enable "Replace Background" option below.<br/>
|
| 516 |
+
<span style="color: #666;">复杂背景图片:启用下方的"替换背景"选项。</span></li>
|
| 517 |
+
<li><strong>Lower temperature (0.2-0.4)</strong> for more accurate reproduction.<br/>
|
| 518 |
+
<span style="color: #666;">较低温度(0.2-0.4)可获得更准确的复制效果。</span></li>
|
| 519 |
+
<li style="color: #dc3545; font-weight: 600;"><strong>Generate 4-8 candidates!</strong> Pick the one that best matches your input.<br/>
|
| 520 |
+
<span style="color: #666; font-weight: normal;">生成4-8个候选!选择最匹配输入的那个。</span></li>
|
| 521 |
+
</ul>
|
| 522 |
+
</div>
|
| 523 |
+
"""
|
| 524 |
+
|
| 525 |
|
| 526 |
def parse_args():
|
| 527 |
parser = argparse.ArgumentParser(description='SVG Generator Service')
|
|
|
|
| 529 |
parser.add_argument('--port', type=int, default=7860)
|
| 530 |
parser.add_argument('--share', action='store_true')
|
| 531 |
parser.add_argument('--debug', action='store_true')
|
| 532 |
+
parser.add_argument('--weight_path', type=str, default="/mnt/jfs-test/OmniSVG_result/8B_1126/1688_bs_4/merge_slerp/merge_150_350_bf16")
|
| 533 |
+
parser.add_argument('--model_path', type=str, default="/mnt/jfs-test/Qwen2.5-VL-7B-Instruct")
|
| 534 |
return parser.parse_args()
|
| 535 |
|
| 536 |
+
|
| 537 |
+
def load_models(weight_path, model_path):
|
| 538 |
+
"""Load all models"""
|
| 539 |
+
global tokenizer, processor, sketch_decoder, svg_tokenizer
|
| 540 |
|
| 541 |
+
print(f"Loading models from {model_path}...")
|
| 542 |
+
print(f"Using precision: {DTYPE}")
|
| 543 |
+
|
| 544 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left")
|
| 545 |
+
processor = AutoProcessor.from_pretrained(model_path, padding_side="left")
|
| 546 |
+
processor.tokenizer.padding_side = "left"
|
| 547 |
|
| 548 |
+
sketch_decoder = SketchDecoder(
|
| 549 |
+
pix_len=config['model']['max_length'],
|
| 550 |
+
text_len=200,
|
| 551 |
+
model_path=model_path,
|
| 552 |
+
torch_dtype=DTYPE
|
| 553 |
+
)
|
| 554 |
+
|
| 555 |
+
bin_path = os.path.join(weight_path, "pytorch_model.bin")
|
| 556 |
+
if os.path.exists(bin_path):
|
| 557 |
+
print(f"Loading weights from: {bin_path}")
|
| 558 |
+
sketch_decoder.load_state_dict(torch.load(bin_path, map_location='cpu'))
|
| 559 |
+
else:
|
| 560 |
+
raise FileNotFoundError(f"No weights found at {bin_path}")
|
| 561 |
+
|
| 562 |
+
sketch_decoder = sketch_decoder.to(device).eval()
|
| 563 |
+
svg_tokenizer = SVGTokenizer('./config.yaml')
|
| 564 |
+
|
| 565 |
+
print("All models loaded successfully!")
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
def detect_text_subtype(text_prompt):
|
| 569 |
+
"""Auto-detect text prompt subtype"""
|
| 570 |
+
text_lower = text_prompt.lower()
|
| 571 |
+
|
| 572 |
+
icon_keywords = ['icon', 'logo', 'symbol', 'badge', 'button', 'emoji', 'glyph', 'simple',
|
| 573 |
+
'arrow', 'triangle', 'circle', 'square', 'heart', 'star', 'checkmark']
|
| 574 |
+
if any(kw in text_lower for kw in icon_keywords):
|
| 575 |
+
return "icon"
|
| 576 |
+
|
| 577 |
+
illustration_keywords = [
|
| 578 |
+
'illustration', 'scene', 'person', 'people', 'character', 'man', 'woman', 'boy', 'girl',
|
| 579 |
+
'avatar', 'portrait', 'face', 'head', 'body',
|
| 580 |
+
'cat', 'dog', 'bird', 'animal', 'pet', 'fox', 'rabbit',
|
| 581 |
+
'sitting', 'standing', 'walking', 'running', 'sleeping', 'holding', 'playing',
|
| 582 |
+
'house', 'building', 'tree', 'garden', 'landscape', 'mountain', 'forest', 'city',
|
| 583 |
+
'ocean', 'beach', 'sunset', 'sunrise', 'sky'
|
| 584 |
+
]
|
| 585 |
+
|
| 586 |
+
match_count = sum(1 for kw in illustration_keywords if kw in text_lower)
|
| 587 |
+
if match_count >= 1 or len(text_prompt) > 50:
|
| 588 |
+
return "illustration"
|
| 589 |
+
|
| 590 |
+
return "icon"
|
| 591 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 592 |
|
| 593 |
+
def detect_and_replace_background(image, threshold=240, edge_sample_ratio=0.1):
|
| 594 |
+
"""
|
| 595 |
+
Detect if image has non-white background and optionally replace it.
|
| 596 |
+
|
| 597 |
+
Args:
|
| 598 |
+
image: PIL Image (RGB or RGBA)
|
| 599 |
+
threshold: Pixel values above this are considered "white"
|
| 600 |
+
edge_sample_ratio: Ratio of edge pixels to sample
|
| 601 |
+
|
| 602 |
+
Returns:
|
| 603 |
+
tuple: (processed_image, background_was_replaced)
|
| 604 |
+
"""
|
| 605 |
+
img_array = np.array(image)
|
| 606 |
+
|
| 607 |
+
# If already has alpha channel, composite onto white
|
| 608 |
+
if image.mode == 'RGBA':
|
| 609 |
+
# Create white background and composite
|
| 610 |
+
bg = Image.new('RGBA', image.size, (255, 255, 255, 255))
|
| 611 |
+
composite = Image.alpha_composite(bg, image)
|
| 612 |
+
return composite.convert('RGB'), True
|
| 613 |
+
|
| 614 |
+
# Sample edge pixels to detect background color
|
| 615 |
+
h, w = img_array.shape[:2]
|
| 616 |
+
edge_pixels = []
|
| 617 |
+
|
| 618 |
+
# Sample from all 4 edges
|
| 619 |
+
sample_count = max(10, int(min(h, w) * edge_sample_ratio))
|
| 620 |
+
|
| 621 |
+
# Top and bottom edges
|
| 622 |
+
for i in range(0, w, max(1, w // sample_count)):
|
| 623 |
+
edge_pixels.append(img_array[0, i])
|
| 624 |
+
edge_pixels.append(img_array[h-1, i])
|
| 625 |
+
|
| 626 |
+
# Left and right edges
|
| 627 |
+
for i in range(0, h, max(1, h // sample_count)):
|
| 628 |
+
edge_pixels.append(img_array[i, 0])
|
| 629 |
+
edge_pixels.append(img_array[i, w-1])
|
| 630 |
+
|
| 631 |
+
edge_pixels = np.array(edge_pixels)
|
| 632 |
+
|
| 633 |
+
# Check if background is already white-ish
|
| 634 |
+
if len(edge_pixels) > 0:
|
| 635 |
+
mean_edge = edge_pixels.mean(axis=0)
|
| 636 |
+
if np.all(mean_edge > threshold):
|
| 637 |
+
# Background is already white, just return original
|
| 638 |
+
return image, False
|
| 639 |
+
|
| 640 |
+
# Background is not white - try to replace it
|
| 641 |
+
# Use the most common edge color as the background color to replace
|
| 642 |
+
if len(img_array.shape) == 3 and img_array.shape[2] >= 3:
|
| 643 |
+
# Convert to grayscale for easier background detection
|
| 644 |
+
if img_array.shape[2] == 4:
|
| 645 |
+
gray = np.mean(img_array[:, :, :3], axis=2)
|
| 646 |
+
else:
|
| 647 |
+
gray = np.mean(img_array, axis=2)
|
| 648 |
|
| 649 |
+
# Find background color (most common color at edges)
|
| 650 |
+
edge_colors = []
|
| 651 |
+
for i in range(w):
|
| 652 |
+
edge_colors.append(tuple(img_array[0, i, :3]))
|
| 653 |
+
edge_colors.append(tuple(img_array[h-1, i, :3]))
|
| 654 |
+
for i in range(h):
|
| 655 |
+
edge_colors.append(tuple(img_array[i, 0, :3]))
|
| 656 |
+
edge_colors.append(tuple(img_array[i, w-1, :3]))
|
| 657 |
|
| 658 |
+
# Find most common edge color
|
| 659 |
+
from collections import Counter
|
| 660 |
+
color_counts = Counter(edge_colors)
|
| 661 |
+
bg_color = color_counts.most_common(1)[0][0]
|
| 662 |
+
|
| 663 |
+
# Create mask for background (colors similar to detected bg_color)
|
| 664 |
+
color_diff = np.sqrt(np.sum((img_array[:, :, :3].astype(float) - np.array(bg_color)) ** 2, axis=2))
|
| 665 |
+
bg_mask = color_diff < 30 # Threshold for color similarity
|
| 666 |
+
|
| 667 |
+
# Replace background with white
|
| 668 |
+
result = img_array.copy()
|
| 669 |
+
if result.shape[2] == 4:
|
| 670 |
+
result[bg_mask] = [255, 255, 255, 255]
|
| 671 |
+
else:
|
| 672 |
+
result[bg_mask] = [255, 255, 255]
|
| 673 |
+
|
| 674 |
+
return Image.fromarray(result).convert('RGB'), True
|
| 675 |
+
|
| 676 |
+
return image, False
|
| 677 |
|
| 678 |
+
|
| 679 |
+
def preprocess_image_for_svg(image, replace_background=True, target_size=448):
|
| 680 |
+
"""
|
| 681 |
+
Preprocess image for SVG generation.
|
| 682 |
+
|
| 683 |
+
Args:
|
| 684 |
+
image: Input PIL Image or path
|
| 685 |
+
replace_background: Whether to replace non-white backgrounds
|
| 686 |
+
target_size: Target size for resizing
|
| 687 |
+
|
| 688 |
+
Returns:
|
| 689 |
+
tuple: (processed_pil_image, was_modified)
|
| 690 |
+
"""
|
| 691 |
+
# Load image if path
|
| 692 |
+
if isinstance(image, str):
|
| 693 |
+
raw_img = Image.open(image)
|
| 694 |
+
else:
|
| 695 |
+
raw_img = image
|
| 696 |
+
|
| 697 |
+
was_modified = False
|
| 698 |
+
|
| 699 |
+
# Handle different modes
|
| 700 |
+
if raw_img.mode == 'RGBA':
|
| 701 |
+
# RGBA images: composite onto white background
|
| 702 |
+
bg = Image.new('RGBA', raw_img.size, (255, 255, 255, 255))
|
| 703 |
+
img_with_bg = Image.alpha_composite(bg, raw_img).convert('RGB')
|
| 704 |
+
was_modified = True
|
| 705 |
+
elif raw_img.mode == 'LA' or raw_img.mode == 'PA':
|
| 706 |
+
# Grayscale or Palette with alpha
|
| 707 |
+
raw_img = raw_img.convert('RGBA')
|
| 708 |
+
bg = Image.new('RGBA', raw_img.size, (255, 255, 255, 255))
|
| 709 |
+
img_with_bg = Image.alpha_composite(bg, raw_img).convert('RGB')
|
| 710 |
+
was_modified = True
|
| 711 |
+
elif raw_img.mode != 'RGB':
|
| 712 |
+
img_with_bg = raw_img.convert('RGB')
|
| 713 |
else:
|
| 714 |
+
img_with_bg = raw_img
|
| 715 |
+
|
| 716 |
+
# Optionally detect and replace non-white background
|
| 717 |
+
if replace_background:
|
| 718 |
+
img_with_bg, bg_replaced = detect_and_replace_background(img_with_bg)
|
| 719 |
+
was_modified = was_modified or bg_replaced
|
| 720 |
|
| 721 |
+
# Resize to target size
|
| 722 |
+
img_resized = img_with_bg.resize((target_size, target_size), Image.Resampling.LANCZOS)
|
| 723 |
+
|
| 724 |
+
return img_resized, was_modified
|
| 725 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 726 |
|
| 727 |
+
def prepare_inputs(task_type, content):
|
| 728 |
+
"""Prepare model inputs"""
|
| 729 |
+
if task_type == "text-to-svg":
|
| 730 |
+
prompt_text = str(content).strip()
|
| 731 |
+
|
| 732 |
+
instruction = f"""Generate an SVG illustration for: {prompt_text}
|
| 733 |
+
|
| 734 |
+
Requirements:
|
| 735 |
+
- Create complete SVG path commands
|
| 736 |
+
- Include proper coordinates and colors
|
| 737 |
+
- Maintain visual clarity and composition"""
|
| 738 |
+
|
| 739 |
+
messages = [
|
| 740 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 741 |
+
{"role": "user", "content": [{"type": "text", "text": instruction}]}
|
| 742 |
+
]
|
| 743 |
+
text_input = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 744 |
+
inputs = processor(text=[text_input], padding=True, truncation=True, return_tensors="pt")
|
| 745 |
+
|
| 746 |
+
else: # image-to-svg
|
| 747 |
+
messages = [
|
| 748 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 749 |
+
{"role": "user", "content": [
|
| 750 |
+
{"type": "text", "text": "Generate SVG code that accurately represents this image:"},
|
| 751 |
+
{"type": "image", "image": content},
|
| 752 |
+
]}
|
| 753 |
]
|
| 754 |
+
text_input = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 755 |
+
image_inputs, _ = process_vision_info(messages)
|
| 756 |
+
inputs = processor(text=[text_input], images=image_inputs, padding=True, truncation=True, return_tensors="pt")
|
| 757 |
+
|
| 758 |
+
return inputs
|
| 759 |
+
|
| 760 |
+
|
| 761 |
+
def render_svg_to_image(svg_str, size=512):
|
| 762 |
+
"""Render SVG to high-quality PIL Image"""
|
| 763 |
+
try:
|
| 764 |
+
png_data = cairosvg.svg2png(
|
| 765 |
+
bytestring=svg_str.encode('utf-8'),
|
| 766 |
+
output_width=size,
|
| 767 |
+
output_height=size
|
| 768 |
+
)
|
| 769 |
+
image_rgba = Image.open(io.BytesIO(png_data)).convert("RGBA")
|
| 770 |
+
bg = Image.new("RGB", image_rgba.size, (255, 255, 255))
|
| 771 |
+
bg.paste(image_rgba, mask=image_rgba.split()[3])
|
| 772 |
+
return bg
|
| 773 |
+
except Exception as e:
|
| 774 |
+
print(f"Render error: {e}")
|
| 775 |
+
return None
|
| 776 |
+
|
| 777 |
+
|
| 778 |
+
def create_gallery_html(candidates, cols=4):
|
| 779 |
+
"""Create HTML gallery for multiple SVG candidates"""
|
| 780 |
+
if not candidates:
|
| 781 |
+
return '<div style="text-align:center;color:#999;padding:50px;">No candidates generated / 未生成候选</div>'
|
| 782 |
|
| 783 |
+
items_html = []
|
| 784 |
+
for i, cand in enumerate(candidates):
|
| 785 |
+
svg_str = cand['svg']
|
| 786 |
+
if 'viewBox' not in svg_str:
|
| 787 |
+
svg_str = svg_str.replace('<svg', f'<svg viewBox="0 0 {TARGET_IMAGE_SIZE} {TARGET_IMAGE_SIZE}"', 1)
|
| 788 |
+
|
| 789 |
+
item_html = f'''
|
| 790 |
+
<div style="
|
| 791 |
+
background: white;
|
| 792 |
+
border: 1px solid #ddd;
|
| 793 |
+
border-radius: 8px;
|
| 794 |
+
padding: 10px;
|
| 795 |
+
text-align: center;
|
| 796 |
+
transition: transform 0.2s, box-shadow 0.2s;
|
| 797 |
+
cursor: pointer;
|
| 798 |
+
" onmouseover="this.style.transform='scale(1.02)';this.style.boxShadow='0 4px 12px rgba(0,0,0,0.15)';"
|
| 799 |
+
onmouseout="this.style.transform='scale(1)';this.style.boxShadow='none';">
|
| 800 |
+
<div style="width: 180px; height: 180px; margin: 0 auto; display: flex; justify-content: center; align-items: center; overflow: hidden;">
|
| 801 |
+
{svg_str}
|
| 802 |
+
</div>
|
| 803 |
+
<div style="margin-top: 8px; font-size: 12px; color: #666;">
|
| 804 |
+
#{i+1} | {cand['path_count']} paths
|
| 805 |
+
</div>
|
| 806 |
+
</div>
|
| 807 |
+
'''
|
| 808 |
+
items_html.append(item_html)
|
| 809 |
|
| 810 |
+
grid_html = f'''
|
| 811 |
+
<div style="
|
| 812 |
+
display: grid;
|
| 813 |
+
grid-template-columns: repeat({cols}, 1fr);
|
| 814 |
+
gap: 15px;
|
| 815 |
+
padding: 15px;
|
| 816 |
+
background: #fafafa;
|
| 817 |
+
border-radius: 12px;
|
| 818 |
+
">
|
| 819 |
+
{''.join(items_html)}
|
| 820 |
+
</div>
|
| 821 |
+
'''
|
| 822 |
+
return grid_html
|
| 823 |
+
|
| 824 |
+
|
| 825 |
+
def is_valid_candidate(svg_str, img, subtype="illustration"):
|
| 826 |
+
"""Check candidate validity"""
|
| 827 |
+
if not svg_str or len(svg_str) < 20:
|
| 828 |
+
return False, "too_short"
|
| 829 |
|
| 830 |
+
if '<svg' not in svg_str:
|
| 831 |
+
return False, "no_svg_tag"
|
| 832 |
|
| 833 |
+
if img is None:
|
| 834 |
+
return False, "render_failed"
|
| 835 |
+
|
| 836 |
+
img_array = np.array(img)
|
| 837 |
+
mean_val = img_array.mean()
|
| 838 |
+
|
| 839 |
+
threshold = 250 if subtype == "illustration" else 252
|
| 840 |
+
|
| 841 |
+
if mean_val > threshold:
|
| 842 |
+
return False, "empty_image"
|
| 843 |
+
|
| 844 |
+
return True, "ok"
|
| 845 |
+
|
| 846 |
+
|
| 847 |
+
def generate_candidates(inputs, task_type, subtype, temperature, top_p, top_k, repetition_penalty,
|
| 848 |
+
max_length, num_samples, progress_callback=None):
|
| 849 |
+
"""Generate candidate SVGs with full parameter control"""
|
| 850 |
|
| 851 |
input_ids = inputs['input_ids'].to(device)
|
| 852 |
attention_mask = inputs['attention_mask'].to(device)
|
|
|
|
|
|
|
| 853 |
|
| 854 |
+
model_inputs = {
|
| 855 |
+
"input_ids": input_ids,
|
| 856 |
+
"attention_mask": attention_mask
|
| 857 |
+
}
|
| 858 |
+
|
| 859 |
+
if 'pixel_values' in inputs:
|
| 860 |
+
model_inputs["pixel_values"] = inputs['pixel_values'].to(device, dtype=DTYPE)
|
| 861 |
+
|
| 862 |
+
if 'image_grid_thw' in inputs:
|
| 863 |
+
model_inputs["image_grid_thw"] = inputs['image_grid_thw'].to(device)
|
| 864 |
+
|
| 865 |
+
all_candidates = []
|
| 866 |
+
|
| 867 |
+
# Generation config with user parameters
|
| 868 |
+
gen_config = {
|
| 869 |
+
'do_sample': True,
|
| 870 |
+
'temperature': temperature,
|
| 871 |
+
'top_p': top_p,
|
| 872 |
+
'top_k': int(top_k),
|
| 873 |
+
'repetition_penalty': repetition_penalty,
|
| 874 |
+
'early_stopping': True,
|
| 875 |
+
'no_repeat_ngram_size': 0,
|
| 876 |
+
'eos_token_id': config['model']['eos_token_id'],
|
| 877 |
+
'pad_token_id': config['model']['pad_token_id'],
|
| 878 |
+
'bos_token_id': config['model']['bos_token_id'],
|
| 879 |
+
}
|
| 880 |
+
|
| 881 |
+
actual_samples = num_samples + 4
|
| 882 |
+
|
| 883 |
try:
|
| 884 |
+
if progress_callback:
|
| 885 |
+
progress_callback(0.1, "Waiting for model access / 等待模型访问...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 886 |
|
| 887 |
+
with generation_lock:
|
| 888 |
+
if progress_callback:
|
| 889 |
+
progress_callback(0.15, "Generating SVG tokens / 生成SVG令牌...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 890 |
|
| 891 |
+
with torch.no_grad():
|
| 892 |
+
results = sketch_decoder.transformer.generate(
|
| 893 |
+
**model_inputs,
|
| 894 |
+
max_new_tokens=max_length,
|
| 895 |
+
num_return_sequences=actual_samples,
|
| 896 |
+
use_cache=True,
|
| 897 |
+
**gen_config
|
| 898 |
+
)
|
| 899 |
+
|
| 900 |
+
input_len = input_ids.shape[1]
|
| 901 |
+
generated_ids_batch = results[:, input_len:]
|
| 902 |
|
| 903 |
+
if progress_callback:
|
| 904 |
+
progress_callback(0.5, "Processing generated tokens / 处理生成的令牌...")
|
| 905 |
|
| 906 |
+
for i in range(min(actual_samples, generated_ids_batch.shape[0])):
|
| 907 |
+
try:
|
| 908 |
+
current_ids = generated_ids_batch[i:i+1]
|
| 909 |
+
|
| 910 |
+
fake_wrapper = torch.cat([
|
| 911 |
+
torch.full((1, 1), config['model']['bos_token_id'], device=device),
|
| 912 |
+
current_ids,
|
| 913 |
+
torch.full((1, 1), config['model']['eos_token_id'], device=device)
|
| 914 |
+
], dim=1)
|
| 915 |
+
|
| 916 |
+
generated_xy = svg_tokenizer.process_generated_tokens(fake_wrapper)
|
| 917 |
+
if len(generated_xy) == 0:
|
| 918 |
+
continue
|
| 919 |
+
|
| 920 |
+
svg_tensors, color_tensors = svg_tokenizer.raster_svg(generated_xy)
|
| 921 |
+
if not svg_tensors or not svg_tensors[0]:
|
| 922 |
+
continue
|
| 923 |
+
|
| 924 |
+
num_paths = len(svg_tensors[0])
|
| 925 |
+
while len(color_tensors) < num_paths:
|
| 926 |
+
color_tensors.append(BLACK_COLOR_TOKEN)
|
| 927 |
+
|
| 928 |
+
svg = svg_tokenizer.apply_colors_to_svg(svg_tensors[0], color_tensors)
|
| 929 |
+
svg_str = svg.to_str()
|
| 930 |
+
|
| 931 |
+
if 'width=' not in svg_str:
|
| 932 |
+
svg_str = svg_str.replace('<svg', f'<svg width="{TARGET_IMAGE_SIZE}" height="{TARGET_IMAGE_SIZE}"', 1)
|
| 933 |
+
|
| 934 |
+
png_image = render_svg_to_image(svg_str, size=512)
|
| 935 |
+
|
| 936 |
+
is_valid, reason = is_valid_candidate(svg_str, png_image, subtype)
|
| 937 |
+
if is_valid:
|
| 938 |
+
all_candidates.append({
|
| 939 |
+
'svg': svg_str,
|
| 940 |
+
'img': png_image,
|
| 941 |
+
'path_count': num_paths,
|
| 942 |
+
'index': len(all_candidates) + 1
|
| 943 |
+
})
|
| 944 |
+
|
| 945 |
+
if progress_callback:
|
| 946 |
+
progress_callback(0.5 + 0.4 * (i / actual_samples),
|
| 947 |
+
f"Found {len(all_candidates)} valid / 找到 {len(all_candidates)} 个有效...")
|
| 948 |
+
|
| 949 |
+
if len(all_candidates) >= num_samples:
|
| 950 |
+
break
|
| 951 |
+
|
| 952 |
+
except Exception as e:
|
| 953 |
+
print(f" Candidate {i} error: {e}")
|
| 954 |
+
continue
|
| 955 |
+
|
| 956 |
except Exception as e:
|
| 957 |
+
print(f"Generation Error: {e}")
|
| 958 |
import traceback
|
| 959 |
traceback.print_exc()
|
| 960 |
+
|
| 961 |
+
if progress_callback:
|
| 962 |
+
progress_callback(0.95, f"Generated {len(all_candidates)} valid / 生成了 {len(all_candidates)} 个有效")
|
| 963 |
+
|
| 964 |
+
return all_candidates
|
| 965 |
|
| 966 |
@spaces.GPU
|
| 967 |
+
def gradio_text_to_svg(text_description, num_candidates, temperature, top_p, top_k, repetition_penalty,
|
| 968 |
+
progress=gr.Progress()):
|
| 969 |
+
"""Gradio interface - text-to-svg with advanced parameters"""
|
| 970 |
+
if not text_description or text_description.strip() == "":
|
| 971 |
+
return '<div style="text-align:center;color:#999;padding:50px;">Please enter a description / 请输入描述</div>', ""
|
| 972 |
+
|
| 973 |
+
progress(0, "Starting generation / 开始生成...")
|
| 974 |
+
|
| 975 |
+
gc.collect()
|
| 976 |
+
if torch.cuda.is_available():
|
| 977 |
+
torch.cuda.empty_cache()
|
| 978 |
+
|
| 979 |
+
start_time = time.time()
|
| 980 |
+
|
| 981 |
+
subtype = detect_text_subtype(text_description)
|
| 982 |
+
progress(0.05, f"Detected: {subtype} / 检测到: {subtype}")
|
| 983 |
+
|
| 984 |
+
inputs = prepare_inputs("text-to-svg", text_description.strip())
|
| 985 |
+
max_length = config['model']['max_length']
|
| 986 |
+
|
| 987 |
+
def update_progress(val, msg):
|
| 988 |
+
progress(val, msg)
|
| 989 |
+
|
| 990 |
+
all_candidates = generate_candidates(
|
| 991 |
+
inputs, "text-to-svg", subtype,
|
| 992 |
+
temperature, top_p, int(top_k), repetition_penalty,
|
| 993 |
+
max_length, int(num_candidates),
|
| 994 |
+
progress_callback=update_progress
|
| 995 |
+
)
|
| 996 |
+
|
| 997 |
+
elapsed = time.time() - start_time
|
| 998 |
+
|
| 999 |
+
if not all_candidates:
|
| 1000 |
+
return (
|
| 1001 |
+
'<div style="text-align:center;color:#999;padding:50px;">No valid SVG generated. Try different parameters or rephrase your prompt.<br/>未生成有效的SVG。请尝试不同参数或重新描述。</div>',
|
| 1002 |
+
f"<!-- No valid SVG (took {elapsed:.1f}s) -->"
|
| 1003 |
+
)
|
| 1004 |
|
| 1005 |
+
svg_codes = []
|
| 1006 |
+
for i, cand in enumerate(all_candidates):
|
| 1007 |
+
svg_codes.append(f"<!-- ====== Candidate {i+1} | Paths: {cand['path_count']} ====== -->\n{cand['svg']}")
|
| 1008 |
+
|
| 1009 |
+
combined_svg = "\n\n".join(svg_codes)
|
| 1010 |
+
gallery_html = create_gallery_html(all_candidates)
|
| 1011 |
+
|
| 1012 |
+
progress(1.0, f"Done! {len(all_candidates)} candidates in {elapsed:.1f}s / 完成!{len(all_candidates)} 个候选,{elapsed:.1f}秒")
|
| 1013 |
+
|
| 1014 |
+
return gallery_html, combined_svg
|
| 1015 |
+
|
| 1016 |
+
@spaces.GPU
|
| 1017 |
+
def gradio_image_to_svg(image, num_candidates, temperature, top_p, top_k, repetition_penalty,
|
| 1018 |
+
replace_background, progress=gr.Progress()):
|
| 1019 |
+
"""Gradio interface - image-to-svg with background handling"""
|
| 1020 |
+
|
| 1021 |
+
if image is None:
|
| 1022 |
+
return (
|
| 1023 |
+
'<div style="text-align:center;color:#999;padding:50px;">Please upload an image / 请上传图片</div>',
|
| 1024 |
+
"",
|
| 1025 |
+
None
|
| 1026 |
+
)
|
| 1027 |
+
|
| 1028 |
+
progress(0, "Processing input image / 处理输入图片...")
|
| 1029 |
+
|
| 1030 |
+
gc.collect()
|
| 1031 |
+
if torch.cuda.is_available():
|
| 1032 |
+
torch.cuda.empty_cache()
|
| 1033 |
+
|
| 1034 |
+
start_time = time.time()
|
| 1035 |
+
|
| 1036 |
+
# Preprocess image with optional background replacement
|
| 1037 |
+
img_processed, was_modified = preprocess_image_for_svg(
|
| 1038 |
+
image,
|
| 1039 |
+
replace_background=replace_background,
|
| 1040 |
+
target_size=TARGET_IMAGE_SIZE
|
| 1041 |
+
)
|
| 1042 |
+
|
| 1043 |
+
if was_modified:
|
| 1044 |
+
progress(0.05, "Background processed / 背景已处理")
|
| 1045 |
+
|
| 1046 |
+
# Save temp file
|
| 1047 |
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_file:
|
| 1048 |
+
img_processed.save(tmp_file.name, format='PNG', quality=100)
|
| 1049 |
tmp_path = tmp_file.name
|
| 1050 |
|
| 1051 |
try:
|
| 1052 |
+
progress(0.1, "Preparing model inputs / 准备模型输入...")
|
| 1053 |
+
inputs = prepare_inputs("image-to-svg", tmp_path)
|
| 1054 |
+
max_length = config['model']['max_length']
|
| 1055 |
+
|
| 1056 |
+
def update_progress(val, msg):
|
| 1057 |
+
progress(val, msg)
|
| 1058 |
+
|
| 1059 |
+
all_candidates = generate_candidates(
|
| 1060 |
+
inputs, "image-to-svg", "image",
|
| 1061 |
+
temperature, top_p, int(top_k), repetition_penalty,
|
| 1062 |
+
max_length, int(num_candidates),
|
| 1063 |
+
progress_callback=update_progress
|
| 1064 |
+
)
|
| 1065 |
+
|
| 1066 |
+
elapsed = time.time() - start_time
|
| 1067 |
+
|
| 1068 |
+
if not all_candidates:
|
| 1069 |
+
return (
|
| 1070 |
+
'<div style="text-align:center;color:#999;padding:50px;">No valid SVG generated. Try adjusting parameters.<br/>未生成有效的SVG。请尝试调整参数。</div>',
|
| 1071 |
+
f"<!-- No valid SVG (took {elapsed:.1f}s) -->",
|
| 1072 |
+
img_processed
|
| 1073 |
+
)
|
| 1074 |
+
|
| 1075 |
+
svg_codes = []
|
| 1076 |
+
for i, cand in enumerate(all_candidates):
|
| 1077 |
+
svg_codes.append(f"<!-- ====== Candidate {i+1} | Paths: {cand['path_count']} ====== -->\n{cand['svg']}")
|
| 1078 |
+
|
| 1079 |
+
combined_svg = "\n\n".join(svg_codes)
|
| 1080 |
+
gallery_html = create_gallery_html(all_candidates)
|
| 1081 |
+
|
| 1082 |
+
progress(1.0, f"Done! {len(all_candidates)} candidates in {elapsed:.1f}s")
|
| 1083 |
+
|
| 1084 |
+
return gallery_html, combined_svg, img_processed
|
| 1085 |
+
|
| 1086 |
finally:
|
| 1087 |
+
if os.path.exists(tmp_path):
|
| 1088 |
+
os.unlink(tmp_path)
|
| 1089 |
|
| 1090 |
+
|
| 1091 |
+
def get_example_images():
|
| 1092 |
+
"""Get example images from the examples directory"""
|
| 1093 |
+
example_dir = "./examples"
|
| 1094 |
+
example_images = []
|
| 1095 |
|
| 1096 |
+
if os.path.exists(example_dir):
|
| 1097 |
+
for ext in SUPPORTED_FORMATS:
|
| 1098 |
+
pattern = os.path.join(example_dir, f"*{ext}")
|
| 1099 |
+
example_images.extend(glob.glob(pattern))
|
| 1100 |
+
example_images.sort()
|
| 1101 |
+
|
| 1102 |
+
return example_images
|
| 1103 |
+
|
| 1104 |
|
| 1105 |
def create_interface():
|
| 1106 |
+
"""Create Gradio interface"""
|
| 1107 |
+
|
| 1108 |
+
# 30 Example prompts covering various categories
|
| 1109 |
example_texts = [
|
| 1110 |
+
# === Simple Icons (1-6) ===
|
| 1111 |
+
"A black triangle pointing downward, centrally positioned on white background.",
|
| 1112 |
+
"A red heart shape with smooth curved edges, centered on white background.",
|
| 1113 |
+
"A yellow star with five sharp points, simple geometric design, flat color on white background.",
|
| 1114 |
+
"A blue arrow pointing to the right, thick solid shape, centered on white background.",
|
| 1115 |
+
"A green circle with a white checkmark inside, centered on white background.",
|
| 1116 |
+
"A black plus sign with equal length arms, thick lines, centered on white background.",
|
| 1117 |
+
|
| 1118 |
+
# === Characters & People (7-12) ===
|
| 1119 |
+
"A simple person standing: round beige head, rectangular blue shirt body, two dark gray rectangular legs, arms at sides. Flat colors, white background.",
|
| 1120 |
+
"A girl with long black hair, wearing pink dress with triangular skirt, small circular face with dot eyes and curved smile. Simple cartoon style, white background.",
|
| 1121 |
+
"A businessman: circular head with short black hair, rectangular dark navy suit body, straight standing pose. Professional minimal style, white background.",
|
| 1122 |
+
"A child waving: large round head with brown messy hair, big circular eyes, small body in red t-shirt and blue shorts, one arm raised. Cheerful cartoon style.",
|
| 1123 |
+
"A person sitting on chair: side view, round head, rectangular torso in green sweater, bent legs on simple chair shape. Relaxed pose, white background.",
|
| 1124 |
+
"A running person: side view silhouette in black, dynamic pose with one leg forward, arms pumping. Motion style, white background.",
|
| 1125 |
+
|
| 1126 |
+
# === Avatars & Portraits (13-17) ===
|
| 1127 |
+
"Circular avatar: person with short black hair, round face with two dot eyes and small curved smile, wearing blue collar shirt. Minimal style, centered in circle.",
|
| 1128 |
+
"Female avatar: oval face with long wavy brown hair, simple eyes, pink lips, wearing v-neck purple top. Soft cartoon style in circular frame.",
|
| 1129 |
+
"Profile silhouette avatar: black side view of head with short hair and glasses outline, facing right. Simple solid shape on white.",
|
| 1130 |
+
"Cute cartoon avatar: round face with big sparkly eyes, rosy cheeks, short bob haircut in orange. Kawaii style, circular frame.",
|
| 1131 |
+
"Professional headshot avatar: person with neat hair, neutral expression, wearing suit collar. Corporate minimal style, circular frame, white background.",
|
| 1132 |
+
|
| 1133 |
+
# === Landscapes & Scenes (18-23) ===
|
| 1134 |
+
"Layered mountain landscape: light blue sky at top, gray triangular snow-capped mountains in middle, dark green triangular pine trees at bottom. Flat colors.",
|
| 1135 |
+
"Sunset beach scene: orange gradient sky at top, yellow semicircle sun on horizon, dark blue wavy ocean, tan beach strip at bottom. Simple shapes.",
|
| 1136 |
+
"Forest scene: light blue sky, row of 5 dark green triangular pine trees of varying heights on brown trunks, light green grass at bottom.",
|
| 1137 |
+
"City skyline at dusk: purple-orange gradient sky, row of black rectangular building silhouettes of different heights, some with yellow window squares.",
|
| 1138 |
+
"Desert landscape: light orange sky with white circle sun, tan sand dunes as curved shapes, one green cactus with arms on the right side.",
|
| 1139 |
+
"Countryside scene: blue sky with white fluffy clouds, green rolling hills, small red barn with white door in the center, yellow hay bales.",
|
| 1140 |
+
|
| 1141 |
+
# === Animals (24-27) ===
|
| 1142 |
+
"Cute orange cat sitting: round head with two triangular ears, oval body, curved tail. Black outline cartoon style, facing forward, white background.",
|
| 1143 |
+
"Simple black bird: oval body, round head, pointed triangular beak facing right, triangular tail, two stick legs. Silhouette style on white.",
|
| 1144 |
+
"Friendly cartoon dog: brown oval body, round head with floppy ears, black dot nose, wagging curved tail, four short legs. Sitting pose.",
|
| 1145 |
+
"Red fox logo: triangular orange face with pointed ears, white chest marking, bushy tail. Minimalist style, facing right, centered on white.",
|
| 1146 |
+
|
| 1147 |
+
# === Objects & Misc (28-30) ===
|
| 1148 |
+
"Simple house icon: red triangular roof, beige rectangular walls, brown door in center, two blue square windows, green ground at bottom.",
|
| 1149 |
+
"Coffee mug: brown cylindrical cup with curved handle on right, three wavy steam lines rising from top. Flat style on white background.",
|
| 1150 |
+
"Open book: two rectangular white pages spread open, black text lines on each page, brown spine in center. Simple top-down view."
|
| 1151 |
]
|
| 1152 |
+
|
| 1153 |
example_images = get_example_images()
|
| 1154 |
|
| 1155 |
+
with gr.Blocks(title="OmniSVG Generator", theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo:
|
| 1156 |
+
# Header
|
| 1157 |
+
gr.HTML("""
|
| 1158 |
+
<div class="header-container">
|
| 1159 |
+
<h1>🎨 OmniSVG Generator</h1>
|
| 1160 |
+
<p>Transform images and text descriptions into scalable vector graphics</p>
|
| 1161 |
+
<p style="font-size: 0.9em; opacity: 0.8;">将图像和文本描述转换为可缩放矢量图形</p>
|
| 1162 |
+
</div>
|
| 1163 |
+
""")
|
| 1164 |
+
|
| 1165 |
+
# Queue status
|
| 1166 |
+
gr.HTML("""
|
| 1167 |
+
<div style="background: #e7f3ff; border: 1px solid #b3d7ff; border-radius: 8px; padding: 12px 15px; margin-bottom: 15px;">
|
| 1168 |
+
<span style="font-size: 1.5em;">ℹ️</span>
|
| 1169 |
+
<strong>Queue System Active</strong> - Requests processed one at a time. Please wait patiently if busy.<br/>
|
| 1170 |
+
<span style="color: #666;">队列系统已启用 - 请求按顺序处理,繁忙时请耐心等待。</span>
|
| 1171 |
+
</div>
|
| 1172 |
+
""")
|
| 1173 |
+
|
| 1174 |
+
# Tips section
|
| 1175 |
+
gr.HTML(TIPS_HTML)
|
| 1176 |
|
| 1177 |
with gr.Tabs():
|
| 1178 |
+
# ==================== Image-to-SVG Tab ====================
|
| 1179 |
+
with gr.TabItem("🖼️ Image-to-SVG", id="image-tab"):
|
| 1180 |
+
gr.HTML(IMAGE_TIPS_HTML)
|
| 1181 |
+
|
| 1182 |
+
with gr.Row(equal_height=False):
|
| 1183 |
+
with gr.Column(scale=1, min_width=300):
|
| 1184 |
+
gr.Markdown("### 📤 Upload Image / 上传图片")
|
| 1185 |
+
image_input = gr.Image(
|
| 1186 |
+
label="Drag, upload, or Ctrl+V to paste / 拖拽、上传或Ctrl+V粘贴",
|
| 1187 |
+
type="pil",
|
| 1188 |
+
image_mode="RGBA",
|
| 1189 |
+
height=250,
|
| 1190 |
+
sources=["upload", "clipboard"],
|
| 1191 |
+
elem_classes=["input-image"]
|
| 1192 |
+
)
|
| 1193 |
+
|
| 1194 |
+
with gr.Group(elem_classes=["settings-group"]):
|
| 1195 |
+
gr.Markdown("### ⚙️ Settings / 设置")
|
| 1196 |
+
img_num_candidates = gr.Slider(
|
| 1197 |
+
minimum=1, maximum=8, value=4, step=1,
|
| 1198 |
+
label="Number of Candidates / 候选数量"
|
| 1199 |
+
)
|
| 1200 |
+
img_replace_bg = gr.Checkbox(
|
| 1201 |
+
label="Replace non-white background / 替换非白色背景",
|
| 1202 |
+
value=True,
|
| 1203 |
+
info="Enable for images with colored backgrounds / 对有色背景图片启用"
|
| 1204 |
+
)
|
| 1205 |
+
|
| 1206 |
+
with gr.Accordion("🔧 Advanced Parameters / 高级参数", open=False):
|
| 1207 |
+
img_temperature = gr.Slider(
|
| 1208 |
+
minimum=0.1, maximum=1.0, value=0.3, step=0.05,
|
| 1209 |
+
label="Temperature (Lower=accurate)",
|
| 1210 |
+
info="0.2-0.4 recommended / 建议0.2-0.4"
|
| 1211 |
+
)
|
| 1212 |
+
img_top_p = gr.Slider(
|
| 1213 |
+
minimum=0.5, maximum=1.0, value=0.90, step=0.02,
|
| 1214 |
+
label="Top-P"
|
| 1215 |
+
)
|
| 1216 |
+
img_top_k = gr.Slider(
|
| 1217 |
+
minimum=10, maximum=100, value=50, step=5,
|
| 1218 |
+
label="Top-K"
|
| 1219 |
+
)
|
| 1220 |
+
img_rep_penalty = gr.Slider(
|
| 1221 |
+
minimum=1.0, maximum=1.3, value=1.05, step=0.01,
|
| 1222 |
+
label="Repetition Penalty"
|
| 1223 |
+
)
|
| 1224 |
+
|
| 1225 |
+
image_generate_btn = gr.Button(
|
| 1226 |
+
"🚀 Generate SVG / 生成SVG",
|
| 1227 |
+
variant="primary",
|
| 1228 |
+
size="lg",
|
| 1229 |
+
elem_classes=["primary-btn"]
|
| 1230 |
+
)
|
| 1231 |
+
|
| 1232 |
if example_images:
|
| 1233 |
+
gr.Markdown("### 📁 Examples")
|
| 1234 |
+
gr.Examples(examples=example_images, inputs=[image_input], label="")
|
| 1235 |
|
| 1236 |
+
with gr.Column(scale=2, min_width=500):
|
| 1237 |
+
gr.Markdown("### 📥 Processed Input / 处理后输入")
|
| 1238 |
+
image_processed = gr.Image(label="", type="pil", height=120)
|
| 1239 |
+
|
| 1240 |
+
gr.Markdown("### 🖼️ Generated SVG Candidates / 生成的SVG候选")
|
| 1241 |
+
image_gallery = gr.HTML(
|
| 1242 |
+
value='<div style="text-align:center;color:#999;padding:50px;background:#fafafa;border-radius:12px;">Generated SVGs will appear here / 生成的SVG将显示在这里</div>'
|
| 1243 |
+
)
|
| 1244 |
+
|
| 1245 |
+
gr.Markdown("### 📝 SVG Code")
|
| 1246 |
+
image_svg_output = gr.Code(label="", language="html", lines=10, elem_classes=["code-output"])
|
| 1247 |
|
| 1248 |
+
image_generate_btn.click(
|
| 1249 |
+
fn=gradio_image_to_svg,
|
| 1250 |
+
inputs=[image_input, img_num_candidates, img_temperature, img_top_p,
|
| 1251 |
+
img_top_k, img_rep_penalty, img_replace_bg],
|
| 1252 |
+
outputs=[image_gallery, image_svg_output, image_processed],
|
| 1253 |
+
queue=True
|
| 1254 |
+
)
|
| 1255 |
+
|
| 1256 |
+
# ==================== Text-to-SVG Tab ====================
|
| 1257 |
+
with gr.TabItem("✏️ Text-to-SVG", id="text-tab"):
|
| 1258 |
+
with gr.Row(equal_height=False):
|
| 1259 |
+
with gr.Column(scale=1, min_width=300):
|
| 1260 |
+
gr.Markdown("### 📝 Description / 描述")
|
| 1261 |
+
gr.HTML("""
|
| 1262 |
+
<div style="background: #fff5f5; padding: 10px; border-radius: 8px; border-left: 4px solid #dc3545; margin-bottom: 10px;">
|
| 1263 |
+
<strong style="color: #dc3545;">🔴 Generate 4-8 candidates and pick the best!</strong><br/>
|
| 1264 |
+
生成4-8个候选结果并选择最好的!
|
| 1265 |
+
</div>
|
| 1266 |
+
""")
|
| 1267 |
+
text_input = gr.Textbox(
|
| 1268 |
+
label="",
|
| 1269 |
+
placeholder="Describe your SVG with geometric shapes and colors...\n用几何形状和颜色描述您的SVG...\n\nExample: A black triangle pointing downward, centrally positioned on white background.",
|
| 1270 |
+
lines=5
|
| 1271 |
+
)
|
| 1272 |
+
|
| 1273 |
+
with gr.Group(elem_classes=["settings-group"]):
|
| 1274 |
+
gr.Markdown("### ⚙️ Settings / 设置")
|
| 1275 |
+
text_num_candidates = gr.Slider(
|
| 1276 |
+
minimum=1, maximum=8, value=6, step=1,
|
| 1277 |
+
label="Number of Candidates / 候选数量",
|
| 1278 |
+
info="More = better chances! / 越多越好!"
|
| 1279 |
+
)
|
| 1280 |
+
|
| 1281 |
+
with gr.Accordion("🔧 Advanced Parameters / 高级参数", open=False):
|
| 1282 |
+
text_temperature = gr.Slider(
|
| 1283 |
+
minimum=0.1, maximum=1.0, value=0.5, step=0.05,
|
| 1284 |
+
label="Temperature",
|
| 1285 |
+
info="Icons: 0.3-0.5 | Complex: 0.5-0.7"
|
| 1286 |
+
)
|
| 1287 |
+
text_top_p = gr.Slider(
|
| 1288 |
+
minimum=0.5, maximum=1.0, value=0.90, step=0.02,
|
| 1289 |
+
label="Top-P"
|
| 1290 |
+
)
|
| 1291 |
+
text_top_k = gr.Slider(
|
| 1292 |
+
minimum=10, maximum=100, value=60, step=5,
|
| 1293 |
+
label="Top-K"
|
| 1294 |
+
)
|
| 1295 |
+
text_rep_penalty = gr.Slider(
|
| 1296 |
+
minimum=1.0, maximum=1.3, value=1.03, step=0.01,
|
| 1297 |
+
label="Repetition Penalty",
|
| 1298 |
+
info="Increase if you see repetitive patterns"
|
| 1299 |
+
)
|
| 1300 |
+
|
| 1301 |
+
text_generate_btn = gr.Button(
|
| 1302 |
+
"🚀 Generate SVG / 生成SVG",
|
| 1303 |
+
variant="primary",
|
| 1304 |
+
size="lg",
|
| 1305 |
+
elem_classes=["primary-btn"]
|
| 1306 |
+
)
|
| 1307 |
+
|
| 1308 |
+
gr.Markdown("### 📝 Example Prompts (30)")
|
| 1309 |
+
gr.Examples(
|
| 1310 |
+
examples=[[text] for text in example_texts],
|
| 1311 |
+
inputs=[text_input],
|
| 1312 |
+
label=""
|
| 1313 |
+
)
|
| 1314 |
|
| 1315 |
+
with gr.Column(scale=2, min_width=500):
|
| 1316 |
+
gr.Markdown("### 🖼️ Generated SVG Candidates / 生成的SVG候选")
|
| 1317 |
+
gr.HTML("""
|
| 1318 |
+
<div style="background: #d4edda; padding: 10px; border-radius: 8px; margin-bottom: 10px;">
|
| 1319 |
+
<strong>💡 Pick the best from multiple candidates! / 从多个候选中选择最好的!</strong>
|
| 1320 |
+
</div>
|
| 1321 |
+
""")
|
| 1322 |
+
text_gallery = gr.HTML(
|
| 1323 |
+
value='<div style="text-align:center;color:#999;padding:50px;background:#fafafa;border-radius:12px;">Generated SVGs will appear here / 生成的SVG将显示在这里</div>'
|
| 1324 |
+
)
|
| 1325 |
+
|
| 1326 |
+
gr.Markdown("### 📝 SVG Code")
|
| 1327 |
+
text_svg_output = gr.Code(label="", language="html", lines=12, elem_classes=["code-output"])
|
| 1328 |
|
| 1329 |
+
text_generate_btn.click(
|
| 1330 |
+
fn=gradio_text_to_svg,
|
| 1331 |
+
inputs=[text_input, text_num_candidates, text_temperature, text_top_p,
|
| 1332 |
+
text_top_k, text_rep_penalty],
|
| 1333 |
+
outputs=[text_gallery, text_svg_output],
|
| 1334 |
+
queue=True
|
| 1335 |
+
)
|
| 1336 |
+
|
| 1337 |
+
# Footer
|
| 1338 |
+
gr.HTML("""
|
| 1339 |
+
<div class="footer">
|
| 1340 |
+
<p>Built with ❤️ using OmniSVG</p>
|
| 1341 |
+
<p style="color: #dc3545; font-weight: 600;">🔴 Remember: Generate 4-8 candidates and pick the best! / 记住:生成4-8个候选并选择最好的!</p>
|
| 1342 |
+
</div>
|
| 1343 |
+
""")
|
| 1344 |
+
|
| 1345 |
+
return demo
|
| 1346 |
|
|
|
|
|
|
|
| 1347 |
|
| 1348 |
if __name__ == "__main__":
|
| 1349 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 1350 |
+
|
| 1351 |
args = parse_args()
|
| 1352 |
+
|
| 1353 |
+
print("="*60)
|
| 1354 |
+
print("OmniSVG Generator - Gradio App")
|
| 1355 |
+
print("="*60)
|
| 1356 |
+
print(f"Model path: {args.model_path}")
|
| 1357 |
+
print(f"Weight path: {args.weight_path}")
|
| 1358 |
+
print(f"Device: {device}")
|
| 1359 |
+
print("="*60)
|
| 1360 |
+
|
| 1361 |
+
print("\nLoading models...")
|
| 1362 |
+
load_models(args.weight_path, args.model_path)
|
| 1363 |
+
print("Models loaded successfully!\n")
|
| 1364 |
+
|
| 1365 |
demo = create_interface()
|
| 1366 |
+
|
| 1367 |
+
demo.queue(default_concurrency_limit=1, max_size=20)
|
| 1368 |
+
|
| 1369 |
+
demo.launch(
|
| 1370 |
+
server_name=args.listen,
|
| 1371 |
+
server_port=args.port,
|
| 1372 |
+
share=args.share,
|
| 1373 |
+
debug=args.debug,
|
| 1374 |
+
)
|