Spaces:
Running on Zero
Running on Zero
File size: 28,777 Bytes
ef38b3f 763b90d 3d34a61 763b90d 7419580 ef38b3f 763b90d 06719c2 ef38b3f 06719c2 ef38b3f 06719c2 ef38b3f 06719c2 ef38b3f 06719c2 ef38b3f 06719c2 ef38b3f 06719c2 ef38b3f 06719c2 ef38b3f 06719c2 ef38b3f 06719c2 ef38b3f 763b90d 06719c2 763b90d ef38b3f d2144e5 763b90d e0ad7d8 ef38b3f f2bcc32 ef38b3f e810215 06719c2 ef38b3f f2bcc32 ef38b3f dd14618 f2bcc32 dd14618 f2bcc32 dd14618 f2bcc32 e0ad7d8 9d8fad6 e0ad7d8 ef38b3f 06719c2 e810215 763b90d e07e5ef 763b90d 3d34a61 6ad3257 3d34a61 6ad3257 3d34a61 2f4a24e 5f1c984 e0ad7d8 5f1c984 e0ad7d8 1386be2 5f1c984 e0ad7d8 1386be2 e0ad7d8 1386be2 e0ad7d8 7de112f 94f3429 e0ad7d8 7419580 e0ad7d8 7419580 e0ad7d8 e810215 763b90d 3d34a61 e07e5ef 2f4a24e e07e5ef 2f4a24e 3d34a61 2f4a24e 3d34a61 763b90d bc443a3 3d34a61 763b90d e07e5ef 3d34a61 763b90d 3d34a61 763b90d 3d34a61 763b90d e07e5ef 06719c2 763b90d 2f4a24e 3d34a61 cfa204a 63da0ec 2f4a24e cfa204a 63da0ec 2f4a24e 3d34a61 63da0ec 2f4a24e 6392c28 f8b70f9 6392c28 e07e5ef 06719c2 e07e5ef 6392c28 e07e5ef 763b90d ef38b3f e07e5ef 763b90d 7de112f 35e36c1 06719c2 ef38b3f 763b90d ef38b3f e07e5ef 06719c2 763b90d ef38b3f 6392c28 fa6eb3b 94f3429 e0ad7d8 94f3429 e0ad7d8 e07e5ef 6392c28 e07e5ef 06719c2 e07e5ef 06719c2 6392c28 ef38b3f 763b90d 63da0ec e07e5ef 763b90d ef38b3f 35e36c1 ef38b3f 06719c2 2f4a24e ef38b3f bc443a3 63da0ec fa6eb3b e0ad7d8 7de112f 06719c2 763b90d ef38b3f 763b90d ef38b3f e810215 763b90d 3aeaeb2 763b90d 3aeaeb2 763b90d 06719c2 763b90d 06719c2 763b90d 06719c2 763b90d 06719c2 3aeaeb2 f0b85a4 e07e5ef 763b90d e0ad7d8 e07e5ef e0ad7d8 e07e5ef e0ad7d8 6392c28 f0b85a4 6392c28 3aeaeb2 e0ad7d8 7de112f e0ad7d8 dd14618 e0ad7d8 e07e5ef e810215 e0ad7d8 7de112f e0ad7d8 763b90d ef38b3f e0ad7d8 2568940 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 | import os
import gc
import random
from typing import Iterable, List, Tuple
from huggingface_hub import login as hf_login
_hf_token = os.environ.get("HF_TOKEN")
if _hf_token:
hf_login(token=_hf_token)
import gradio as gr
import numpy as np
import spaces
import torch
from PIL import Image
from gradio.themes import Soft
from gradio.themes.utils import colors, fonts, sizes
# =========================================================
# THEME
# =========================================================
colors.fire_red = colors.Color(
name="fire_red",
c50="#FFF5F0",
c100="#FFE8DB",
c200="#FFD0B5",
c300="#FFB088",
c400="#FF8C5A",
c500="#FF6B35",
c600="#E8531F",
c700="#CC4317",
c800="#A63812",
c900="#80300F",
c950="#5C220A",
)
class FireRedTheme(Soft):
def __init__(
self,
*,
primary_hue: colors.Color | str = colors.gray,
secondary_hue: colors.Color | str = colors.fire_red,
neutral_hue: colors.Color | str = colors.slate,
text_size: sizes.Size | str = sizes.text_md,
font: fonts.Font | str | Iterable[fonts.Font | str] = (
fonts.GoogleFont("Inter"),
"system-ui",
"sans-serif",
),
font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
fonts.GoogleFont("JetBrains Mono"),
"ui-monospace",
"monospace",
),
):
super().__init__(
primary_hue=primary_hue,
secondary_hue=secondary_hue,
neutral_hue=neutral_hue,
text_size=text_size,
font=font,
font_mono=font_mono,
)
super().set(
body_background_fill="#f0f2f6",
body_background_fill_dark="*neutral_950",
background_fill_primary="white",
background_fill_primary_dark="*neutral_900",
block_background_fill="white",
block_background_fill_dark="*neutral_800",
block_border_width="1px",
block_border_color="*neutral_200",
block_border_color_dark="*neutral_700",
block_shadow="0 1px 4px rgba(0,0,0,0.05)",
block_shadow_dark="0 1px 4px rgba(0,0,0,0.25)",
block_title_text_weight="600",
block_label_background_fill="*neutral_50",
block_label_background_fill_dark="*neutral_800",
button_primary_text_color="white",
button_primary_text_color_hover="white",
button_primary_background_fill="linear-gradient(135deg, *secondary_500, *secondary_600)",
button_primary_background_fill_hover="linear-gradient(135deg, *secondary_600, *secondary_700)",
button_primary_background_fill_dark="linear-gradient(135deg, *secondary_500, *secondary_600)",
button_primary_background_fill_hover_dark="linear-gradient(135deg, *secondary_600, *secondary_700)",
button_primary_shadow="0 4px 14px rgba(232, 83, 31, 0.25)",
button_secondary_text_color="*secondary_700",
button_secondary_text_color_dark="*secondary_300",
button_secondary_background_fill="*secondary_50",
button_secondary_background_fill_hover="*secondary_100",
button_secondary_background_fill_dark="rgba(255, 107, 53, 0.1)",
button_secondary_background_fill_hover_dark="rgba(255, 107, 53, 0.2)",
button_large_padding="12px 24px",
slider_color="*secondary_500",
slider_color_dark="*secondary_500",
input_border_color_focus="*secondary_400",
input_border_color_focus_dark="*secondary_500",
color_accent_soft="*secondary_50",
color_accent_soft_dark="rgba(255, 107, 53, 0.15)",
)
theme = FireRedTheme()
# =========================================================
# MODEL
# =========================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("CUDA_VISIBLE_DEVICES =", os.environ.get("CUDA_VISIBLE_DEVICES"))
print("torch.__version__ =", torch.__version__)
print("device =", device)
from diffusers import FlowMatchEulerDiscreteScheduler, QwenImageEditPlusPipeline # noqa: E402,F401
from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3 # noqa: E402
from transformers import AutoModelForImageSegmentation # noqa: E402
from torchvision import transforms # noqa: E402
import torch.nn.functional as F # noqa: E402
dtype = torch.bfloat16
# ── FireRed 编辑模型(官方原生加载)──
pipe = QwenImageEditPlusPipeline.from_pretrained(
"FireRedTeam/FireRed-Image-Edit-1.1",
torch_dtype=dtype,
).to(device)
pipe.vae.enable_tiling()
pipe.vae.enable_slicing()
try:
pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
print("Flash Attention 3 Processor set successfully.")
except Exception as e:
print(f"Warning: Could not set FA3 processor: {e}")
# ── Lightning LoRA(4步加速,与 ComfyUI Rebels.json 完全一致)──
try:
pipe.load_lora_weights(
"Osrivers/Qwen-Image-Lightning-4steps-V2.0-bf16.safetensors",
weight_name="Qwen-Image-Lightning-4steps-V2.0-bf16.safetensors",
adapter_name="lightning",
)
pipe.set_adapters(["lightning"], adapter_weights=[1.0])
print("Lightning LoRA (4steps V2.0) loaded successfully.")
except Exception as e:
print(f"Warning: Could not load Lightning LoRA: {e}")
# ── RMBG 2.0 抠图模型 ──
rmbg = AutoModelForImageSegmentation.from_pretrained(
"briaai/RMBG-2.0",
trust_remote_code=True,
)
rmbg.to(device)
rmbg.eval()
MAX_SEED = np.iinfo(np.int32).max
DEFAULT_NEGATIVE_PROMPT = (
"worst quality, low quality, bad anatomy, bad hands, text, error, "
"missing fingers, extra digit, fewer digits, cropped, jpeg artifacts, "
"signature, watermark, username, blurry"
)
# =========================================================
# SAFE BUCKETS (~1MP each)
# =========================================================
SAFE_BUCKETS: List[Tuple[int, int]] = [
# 标准桶 (~1MP)
(1024, 1024),
(1184, 880),
(880, 1184),
(1392, 752),
(752, 1392),
(1568, 672),
(672, 1568),
# 宽图桶(综艺花字等长条形图)
(1920, 640), # 3:1
(1600, 400), # 4:1 ← Rebels.json 同款
(2048, 512), # 4:1
(1920, 384), # 5:1
(2560, 512), # 5:1
(2048, 336), # ~6:1
]
UPSCALE_SMALL_IMAGES = True
_rmbg_normalize = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
RMBG_SIZE = 1024
@spaces.GPU
def run_rmbg(pil_image: Image.Image) -> Image.Image:
"""用 RMBG-2.0 去除背景,与 ComfyUI comfyui-rmbg 完全一致:
squish 到 1024×1024,sigmoid 激活,bilinear resize 回原尺寸。
"""
orig_w, orig_h = pil_image.size
inp = _rmbg_normalize(pil_image.convert("RGB").resize((RMBG_SIZE, RMBG_SIZE), Image.LANCZOS))
inp = inp.unsqueeze(0).to(device)
with torch.no_grad():
outputs = rmbg(inp)
# 与 ComfyUI 完全一致:取最后输出层,sigmoid 激活
if isinstance(outputs, list):
result = outputs[-1].sigmoid().cpu()
elif isinstance(outputs, dict) and 'logits' in outputs:
result = outputs['logits'].sigmoid().cpu()
else:
result = outputs.sigmoid().cpu()
result = torch.clamp(result.squeeze(), 0, 1)
result = F.interpolate(result.unsqueeze(0).unsqueeze(0), size=(orig_h, orig_w), mode='bilinear').squeeze()
mask_pil = Image.fromarray((result.numpy() * 255).astype(np.uint8))
out = pil_image.convert("RGBA")
out.putalpha(mask_pil)
return out
def color_match_reinhard(source: Image.Image, result: Image.Image) -> Image.Image:
"""Reinhard RGB 均值/标准差色彩迁移:将 result 的色调对齐 source。"""
src = np.array(source.convert("RGB")).astype(np.float32)
res = np.array(result.convert("RGB")).astype(np.float32)
out = np.zeros_like(res)
for c in range(3):
s_mean, s_std = src[:, :, c].mean(), src[:, :, c].std()
r_mean, r_std = res[:, :, c].mean(), res[:, :, c].std()
ratio = s_std / (r_std + 1e-6) if r_std > 0.5 else 1.0
out[:, :, c] = (res[:, :, c] - r_mean) * ratio + s_mean
return Image.fromarray(np.clip(out, 0, 255).astype(np.uint8))
def remove_black_bg(pil_image: Image.Image, dark_thresh: int = 40) -> Image.Image:
"""仅去除与四边连通的黑色背景,保留文字内部的黑色元素。
用连通区域标记(flood fill)实现,不依赖 AI 模型。
"""
from scipy import ndimage as ndi
arr = np.array(pil_image.convert("RGB"))
dark_mask = np.all(arr <= dark_thresh, axis=2)
labeled, _ = ndi.label(dark_mask)
# 找所有与图片边缘相连的连通区域
border_labels = set()
border_labels.update(labeled[0, :].tolist())
border_labels.update(labeled[-1, :].tolist())
border_labels.update(labeled[:, 0].tolist())
border_labels.update(labeled[:, -1].tolist())
border_labels.discard(0) # 0 = 非黑色区域
bg_mask = np.zeros(arr.shape[:2], dtype=bool)
for lbl in border_labels:
bg_mask |= (labeled == lbl)
alpha = np.where(bg_mask, 0, 255).astype(np.uint8)
out = pil_image.convert("RGBA")
out.putalpha(Image.fromarray(alpha))
return out
def add_image_watermark(result: Image.Image, ref: Image.Image, size: int = 200, padding: int = 16) -> Image.Image:
result = result.copy().convert("RGBA")
thumb = ref.convert("RGBA")
thumb.thumbnail((size, size), Image.LANCZOS)
result.paste(thumb, (padding, padding), thumb)
return result.convert("RGB")
def paste_png_into_mask(editor_value: dict, png_image) -> Image.Image:
"""
从 ImageEditor 的 mask 层提取 bounding box,
把 PNG 等比缩放(最长边 = mask 最长边)后居中贴入。
"""
if editor_value is None:
raise gr.Error("⚠️ Please upload and draw a mask on the source image.")
if png_image is None:
raise gr.Error("⚠️ Please upload a PNG to place.")
# 取底图和 mask 层
background: Image.Image = editor_value.get("background")
layers: list = editor_value.get("layers", [])
if background is None:
raise gr.Error("⚠️ No source image found.")
if not layers:
raise gr.Error("⚠️ Please draw a mask area on the image first.")
if isinstance(background, np.ndarray):
background = Image.fromarray(background)
background = background.convert("RGBA")
mask_layer = layers[0]
if isinstance(mask_layer, np.ndarray):
mask_layer = Image.fromarray(mask_layer)
mask_layer = mask_layer.convert("RGBA")
# 从 mask 层的 alpha 通道找 bounding box
alpha = mask_layer.split()[3]
bbox = alpha.getbbox()
if bbox is None:
raise gr.Error("⚠️ Mask area is empty. Please draw on the image.")
x1, y1, x2, y2 = bbox
mask_w = x2 - x1
mask_h = y2 - y1
mask_longest = max(mask_w, mask_h)
# 加载 PNG
if isinstance(png_image, str):
png = Image.open(png_image).convert("RGBA")
else:
png = Image.fromarray(png_image).convert("RGBA")
png_w, png_h = png.size
png_longest = max(png_w, png_h)
# 等比缩放:最长边对齐 mask 最长边
scale = mask_longest / png_longest
new_w = max(1, int(png_w * scale))
new_h = max(1, int(png_h * scale))
png_resized = png.resize((new_w, new_h), Image.LANCZOS)
# 居中贴入 mask 区域
paste_x = x1 + (mask_w - new_w) // 2
paste_y = y1 + (mask_h - new_h) // 2
result = background.copy()
result.paste(png_resized, (paste_x, paste_y), png_resized)
return result.convert("RGB")
# =========================================================
# HELPERS
# =========================================================
def load_pil_image(item) -> Image.Image:
if item is None:
return None
if isinstance(item, Image.Image):
return item.convert("RGB")
if isinstance(item, str):
return Image.open(item).convert("RGB")
if isinstance(item, (tuple, list)):
path = item[0]
if isinstance(path, Image.Image):
return path.convert("RGB")
return Image.open(path).convert("RGB")
return Image.open(item.name).convert("RGB")
def pick_best_bucket(
orig_w: int,
orig_h: int,
buckets: List[Tuple[int, int]] = SAFE_BUCKETS,
allow_upscale: bool = UPSCALE_SMALL_IMAGES,
) -> Tuple[int, int]:
if orig_w <= 0 or orig_h <= 0:
return 1024, 1024
orig_ratio = orig_w / orig_h
def score(bucket):
bw, bh = bucket
ratio_diff = abs((bw / bh) - orig_ratio)
area_diff = abs((bw * bh) - (orig_w * orig_h))
return (ratio_diff, area_diff)
sorted_buckets = sorted(buckets, key=score)
if allow_upscale:
return sorted_buckets[0]
not_larger = [b for b in sorted_buckets if b[0] <= orig_w and b[1] <= orig_h]
return not_larger[0] if not_larger else sorted_buckets[0]
def prepare_images_before_pipe(
pil_images: List[Image.Image],
allow_upscale: bool = UPSCALE_SMALL_IMAGES,
divisible_by: int = 16,
) -> Tuple[List[Image.Image], int, int, tuple]:
"""准备图片:等比缩放 + 补边到最佳 bucket,保留原始比例。
返回 (processed_images, width, height, pad_info)
pad_info = (pad_left, pad_top, content_w, content_h) 用于推理后裁剪补边。
"""
if not pil_images:
raise ValueError("No input images.")
base_w, base_h = pil_images[0].size
# 选最佳 bucket(~1MP,比例最接近)
bucket_w, bucket_h = pick_best_bucket(base_w, base_h, SAFE_BUCKETS, allow_upscale)
# 等比缩放 fit 到 bucket 内(不拉伸)
scale = min(bucket_w / base_w, bucket_h / base_h)
content_w = max(divisible_by, round(base_w * scale))
content_h = max(divisible_by, round(base_h * scale))
# 居中补边到 bucket 尺寸
pad_left = (bucket_w - content_w) // 2
pad_top = (bucket_h - content_h) // 2
pad_info = (pad_left, pad_top, content_w, content_h)
processed = []
for img in pil_images:
# 等比缩放
resized = img.resize((content_w, content_h), Image.LANCZOS)
# 创建 bucket 大小的画布,边缘用镜像填充减少接缝
canvas = Image.new("RGB", (bucket_w, bucket_h), (0, 0, 0))
canvas.paste(resized, (pad_left, pad_top))
# 用边缘像素填充补边区域(比纯黑效果好)
import numpy as _np
arr = np.array(canvas)
res_arr = np.array(resized)
# 填充左右
if pad_left > 0:
left_col = res_arr[:, 0:1, :]
arr[pad_top:pad_top+content_h, :pad_left, :] = np.broadcast_to(left_col, (content_h, pad_left, 3))
right_start = pad_left + content_w
if right_start < bucket_w:
right_col = res_arr[:, -1:, :]
arr[pad_top:pad_top+content_h, right_start:, :] = np.broadcast_to(right_col, (content_h, bucket_w - right_start, 3))
# 填充上下
if pad_top > 0:
top_row = arr[pad_top:pad_top+1, :, :]
arr[:pad_top, :, :] = np.broadcast_to(top_row, (pad_top, bucket_w, 3))
bottom_start = pad_top + content_h
if bottom_start < bucket_h:
bottom_row = arr[bottom_start-1:bottom_start, :, :]
arr[bottom_start:, :, :] = np.broadcast_to(bottom_row, (bucket_h - bottom_start, bucket_w, 3))
processed.append(Image.fromarray(arr))
return processed, bucket_w, bucket_h, pad_info
def extract_pil_from_source(source) -> Image.Image:
"""从 gr.ImageEditor dict 或普通路径/PIL 中提取图片(使用 composite 保留涂色标注)。"""
if source is None:
return None
if isinstance(source, dict):
img = source.get("composite")
if img is None:
img = source.get("background")
if img is None:
return None
if isinstance(img, np.ndarray):
return Image.fromarray(img).convert("RGB")
return img.convert("RGB")
return load_pil_image(source)
def format_info(seed_val, source_img, ref_img):
lines = [f"**Seed:** `{int(seed_val)}`"]
for label, img in [("Source", source_img), ("Reference", ref_img)]:
if img is None:
continue
try:
pil = extract_pil_from_source(img) if label == "Source" else load_pil_image(img)
ow, oh = pil.size
nw, nh = pick_best_bucket(ow, oh, SAFE_BUCKETS, UPSCALE_SMALL_IMAGES)
lines.append(
f"\n**{label}:** {ow}×{oh} → **{nw}×{nh}** "
f"(ratio {ow/oh:.3f} → {nw/nh:.3f})"
)
except Exception:
pass
return "\n\n".join(lines)
# =========================================================
# INFERENCE
# =========================================================
@spaces.GPU
def infer(
source_image,
ref_image,
prompt,
negative_prompt,
seed,
randomize_seed,
guidance_scale,
steps,
color_match,
out_width=0,
out_height=0,
progress=gr.Progress(track_tqdm=True),
):
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
if source_image is None:
raise gr.Error("⚠️ Please upload a source image.")
if not prompt or not prompt.strip():
raise gr.Error("⚠️ Please enter an edit prompt.")
# 提取原图(兼容 ImageEditor dict 和普通路径)
try:
src_pil = extract_pil_from_source(source_image)
except Exception as e:
raise gr.Error(f"⚠️ Could not load source image: {e}")
if src_pil is None:
raise gr.Error("⚠️ Please upload a source image.")
# 记录原始尺寸,推理后 resize 回来,避免 16 对齐导致裁剪
orig_size = src_pil.size # (w, h)
# ── 路由:抠图 ──
if "抠" in prompt:
if "黑底" in prompt:
# 黑底花字:连通区域去除外围黑色,保留文字内部黑色
result = remove_black_bg(src_pil)
else:
# 普通抠图:RMBG 2.0 语义分割
result = run_rmbg(src_pil)
return result, seed
# 收集图片:原图必须,参考图可选
pil_images = [src_pil]
if ref_image is not None:
try:
pil_images.append(load_pil_image(ref_image))
except Exception as e:
print(f"Warning: could not load reference image: {e}")
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device=device).manual_seed(int(seed))
processed_images, width, height, pad_info = prepare_images_before_pipe(
pil_images, allow_upscale=UPSCALE_SMALL_IMAGES
)
# 显式指定输出尺寸(对齐 ComfyUI EmptyLatentImage 行为)
if out_width > 0:
width = (out_width // 16) * 16
if out_height > 0:
height = (out_height // 16) * 16
try:
result = pipe(
image=processed_images,
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_inference_steps=steps,
generator=generator,
true_cfg_scale=guidance_scale,
).images[0]
# ── 裁掉补边,还原到原始比例内容区域 ──
pad_left, pad_top, content_w, content_h = pad_info
if pad_left > 0 or pad_top > 0 or content_w < width or content_h < height:
result = result.crop((pad_left, pad_top, pad_left + content_w, pad_top + content_h))
# ── 还原到原始尺寸 ──
if result.size != orig_size:
result = result.resize(orig_size, Image.LANCZOS)
if ref_image is not None and len(pil_images) > 1:
result = add_image_watermark(result, pil_images[1])
if color_match:
# 用原图背景(无笔迹)作为色彩参考
if isinstance(source_image, dict):
bg = source_image.get("background")
if bg is not None:
ref_pil = Image.fromarray(bg).convert("RGB") if isinstance(bg, np.ndarray) else bg.convert("RGB")
else:
ref_pil = src_pil
else:
ref_pil = src_pil
ref_pil_resized = ref_pil.resize(result.size, Image.LANCZOS)
result = color_match_reinhard(ref_pil_resized, result)
return result, seed
finally:
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# =========================================================
# UI
# =========================================================
# JS:等 ImageEditor 渲染完后,把绝对定位的工具栏改为相对定位,
# 使其不再悬浮覆盖画布(CSS 选择器会被 Svelte 作用域哈希阻挡,
# 所以用 JS 通过 getComputedStyle 精确检测并强制修改)
_FIX_TOOLBAR_JS = """
() => {
const setup = (ed) => {
if (ed.dataset.toggleReady) return;
// 找工具栏元素(Gradio/Svelte 会给 class 加哈希,用 includes 匹配)
const toolbar = Array.from(ed.querySelectorAll('*')).find(el => {
const cls = el.getAttribute('class') || '';
return cls.includes('toolbar') || cls.includes('tool-bar');
});
if (!toolbar) return;
ed.dataset.toggleReady = '1';
// 插入切换按钮,放在 toolbar 的父容器第一位
const btn = document.createElement('button');
btn.className = 'toolbar-toggle-btn';
btn.textContent = '🎨 隐藏画笔工具栏';
let hidden = false;
btn.onclick = () => {
hidden = !hidden;
// 用 visibility 而非 display,避免画布区域跳动
toolbar.style.visibility = hidden ? 'hidden' : '';
toolbar.style.pointerEvents = hidden ? 'none' : '';
btn.textContent = hidden ? '🎨 显示画笔工具栏' : '🎨 隐藏画笔工具栏';
};
toolbar.parentNode.insertBefore(btn, toolbar);
};
const mo = new MutationObserver(() => {
document.querySelectorAll('.src-editor').forEach(setup);
});
mo.observe(document.body, { childList: true, subtree: true });
setTimeout(() => document.querySelectorAll('.src-editor').forEach(setup), 1000);
}
"""
with gr.Blocks(
theme=theme,
js=_FIX_TOOLBAR_JS,
css="""
.gradio-container {
max-width: 1400px !important;
margin: 0 auto;
padding-top: 20px;
}
.hero {
text-align: center;
padding: 24px 0 12px 0;
}
.hero h1 {
font-size: 2.2rem;
font-weight: 800;
margin-bottom: 8px;
}
.hero p {
font-size: 1rem;
color: #666;
margin-bottom: 0;
}
/* 工具栏隐藏时,隐藏按钮仍可点击 */
.toolbar-toggle-btn {
display: block;
width: 100%;
padding: 4px 10px;
margin-bottom: 2px;
background: #f0f0f0;
border: 1px solid #ddd;
border-radius: 4px;
font-size: 12px;
cursor: pointer;
text-align: left;
color: #555;
}
""",
) as demo:
gr.HTML("""
<div class="hero">
<h1>🔥 FireRed Image Edit 1.1 Fast</h1>
</div>
""")
with gr.Tabs():
# ══════════════════════════════════════════════════════
# Tab 1: AI 编辑
# ══════════════════════════════════════════════════════
with gr.Tab("AI Edit"):
with gr.Row():
with gr.Column(scale=1):
source_input = gr.ImageEditor(
label="Source Image — 可用画笔标注区域(红/绿/蓝等),提示词中引用颜色",
elem_classes=["src-editor"],
brush=gr.Brush(
colors=["#FF0000", "#00CC00", "#0066FF", "#FFFF00", "#FF00FF", "#FFFFFF"],
color_mode="defaults",
),
)
gr.Markdown(
"<small>🔴红 🟢绿 🔵蓝 🟡黄 🟣紫 ⬜白 — 画好后提示词写:*去掉红色标注的区域* 等</small>"
)
with gr.Row():
ref_input = gr.Image(
label="Reference Image(参考图,可选)",
type="filepath",
sources=["upload", "clipboard"],
)
prompt_input = gr.Textbox(
label="Prompt",
placeholder="Describe how you want to edit the image...",
lines=4,
)
negative_prompt_input = gr.Textbox(
label="Negative Prompt",
value=DEFAULT_NEGATIVE_PROMPT,
lines=3,
)
color_match_input = gr.Checkbox(label="Color Match — 色彩对齐原图", value=True)
with gr.Accordion("Advanced Settings", open=False):
seed_input = gr.Slider(
label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0,
)
randomize_seed_input = gr.Checkbox(label="Randomize Seed", value=True)
guidance_scale_input = gr.Slider(
label="Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=1.0,
)
steps_input = gr.Slider(
label="Inference Steps", minimum=1, maximum=50, step=1, value=4,
)
run_button = gr.Button("Generate", variant="primary")
info_markdown = gr.Markdown()
with gr.Column(scale=1):
output_image = gr.Image(label="Result", type="pil")
for trigger in [source_input, ref_input, seed_input]:
trigger.change(
fn=format_info,
inputs=[seed_input, source_input, ref_input],
outputs=[info_markdown],
)
run_button.click(
fn=infer,
inputs=[
source_input, ref_input, prompt_input, negative_prompt_input,
seed_input, randomize_seed_input, guidance_scale_input, steps_input,
color_match_input,
],
outputs=[output_image, seed_input],
).then(
fn=format_info,
inputs=[seed_input, source_input, ref_input],
outputs=[info_markdown],
)
# ══════════════════════════════════════════════════════
# Tab 2: PNG 贴图(画 mask → 等比贴入)
# ══════════════════════════════════════════════════════
with gr.Tab("PNG Placement"):
gr.Markdown("**用法:** 上传底图后在图上涂抹出放置区域,再上传 PNG,点击 Apply。PNG 会等比缩放,最长边对齐 mask 最长边,居中贴入。")
with gr.Row():
with gr.Column(scale=1):
mask_editor = gr.ImageEditor(
label="Source Image — 在图上涂抹出放置区域",
brush=gr.Brush(colors=["#FF6B35"], color_mode="fixed"),
)
png_input = gr.Image(
label="PNG to place(支持透明背景)",
type="numpy",
sources=["upload", "clipboard"],
image_mode="RGBA",
)
apply_button = gr.Button("Apply", variant="primary")
with gr.Column(scale=1):
placement_output = gr.Image(label="Result", type="pil")
apply_button.click(
fn=paste_png_into_mask,
inputs=[mask_editor, png_input],
outputs=[placement_output],
)
if __name__ == "__main__":
demo.launch()
|