import os import gc import random from typing import Iterable, List, Tuple from huggingface_hub import login as hf_login _hf_token = os.environ.get("HF_TOKEN") if _hf_token: hf_login(token=_hf_token) import gradio as gr import numpy as np import spaces import torch from PIL import Image from gradio.themes import Soft from gradio.themes.utils import colors, fonts, sizes # ========================================================= # THEME # ========================================================= colors.fire_red = colors.Color( name="fire_red", c50="#FFF5F0", c100="#FFE8DB", c200="#FFD0B5", c300="#FFB088", c400="#FF8C5A", c500="#FF6B35", c600="#E8531F", c700="#CC4317", c800="#A63812", c900="#80300F", c950="#5C220A", ) class FireRedTheme(Soft): def __init__( self, *, primary_hue: colors.Color | str = colors.gray, secondary_hue: colors.Color | str = colors.fire_red, neutral_hue: colors.Color | str = colors.slate, text_size: sizes.Size | str = sizes.text_md, font: fonts.Font | str | Iterable[fonts.Font | str] = ( fonts.GoogleFont("Inter"), "system-ui", "sans-serif", ), font_mono: fonts.Font | str | Iterable[fonts.Font | str] = ( fonts.GoogleFont("JetBrains Mono"), "ui-monospace", "monospace", ), ): super().__init__( primary_hue=primary_hue, secondary_hue=secondary_hue, neutral_hue=neutral_hue, text_size=text_size, font=font, font_mono=font_mono, ) super().set( body_background_fill="#f0f2f6", body_background_fill_dark="*neutral_950", background_fill_primary="white", background_fill_primary_dark="*neutral_900", block_background_fill="white", block_background_fill_dark="*neutral_800", block_border_width="1px", block_border_color="*neutral_200", block_border_color_dark="*neutral_700", block_shadow="0 1px 4px rgba(0,0,0,0.05)", block_shadow_dark="0 1px 4px rgba(0,0,0,0.25)", block_title_text_weight="600", block_label_background_fill="*neutral_50", block_label_background_fill_dark="*neutral_800", button_primary_text_color="white", button_primary_text_color_hover="white", button_primary_background_fill="linear-gradient(135deg, *secondary_500, *secondary_600)", button_primary_background_fill_hover="linear-gradient(135deg, *secondary_600, *secondary_700)", button_primary_background_fill_dark="linear-gradient(135deg, *secondary_500, *secondary_600)", button_primary_background_fill_hover_dark="linear-gradient(135deg, *secondary_600, *secondary_700)", button_primary_shadow="0 4px 14px rgba(232, 83, 31, 0.25)", button_secondary_text_color="*secondary_700", button_secondary_text_color_dark="*secondary_300", button_secondary_background_fill="*secondary_50", button_secondary_background_fill_hover="*secondary_100", button_secondary_background_fill_dark="rgba(255, 107, 53, 0.1)", button_secondary_background_fill_hover_dark="rgba(255, 107, 53, 0.2)", button_large_padding="12px 24px", slider_color="*secondary_500", slider_color_dark="*secondary_500", input_border_color_focus="*secondary_400", input_border_color_focus_dark="*secondary_500", color_accent_soft="*secondary_50", color_accent_soft_dark="rgba(255, 107, 53, 0.15)", ) theme = FireRedTheme() # ========================================================= # MODEL # ========================================================= device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("CUDA_VISIBLE_DEVICES =", os.environ.get("CUDA_VISIBLE_DEVICES")) print("torch.__version__ =", torch.__version__) print("device =", device) from diffusers import FlowMatchEulerDiscreteScheduler, QwenImageEditPlusPipeline # noqa: E402,F401 from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3 # noqa: E402 from transformers import AutoModelForImageSegmentation # noqa: E402 from torchvision import transforms # noqa: E402 import torch.nn.functional as F # noqa: E402 dtype = torch.bfloat16 # ── FireRed 编辑模型(官方原生加载)── pipe = QwenImageEditPlusPipeline.from_pretrained( "FireRedTeam/FireRed-Image-Edit-1.1", torch_dtype=dtype, ).to(device) pipe.vae.enable_tiling() pipe.vae.enable_slicing() try: pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3()) print("Flash Attention 3 Processor set successfully.") except Exception as e: print(f"Warning: Could not set FA3 processor: {e}") # ── Lightning LoRA(4步加速,与 ComfyUI Rebels.json 完全一致)── try: pipe.load_lora_weights( "Osrivers/Qwen-Image-Lightning-4steps-V2.0-bf16.safetensors", weight_name="Qwen-Image-Lightning-4steps-V2.0-bf16.safetensors", adapter_name="lightning", ) pipe.set_adapters(["lightning"], adapter_weights=[1.0]) print("Lightning LoRA (4steps V2.0) loaded successfully.") except Exception as e: print(f"Warning: Could not load Lightning LoRA: {e}") # ── RMBG 2.0 抠图模型 ── rmbg = AutoModelForImageSegmentation.from_pretrained( "briaai/RMBG-2.0", trust_remote_code=True, ) rmbg.to(device) rmbg.eval() MAX_SEED = np.iinfo(np.int32).max DEFAULT_NEGATIVE_PROMPT = ( "worst quality, low quality, bad anatomy, bad hands, text, error, " "missing fingers, extra digit, fewer digits, cropped, jpeg artifacts, " "signature, watermark, username, blurry" ) # ========================================================= # SAFE BUCKETS (~1MP each) # ========================================================= SAFE_BUCKETS: List[Tuple[int, int]] = [ # 标准桶 (~1MP) (1024, 1024), (1184, 880), (880, 1184), (1392, 752), (752, 1392), (1568, 672), (672, 1568), # 宽图桶(综艺花字等长条形图) (1920, 640), # 3:1 (1600, 400), # 4:1 ← Rebels.json 同款 (2048, 512), # 4:1 (1920, 384), # 5:1 (2560, 512), # 5:1 (2048, 336), # ~6:1 ] UPSCALE_SMALL_IMAGES = True _rmbg_normalize = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) RMBG_SIZE = 1024 @spaces.GPU def run_rmbg(pil_image: Image.Image) -> Image.Image: """用 RMBG-2.0 去除背景,与 ComfyUI comfyui-rmbg 完全一致: squish 到 1024×1024,sigmoid 激活,bilinear resize 回原尺寸。 """ orig_w, orig_h = pil_image.size inp = _rmbg_normalize(pil_image.convert("RGB").resize((RMBG_SIZE, RMBG_SIZE), Image.LANCZOS)) inp = inp.unsqueeze(0).to(device) with torch.no_grad(): outputs = rmbg(inp) # 与 ComfyUI 完全一致:取最后输出层,sigmoid 激活 if isinstance(outputs, list): result = outputs[-1].sigmoid().cpu() elif isinstance(outputs, dict) and 'logits' in outputs: result = outputs['logits'].sigmoid().cpu() else: result = outputs.sigmoid().cpu() result = torch.clamp(result.squeeze(), 0, 1) result = F.interpolate(result.unsqueeze(0).unsqueeze(0), size=(orig_h, orig_w), mode='bilinear').squeeze() mask_pil = Image.fromarray((result.numpy() * 255).astype(np.uint8)) out = pil_image.convert("RGBA") out.putalpha(mask_pil) return out def color_match_reinhard(source: Image.Image, result: Image.Image) -> Image.Image: """Reinhard RGB 均值/标准差色彩迁移:将 result 的色调对齐 source。""" src = np.array(source.convert("RGB")).astype(np.float32) res = np.array(result.convert("RGB")).astype(np.float32) out = np.zeros_like(res) for c in range(3): s_mean, s_std = src[:, :, c].mean(), src[:, :, c].std() r_mean, r_std = res[:, :, c].mean(), res[:, :, c].std() ratio = s_std / (r_std + 1e-6) if r_std > 0.5 else 1.0 out[:, :, c] = (res[:, :, c] - r_mean) * ratio + s_mean return Image.fromarray(np.clip(out, 0, 255).astype(np.uint8)) def remove_black_bg(pil_image: Image.Image, dark_thresh: int = 40) -> Image.Image: """仅去除与四边连通的黑色背景,保留文字内部的黑色元素。 用连通区域标记(flood fill)实现,不依赖 AI 模型。 """ from scipy import ndimage as ndi arr = np.array(pil_image.convert("RGB")) dark_mask = np.all(arr <= dark_thresh, axis=2) labeled, _ = ndi.label(dark_mask) # 找所有与图片边缘相连的连通区域 border_labels = set() border_labels.update(labeled[0, :].tolist()) border_labels.update(labeled[-1, :].tolist()) border_labels.update(labeled[:, 0].tolist()) border_labels.update(labeled[:, -1].tolist()) border_labels.discard(0) # 0 = 非黑色区域 bg_mask = np.zeros(arr.shape[:2], dtype=bool) for lbl in border_labels: bg_mask |= (labeled == lbl) alpha = np.where(bg_mask, 0, 255).astype(np.uint8) out = pil_image.convert("RGBA") out.putalpha(Image.fromarray(alpha)) return out def add_image_watermark(result: Image.Image, ref: Image.Image, size: int = 200, padding: int = 16) -> Image.Image: result = result.copy().convert("RGBA") thumb = ref.convert("RGBA") thumb.thumbnail((size, size), Image.LANCZOS) result.paste(thumb, (padding, padding), thumb) return result.convert("RGB") def paste_png_into_mask(editor_value: dict, png_image) -> Image.Image: """ 从 ImageEditor 的 mask 层提取 bounding box, 把 PNG 等比缩放(最长边 = mask 最长边)后居中贴入。 """ if editor_value is None: raise gr.Error("⚠️ Please upload and draw a mask on the source image.") if png_image is None: raise gr.Error("⚠️ Please upload a PNG to place.") # 取底图和 mask 层 background: Image.Image = editor_value.get("background") layers: list = editor_value.get("layers", []) if background is None: raise gr.Error("⚠️ No source image found.") if not layers: raise gr.Error("⚠️ Please draw a mask area on the image first.") if isinstance(background, np.ndarray): background = Image.fromarray(background) background = background.convert("RGBA") mask_layer = layers[0] if isinstance(mask_layer, np.ndarray): mask_layer = Image.fromarray(mask_layer) mask_layer = mask_layer.convert("RGBA") # 从 mask 层的 alpha 通道找 bounding box alpha = mask_layer.split()[3] bbox = alpha.getbbox() if bbox is None: raise gr.Error("⚠️ Mask area is empty. Please draw on the image.") x1, y1, x2, y2 = bbox mask_w = x2 - x1 mask_h = y2 - y1 mask_longest = max(mask_w, mask_h) # 加载 PNG if isinstance(png_image, str): png = Image.open(png_image).convert("RGBA") else: png = Image.fromarray(png_image).convert("RGBA") png_w, png_h = png.size png_longest = max(png_w, png_h) # 等比缩放:最长边对齐 mask 最长边 scale = mask_longest / png_longest new_w = max(1, int(png_w * scale)) new_h = max(1, int(png_h * scale)) png_resized = png.resize((new_w, new_h), Image.LANCZOS) # 居中贴入 mask 区域 paste_x = x1 + (mask_w - new_w) // 2 paste_y = y1 + (mask_h - new_h) // 2 result = background.copy() result.paste(png_resized, (paste_x, paste_y), png_resized) return result.convert("RGB") # ========================================================= # HELPERS # ========================================================= def load_pil_image(item) -> Image.Image: if item is None: return None if isinstance(item, Image.Image): return item.convert("RGB") if isinstance(item, str): return Image.open(item).convert("RGB") if isinstance(item, (tuple, list)): path = item[0] if isinstance(path, Image.Image): return path.convert("RGB") return Image.open(path).convert("RGB") return Image.open(item.name).convert("RGB") def pick_best_bucket( orig_w: int, orig_h: int, buckets: List[Tuple[int, int]] = SAFE_BUCKETS, allow_upscale: bool = UPSCALE_SMALL_IMAGES, ) -> Tuple[int, int]: if orig_w <= 0 or orig_h <= 0: return 1024, 1024 orig_ratio = orig_w / orig_h def score(bucket): bw, bh = bucket ratio_diff = abs((bw / bh) - orig_ratio) area_diff = abs((bw * bh) - (orig_w * orig_h)) return (ratio_diff, area_diff) sorted_buckets = sorted(buckets, key=score) if allow_upscale: return sorted_buckets[0] not_larger = [b for b in sorted_buckets if b[0] <= orig_w and b[1] <= orig_h] return not_larger[0] if not_larger else sorted_buckets[0] def prepare_images_before_pipe( pil_images: List[Image.Image], allow_upscale: bool = UPSCALE_SMALL_IMAGES, divisible_by: int = 16, ) -> Tuple[List[Image.Image], int, int, tuple]: """准备图片:等比缩放 + 补边到最佳 bucket,保留原始比例。 返回 (processed_images, width, height, pad_info) pad_info = (pad_left, pad_top, content_w, content_h) 用于推理后裁剪补边。 """ if not pil_images: raise ValueError("No input images.") base_w, base_h = pil_images[0].size # 选最佳 bucket(~1MP,比例最接近) bucket_w, bucket_h = pick_best_bucket(base_w, base_h, SAFE_BUCKETS, allow_upscale) # 等比缩放 fit 到 bucket 内(不拉伸) scale = min(bucket_w / base_w, bucket_h / base_h) content_w = max(divisible_by, round(base_w * scale)) content_h = max(divisible_by, round(base_h * scale)) # 居中补边到 bucket 尺寸 pad_left = (bucket_w - content_w) // 2 pad_top = (bucket_h - content_h) // 2 pad_info = (pad_left, pad_top, content_w, content_h) processed = [] for img in pil_images: # 等比缩放 resized = img.resize((content_w, content_h), Image.LANCZOS) # 创建 bucket 大小的画布,边缘用镜像填充减少接缝 canvas = Image.new("RGB", (bucket_w, bucket_h), (0, 0, 0)) canvas.paste(resized, (pad_left, pad_top)) # 用边缘像素填充补边区域(比纯黑效果好) import numpy as _np arr = np.array(canvas) res_arr = np.array(resized) # 填充左右 if pad_left > 0: left_col = res_arr[:, 0:1, :] arr[pad_top:pad_top+content_h, :pad_left, :] = np.broadcast_to(left_col, (content_h, pad_left, 3)) right_start = pad_left + content_w if right_start < bucket_w: right_col = res_arr[:, -1:, :] arr[pad_top:pad_top+content_h, right_start:, :] = np.broadcast_to(right_col, (content_h, bucket_w - right_start, 3)) # 填充上下 if pad_top > 0: top_row = arr[pad_top:pad_top+1, :, :] arr[:pad_top, :, :] = np.broadcast_to(top_row, (pad_top, bucket_w, 3)) bottom_start = pad_top + content_h if bottom_start < bucket_h: bottom_row = arr[bottom_start-1:bottom_start, :, :] arr[bottom_start:, :, :] = np.broadcast_to(bottom_row, (bucket_h - bottom_start, bucket_w, 3)) processed.append(Image.fromarray(arr)) return processed, bucket_w, bucket_h, pad_info def extract_pil_from_source(source) -> Image.Image: """从 gr.ImageEditor dict 或普通路径/PIL 中提取图片(使用 composite 保留涂色标注)。""" if source is None: return None if isinstance(source, dict): img = source.get("composite") if img is None: img = source.get("background") if img is None: return None if isinstance(img, np.ndarray): return Image.fromarray(img).convert("RGB") return img.convert("RGB") return load_pil_image(source) def format_info(seed_val, source_img, ref_img): lines = [f"**Seed:** `{int(seed_val)}`"] for label, img in [("Source", source_img), ("Reference", ref_img)]: if img is None: continue try: pil = extract_pil_from_source(img) if label == "Source" else load_pil_image(img) ow, oh = pil.size nw, nh = pick_best_bucket(ow, oh, SAFE_BUCKETS, UPSCALE_SMALL_IMAGES) lines.append( f"\n**{label}:** {ow}×{oh} → **{nw}×{nh}** " f"(ratio {ow/oh:.3f} → {nw/nh:.3f})" ) except Exception: pass return "\n\n".join(lines) # ========================================================= # INFERENCE # ========================================================= @spaces.GPU def infer( source_image, ref_image, prompt, negative_prompt, seed, randomize_seed, guidance_scale, steps, color_match, out_width=0, out_height=0, progress=gr.Progress(track_tqdm=True), ): gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() if source_image is None: raise gr.Error("⚠️ Please upload a source image.") if not prompt or not prompt.strip(): raise gr.Error("⚠️ Please enter an edit prompt.") # 提取原图(兼容 ImageEditor dict 和普通路径) try: src_pil = extract_pil_from_source(source_image) except Exception as e: raise gr.Error(f"⚠️ Could not load source image: {e}") if src_pil is None: raise gr.Error("⚠️ Please upload a source image.") # 记录原始尺寸,推理后 resize 回来,避免 16 对齐导致裁剪 orig_size = src_pil.size # (w, h) # ── 路由:抠图 ── if "抠" in prompt: if "黑底" in prompt: # 黑底花字:连通区域去除外围黑色,保留文字内部黑色 result = remove_black_bg(src_pil) else: # 普通抠图:RMBG 2.0 语义分割 result = run_rmbg(src_pil) return result, seed # 收集图片:原图必须,参考图可选 pil_images = [src_pil] if ref_image is not None: try: pil_images.append(load_pil_image(ref_image)) except Exception as e: print(f"Warning: could not load reference image: {e}") if randomize_seed: seed = random.randint(0, MAX_SEED) generator = torch.Generator(device=device).manual_seed(int(seed)) processed_images, width, height, pad_info = prepare_images_before_pipe( pil_images, allow_upscale=UPSCALE_SMALL_IMAGES ) # 显式指定输出尺寸(对齐 ComfyUI EmptyLatentImage 行为) if out_width > 0: width = (out_width // 16) * 16 if out_height > 0: height = (out_height // 16) * 16 try: result = pipe( image=processed_images, prompt=prompt, negative_prompt=negative_prompt, height=height, width=width, num_inference_steps=steps, generator=generator, true_cfg_scale=guidance_scale, ).images[0] # ── 裁掉补边,还原到原始比例内容区域 ── pad_left, pad_top, content_w, content_h = pad_info if pad_left > 0 or pad_top > 0 or content_w < width or content_h < height: result = result.crop((pad_left, pad_top, pad_left + content_w, pad_top + content_h)) # ── 还原到原始尺寸 ── if result.size != orig_size: result = result.resize(orig_size, Image.LANCZOS) if ref_image is not None and len(pil_images) > 1: result = add_image_watermark(result, pil_images[1]) if color_match: # 用原图背景(无笔迹)作为色彩参考 if isinstance(source_image, dict): bg = source_image.get("background") if bg is not None: ref_pil = Image.fromarray(bg).convert("RGB") if isinstance(bg, np.ndarray) else bg.convert("RGB") else: ref_pil = src_pil else: ref_pil = src_pil ref_pil_resized = ref_pil.resize(result.size, Image.LANCZOS) result = color_match_reinhard(ref_pil_resized, result) return result, seed finally: gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() # ========================================================= # UI # ========================================================= # JS:等 ImageEditor 渲染完后,把绝对定位的工具栏改为相对定位, # 使其不再悬浮覆盖画布(CSS 选择器会被 Svelte 作用域哈希阻挡, # 所以用 JS 通过 getComputedStyle 精确检测并强制修改) _FIX_TOOLBAR_JS = """ () => { const setup = (ed) => { if (ed.dataset.toggleReady) return; // 找工具栏元素(Gradio/Svelte 会给 class 加哈希,用 includes 匹配) const toolbar = Array.from(ed.querySelectorAll('*')).find(el => { const cls = el.getAttribute('class') || ''; return cls.includes('toolbar') || cls.includes('tool-bar'); }); if (!toolbar) return; ed.dataset.toggleReady = '1'; // 插入切换按钮,放在 toolbar 的父容器第一位 const btn = document.createElement('button'); btn.className = 'toolbar-toggle-btn'; btn.textContent = '🎨 隐藏画笔工具栏'; let hidden = false; btn.onclick = () => { hidden = !hidden; // 用 visibility 而非 display,避免画布区域跳动 toolbar.style.visibility = hidden ? 'hidden' : ''; toolbar.style.pointerEvents = hidden ? 'none' : ''; btn.textContent = hidden ? '🎨 显示画笔工具栏' : '🎨 隐藏画笔工具栏'; }; toolbar.parentNode.insertBefore(btn, toolbar); }; const mo = new MutationObserver(() => { document.querySelectorAll('.src-editor').forEach(setup); }); mo.observe(document.body, { childList: true, subtree: true }); setTimeout(() => document.querySelectorAll('.src-editor').forEach(setup), 1000); } """ with gr.Blocks( theme=theme, js=_FIX_TOOLBAR_JS, css=""" .gradio-container { max-width: 1400px !important; margin: 0 auto; padding-top: 20px; } .hero { text-align: center; padding: 24px 0 12px 0; } .hero h1 { font-size: 2.2rem; font-weight: 800; margin-bottom: 8px; } .hero p { font-size: 1rem; color: #666; margin-bottom: 0; } /* 工具栏隐藏时,隐藏按钮仍可点击 */ .toolbar-toggle-btn { display: block; width: 100%; padding: 4px 10px; margin-bottom: 2px; background: #f0f0f0; border: 1px solid #ddd; border-radius: 4px; font-size: 12px; cursor: pointer; text-align: left; color: #555; } """, ) as demo: gr.HTML("""

🔥 FireRed Image Edit 1.1 Fast

""") with gr.Tabs(): # ══════════════════════════════════════════════════════ # Tab 1: AI 编辑 # ══════════════════════════════════════════════════════ with gr.Tab("AI Edit"): with gr.Row(): with gr.Column(scale=1): source_input = gr.ImageEditor( label="Source Image — 可用画笔标注区域(红/绿/蓝等),提示词中引用颜色", elem_classes=["src-editor"], brush=gr.Brush( colors=["#FF0000", "#00CC00", "#0066FF", "#FFFF00", "#FF00FF", "#FFFFFF"], color_mode="defaults", ), ) gr.Markdown( "🔴红 🟢绿 🔵蓝 🟡黄 🟣紫 ⬜白 — 画好后提示词写:*去掉红色标注的区域* 等" ) with gr.Row(): ref_input = gr.Image( label="Reference Image(参考图,可选)", type="filepath", sources=["upload", "clipboard"], ) prompt_input = gr.Textbox( label="Prompt", placeholder="Describe how you want to edit the image...", lines=4, ) negative_prompt_input = gr.Textbox( label="Negative Prompt", value=DEFAULT_NEGATIVE_PROMPT, lines=3, ) color_match_input = gr.Checkbox(label="Color Match — 色彩对齐原图", value=True) with gr.Accordion("Advanced Settings", open=False): seed_input = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, ) randomize_seed_input = gr.Checkbox(label="Randomize Seed", value=True) guidance_scale_input = gr.Slider( label="Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=1.0, ) steps_input = gr.Slider( label="Inference Steps", minimum=1, maximum=50, step=1, value=4, ) run_button = gr.Button("Generate", variant="primary") info_markdown = gr.Markdown() with gr.Column(scale=1): output_image = gr.Image(label="Result", type="pil") for trigger in [source_input, ref_input, seed_input]: trigger.change( fn=format_info, inputs=[seed_input, source_input, ref_input], outputs=[info_markdown], ) run_button.click( fn=infer, inputs=[ source_input, ref_input, prompt_input, negative_prompt_input, seed_input, randomize_seed_input, guidance_scale_input, steps_input, color_match_input, ], outputs=[output_image, seed_input], ).then( fn=format_info, inputs=[seed_input, source_input, ref_input], outputs=[info_markdown], ) # ══════════════════════════════════════════════════════ # Tab 2: PNG 贴图(画 mask → 等比贴入) # ══════════════════════════════════════════════════════ with gr.Tab("PNG Placement"): gr.Markdown("**用法:** 上传底图后在图上涂抹出放置区域,再上传 PNG,点击 Apply。PNG 会等比缩放,最长边对齐 mask 最长边,居中贴入。") with gr.Row(): with gr.Column(scale=1): mask_editor = gr.ImageEditor( label="Source Image — 在图上涂抹出放置区域", brush=gr.Brush(colors=["#FF6B35"], color_mode="fixed"), ) png_input = gr.Image( label="PNG to place(支持透明背景)", type="numpy", sources=["upload", "clipboard"], image_mode="RGBA", ) apply_button = gr.Button("Apply", variant="primary") with gr.Column(scale=1): placement_output = gr.Image(label="Result", type="pil") apply_button.click( fn=paste_png_into_mask, inputs=[mask_editor, png_input], outputs=[placement_output], ) if __name__ == "__main__": demo.launch()