image111 / app.py
fdsgsfjsfg's picture
fix: pad-to-bucket instead of stretch, preserve exact aspect ratio, edge-fill padding
63da0ec verified
import os
import gc
import random
from typing import Iterable, List, Tuple
from huggingface_hub import login as hf_login
_hf_token = os.environ.get("HF_TOKEN")
if _hf_token:
hf_login(token=_hf_token)
import gradio as gr
import numpy as np
import spaces
import torch
from PIL import Image
from gradio.themes import Soft
from gradio.themes.utils import colors, fonts, sizes
# =========================================================
# THEME
# =========================================================
colors.fire_red = colors.Color(
name="fire_red",
c50="#FFF5F0",
c100="#FFE8DB",
c200="#FFD0B5",
c300="#FFB088",
c400="#FF8C5A",
c500="#FF6B35",
c600="#E8531F",
c700="#CC4317",
c800="#A63812",
c900="#80300F",
c950="#5C220A",
)
class FireRedTheme(Soft):
def __init__(
self,
*,
primary_hue: colors.Color | str = colors.gray,
secondary_hue: colors.Color | str = colors.fire_red,
neutral_hue: colors.Color | str = colors.slate,
text_size: sizes.Size | str = sizes.text_md,
font: fonts.Font | str | Iterable[fonts.Font | str] = (
fonts.GoogleFont("Inter"),
"system-ui",
"sans-serif",
),
font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
fonts.GoogleFont("JetBrains Mono"),
"ui-monospace",
"monospace",
),
):
super().__init__(
primary_hue=primary_hue,
secondary_hue=secondary_hue,
neutral_hue=neutral_hue,
text_size=text_size,
font=font,
font_mono=font_mono,
)
super().set(
body_background_fill="#f0f2f6",
body_background_fill_dark="*neutral_950",
background_fill_primary="white",
background_fill_primary_dark="*neutral_900",
block_background_fill="white",
block_background_fill_dark="*neutral_800",
block_border_width="1px",
block_border_color="*neutral_200",
block_border_color_dark="*neutral_700",
block_shadow="0 1px 4px rgba(0,0,0,0.05)",
block_shadow_dark="0 1px 4px rgba(0,0,0,0.25)",
block_title_text_weight="600",
block_label_background_fill="*neutral_50",
block_label_background_fill_dark="*neutral_800",
button_primary_text_color="white",
button_primary_text_color_hover="white",
button_primary_background_fill="linear-gradient(135deg, *secondary_500, *secondary_600)",
button_primary_background_fill_hover="linear-gradient(135deg, *secondary_600, *secondary_700)",
button_primary_background_fill_dark="linear-gradient(135deg, *secondary_500, *secondary_600)",
button_primary_background_fill_hover_dark="linear-gradient(135deg, *secondary_600, *secondary_700)",
button_primary_shadow="0 4px 14px rgba(232, 83, 31, 0.25)",
button_secondary_text_color="*secondary_700",
button_secondary_text_color_dark="*secondary_300",
button_secondary_background_fill="*secondary_50",
button_secondary_background_fill_hover="*secondary_100",
button_secondary_background_fill_dark="rgba(255, 107, 53, 0.1)",
button_secondary_background_fill_hover_dark="rgba(255, 107, 53, 0.2)",
button_large_padding="12px 24px",
slider_color="*secondary_500",
slider_color_dark="*secondary_500",
input_border_color_focus="*secondary_400",
input_border_color_focus_dark="*secondary_500",
color_accent_soft="*secondary_50",
color_accent_soft_dark="rgba(255, 107, 53, 0.15)",
)
theme = FireRedTheme()
# =========================================================
# MODEL
# =========================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("CUDA_VISIBLE_DEVICES =", os.environ.get("CUDA_VISIBLE_DEVICES"))
print("torch.__version__ =", torch.__version__)
print("device =", device)
from diffusers import FlowMatchEulerDiscreteScheduler, QwenImageEditPlusPipeline # noqa: E402,F401
from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3 # noqa: E402
from transformers import AutoModelForImageSegmentation # noqa: E402
from torchvision import transforms # noqa: E402
import torch.nn.functional as F # noqa: E402
dtype = torch.bfloat16
# โ”€โ”€ FireRed ็ผ–่พ‘ๆจกๅž‹๏ผˆๅฎ˜ๆ–นๅŽŸ็”ŸๅŠ ่ฝฝ๏ผ‰โ”€โ”€
pipe = QwenImageEditPlusPipeline.from_pretrained(
"FireRedTeam/FireRed-Image-Edit-1.1",
torch_dtype=dtype,
).to(device)
pipe.vae.enable_tiling()
pipe.vae.enable_slicing()
try:
pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
print("Flash Attention 3 Processor set successfully.")
except Exception as e:
print(f"Warning: Could not set FA3 processor: {e}")
# โ”€โ”€ Lightning LoRA๏ผˆ4ๆญฅๅŠ ้€Ÿ๏ผŒไธŽ ComfyUI Rebels.json ๅฎŒๅ…จไธ€่‡ด๏ผ‰โ”€โ”€
try:
pipe.load_lora_weights(
"Osrivers/Qwen-Image-Lightning-4steps-V2.0-bf16.safetensors",
weight_name="Qwen-Image-Lightning-4steps-V2.0-bf16.safetensors",
adapter_name="lightning",
)
pipe.set_adapters(["lightning"], adapter_weights=[1.0])
print("Lightning LoRA (4steps V2.0) loaded successfully.")
except Exception as e:
print(f"Warning: Could not load Lightning LoRA: {e}")
# โ”€โ”€ RMBG 2.0 ๆŠ ๅ›พๆจกๅž‹ โ”€โ”€
rmbg = AutoModelForImageSegmentation.from_pretrained(
"briaai/RMBG-2.0",
trust_remote_code=True,
)
rmbg.to(device)
rmbg.eval()
MAX_SEED = np.iinfo(np.int32).max
DEFAULT_NEGATIVE_PROMPT = (
"worst quality, low quality, bad anatomy, bad hands, text, error, "
"missing fingers, extra digit, fewer digits, cropped, jpeg artifacts, "
"signature, watermark, username, blurry"
)
# =========================================================
# SAFE BUCKETS (~1MP each)
# =========================================================
SAFE_BUCKETS: List[Tuple[int, int]] = [
# ๆ ‡ๅ‡†ๆกถ (~1MP)
(1024, 1024),
(1184, 880),
(880, 1184),
(1392, 752),
(752, 1392),
(1568, 672),
(672, 1568),
# ๅฎฝๅ›พๆกถ๏ผˆ็ปผ่‰บ่Šฑๅญ—็ญ‰้•ฟๆกๅฝขๅ›พ๏ผ‰
(1920, 640), # 3:1
(1600, 400), # 4:1 โ† Rebels.json ๅŒๆฌพ
(2048, 512), # 4:1
(1920, 384), # 5:1
(2560, 512), # 5:1
(2048, 336), # ~6:1
]
UPSCALE_SMALL_IMAGES = True
_rmbg_normalize = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
RMBG_SIZE = 1024
@spaces.GPU
def run_rmbg(pil_image: Image.Image) -> Image.Image:
"""็”จ RMBG-2.0 ๅŽป้™ค่ƒŒๆ™ฏ๏ผŒไธŽ ComfyUI comfyui-rmbg ๅฎŒๅ…จไธ€่‡ด๏ผš
squish ๅˆฐ 1024ร—1024๏ผŒsigmoid ๆฟ€ๆดป๏ผŒbilinear resize ๅ›žๅŽŸๅฐบๅฏธใ€‚
"""
orig_w, orig_h = pil_image.size
inp = _rmbg_normalize(pil_image.convert("RGB").resize((RMBG_SIZE, RMBG_SIZE), Image.LANCZOS))
inp = inp.unsqueeze(0).to(device)
with torch.no_grad():
outputs = rmbg(inp)
# ไธŽ ComfyUI ๅฎŒๅ…จไธ€่‡ด๏ผšๅ–ๆœ€ๅŽ่พ“ๅ‡บๅฑ‚๏ผŒsigmoid ๆฟ€ๆดป
if isinstance(outputs, list):
result = outputs[-1].sigmoid().cpu()
elif isinstance(outputs, dict) and 'logits' in outputs:
result = outputs['logits'].sigmoid().cpu()
else:
result = outputs.sigmoid().cpu()
result = torch.clamp(result.squeeze(), 0, 1)
result = F.interpolate(result.unsqueeze(0).unsqueeze(0), size=(orig_h, orig_w), mode='bilinear').squeeze()
mask_pil = Image.fromarray((result.numpy() * 255).astype(np.uint8))
out = pil_image.convert("RGBA")
out.putalpha(mask_pil)
return out
def color_match_reinhard(source: Image.Image, result: Image.Image) -> Image.Image:
"""Reinhard RGB ๅ‡ๅ€ผ/ๆ ‡ๅ‡†ๅทฎ่‰ฒๅฝฉ่ฟ็งป๏ผšๅฐ† result ็š„่‰ฒ่ฐƒๅฏน้ฝ sourceใ€‚"""
src = np.array(source.convert("RGB")).astype(np.float32)
res = np.array(result.convert("RGB")).astype(np.float32)
out = np.zeros_like(res)
for c in range(3):
s_mean, s_std = src[:, :, c].mean(), src[:, :, c].std()
r_mean, r_std = res[:, :, c].mean(), res[:, :, c].std()
ratio = s_std / (r_std + 1e-6) if r_std > 0.5 else 1.0
out[:, :, c] = (res[:, :, c] - r_mean) * ratio + s_mean
return Image.fromarray(np.clip(out, 0, 255).astype(np.uint8))
def remove_black_bg(pil_image: Image.Image, dark_thresh: int = 40) -> Image.Image:
"""ไป…ๅŽป้™คไธŽๅ››่พน่ฟž้€š็š„้ป‘่‰ฒ่ƒŒๆ™ฏ๏ผŒไฟ็•™ๆ–‡ๅญ—ๅ†…้ƒจ็š„้ป‘่‰ฒๅ…ƒ็ด ใ€‚
็”จ่ฟž้€šๅŒบๅŸŸๆ ‡่ฎฐ๏ผˆflood fill๏ผ‰ๅฎž็Žฐ๏ผŒไธไพ่ต– AI ๆจกๅž‹ใ€‚
"""
from scipy import ndimage as ndi
arr = np.array(pil_image.convert("RGB"))
dark_mask = np.all(arr <= dark_thresh, axis=2)
labeled, _ = ndi.label(dark_mask)
# ๆ‰พๆ‰€ๆœ‰ไธŽๅ›พ็‰‡่พน็ผ˜็›ธ่ฟž็š„่ฟž้€šๅŒบๅŸŸ
border_labels = set()
border_labels.update(labeled[0, :].tolist())
border_labels.update(labeled[-1, :].tolist())
border_labels.update(labeled[:, 0].tolist())
border_labels.update(labeled[:, -1].tolist())
border_labels.discard(0) # 0 = ้ž้ป‘่‰ฒๅŒบๅŸŸ
bg_mask = np.zeros(arr.shape[:2], dtype=bool)
for lbl in border_labels:
bg_mask |= (labeled == lbl)
alpha = np.where(bg_mask, 0, 255).astype(np.uint8)
out = pil_image.convert("RGBA")
out.putalpha(Image.fromarray(alpha))
return out
def add_image_watermark(result: Image.Image, ref: Image.Image, size: int = 200, padding: int = 16) -> Image.Image:
result = result.copy().convert("RGBA")
thumb = ref.convert("RGBA")
thumb.thumbnail((size, size), Image.LANCZOS)
result.paste(thumb, (padding, padding), thumb)
return result.convert("RGB")
def paste_png_into_mask(editor_value: dict, png_image) -> Image.Image:
"""
ไปŽ ImageEditor ็š„ mask ๅฑ‚ๆๅ– bounding box๏ผŒ
ๆŠŠ PNG ็ญ‰ๆฏ”็ผฉๆ”พ๏ผˆๆœ€้•ฟ่พน = mask ๆœ€้•ฟ่พน๏ผ‰ๅŽๅฑ…ไธญ่ดดๅ…ฅใ€‚
"""
if editor_value is None:
raise gr.Error("โš ๏ธ Please upload and draw a mask on the source image.")
if png_image is None:
raise gr.Error("โš ๏ธ Please upload a PNG to place.")
# ๅ–ๅบ•ๅ›พๅ’Œ mask ๅฑ‚
background: Image.Image = editor_value.get("background")
layers: list = editor_value.get("layers", [])
if background is None:
raise gr.Error("โš ๏ธ No source image found.")
if not layers:
raise gr.Error("โš ๏ธ Please draw a mask area on the image first.")
if isinstance(background, np.ndarray):
background = Image.fromarray(background)
background = background.convert("RGBA")
mask_layer = layers[0]
if isinstance(mask_layer, np.ndarray):
mask_layer = Image.fromarray(mask_layer)
mask_layer = mask_layer.convert("RGBA")
# ไปŽ mask ๅฑ‚็š„ alpha ้€š้“ๆ‰พ bounding box
alpha = mask_layer.split()[3]
bbox = alpha.getbbox()
if bbox is None:
raise gr.Error("โš ๏ธ Mask area is empty. Please draw on the image.")
x1, y1, x2, y2 = bbox
mask_w = x2 - x1
mask_h = y2 - y1
mask_longest = max(mask_w, mask_h)
# ๅŠ ่ฝฝ PNG
if isinstance(png_image, str):
png = Image.open(png_image).convert("RGBA")
else:
png = Image.fromarray(png_image).convert("RGBA")
png_w, png_h = png.size
png_longest = max(png_w, png_h)
# ็ญ‰ๆฏ”็ผฉๆ”พ๏ผšๆœ€้•ฟ่พนๅฏน้ฝ mask ๆœ€้•ฟ่พน
scale = mask_longest / png_longest
new_w = max(1, int(png_w * scale))
new_h = max(1, int(png_h * scale))
png_resized = png.resize((new_w, new_h), Image.LANCZOS)
# ๅฑ…ไธญ่ดดๅ…ฅ mask ๅŒบๅŸŸ
paste_x = x1 + (mask_w - new_w) // 2
paste_y = y1 + (mask_h - new_h) // 2
result = background.copy()
result.paste(png_resized, (paste_x, paste_y), png_resized)
return result.convert("RGB")
# =========================================================
# HELPERS
# =========================================================
def load_pil_image(item) -> Image.Image:
if item is None:
return None
if isinstance(item, Image.Image):
return item.convert("RGB")
if isinstance(item, str):
return Image.open(item).convert("RGB")
if isinstance(item, (tuple, list)):
path = item[0]
if isinstance(path, Image.Image):
return path.convert("RGB")
return Image.open(path).convert("RGB")
return Image.open(item.name).convert("RGB")
def pick_best_bucket(
orig_w: int,
orig_h: int,
buckets: List[Tuple[int, int]] = SAFE_BUCKETS,
allow_upscale: bool = UPSCALE_SMALL_IMAGES,
) -> Tuple[int, int]:
if orig_w <= 0 or orig_h <= 0:
return 1024, 1024
orig_ratio = orig_w / orig_h
def score(bucket):
bw, bh = bucket
ratio_diff = abs((bw / bh) - orig_ratio)
area_diff = abs((bw * bh) - (orig_w * orig_h))
return (ratio_diff, area_diff)
sorted_buckets = sorted(buckets, key=score)
if allow_upscale:
return sorted_buckets[0]
not_larger = [b for b in sorted_buckets if b[0] <= orig_w and b[1] <= orig_h]
return not_larger[0] if not_larger else sorted_buckets[0]
def prepare_images_before_pipe(
pil_images: List[Image.Image],
allow_upscale: bool = UPSCALE_SMALL_IMAGES,
divisible_by: int = 16,
) -> Tuple[List[Image.Image], int, int, tuple]:
"""ๅ‡†ๅค‡ๅ›พ็‰‡๏ผš็ญ‰ๆฏ”็ผฉๆ”พ + ่กฅ่พนๅˆฐๆœ€ไฝณ bucket๏ผŒไฟ็•™ๅŽŸๅง‹ๆฏ”ไพ‹ใ€‚
่ฟ”ๅ›ž (processed_images, width, height, pad_info)
pad_info = (pad_left, pad_top, content_w, content_h) ็”จไบŽๆŽจ็†ๅŽ่ฃๅ‰ช่กฅ่พนใ€‚
"""
if not pil_images:
raise ValueError("No input images.")
base_w, base_h = pil_images[0].size
# ้€‰ๆœ€ไฝณ bucket๏ผˆ~1MP๏ผŒๆฏ”ไพ‹ๆœ€ๆŽฅ่ฟ‘๏ผ‰
bucket_w, bucket_h = pick_best_bucket(base_w, base_h, SAFE_BUCKETS, allow_upscale)
# ็ญ‰ๆฏ”็ผฉๆ”พ fit ๅˆฐ bucket ๅ†…๏ผˆไธๆ‹‰ไผธ๏ผ‰
scale = min(bucket_w / base_w, bucket_h / base_h)
content_w = max(divisible_by, round(base_w * scale))
content_h = max(divisible_by, round(base_h * scale))
# ๅฑ…ไธญ่กฅ่พนๅˆฐ bucket ๅฐบๅฏธ
pad_left = (bucket_w - content_w) // 2
pad_top = (bucket_h - content_h) // 2
pad_info = (pad_left, pad_top, content_w, content_h)
processed = []
for img in pil_images:
# ็ญ‰ๆฏ”็ผฉๆ”พ
resized = img.resize((content_w, content_h), Image.LANCZOS)
# ๅˆ›ๅปบ bucket ๅคงๅฐ็š„็”ปๅธƒ๏ผŒ่พน็ผ˜็”จ้•œๅƒๅกซๅ……ๅ‡ๅฐ‘ๆŽฅ็ผ
canvas = Image.new("RGB", (bucket_w, bucket_h), (0, 0, 0))
canvas.paste(resized, (pad_left, pad_top))
# ็”จ่พน็ผ˜ๅƒ็ด ๅกซๅ……่กฅ่พนๅŒบๅŸŸ๏ผˆๆฏ”็บฏ้ป‘ๆ•ˆๆžœๅฅฝ๏ผ‰
import numpy as _np
arr = np.array(canvas)
res_arr = np.array(resized)
# ๅกซๅ……ๅทฆๅณ
if pad_left > 0:
left_col = res_arr[:, 0:1, :]
arr[pad_top:pad_top+content_h, :pad_left, :] = np.broadcast_to(left_col, (content_h, pad_left, 3))
right_start = pad_left + content_w
if right_start < bucket_w:
right_col = res_arr[:, -1:, :]
arr[pad_top:pad_top+content_h, right_start:, :] = np.broadcast_to(right_col, (content_h, bucket_w - right_start, 3))
# ๅกซๅ……ไธŠไธ‹
if pad_top > 0:
top_row = arr[pad_top:pad_top+1, :, :]
arr[:pad_top, :, :] = np.broadcast_to(top_row, (pad_top, bucket_w, 3))
bottom_start = pad_top + content_h
if bottom_start < bucket_h:
bottom_row = arr[bottom_start-1:bottom_start, :, :]
arr[bottom_start:, :, :] = np.broadcast_to(bottom_row, (bucket_h - bottom_start, bucket_w, 3))
processed.append(Image.fromarray(arr))
return processed, bucket_w, bucket_h, pad_info
def extract_pil_from_source(source) -> Image.Image:
"""ไปŽ gr.ImageEditor dict ๆˆ–ๆ™ฎ้€š่ทฏๅพ„/PIL ไธญๆๅ–ๅ›พ็‰‡๏ผˆไฝฟ็”จ composite ไฟ็•™ๆถ‚่‰ฒๆ ‡ๆณจ๏ผ‰ใ€‚"""
if source is None:
return None
if isinstance(source, dict):
img = source.get("composite")
if img is None:
img = source.get("background")
if img is None:
return None
if isinstance(img, np.ndarray):
return Image.fromarray(img).convert("RGB")
return img.convert("RGB")
return load_pil_image(source)
def format_info(seed_val, source_img, ref_img):
lines = [f"**Seed:** `{int(seed_val)}`"]
for label, img in [("Source", source_img), ("Reference", ref_img)]:
if img is None:
continue
try:
pil = extract_pil_from_source(img) if label == "Source" else load_pil_image(img)
ow, oh = pil.size
nw, nh = pick_best_bucket(ow, oh, SAFE_BUCKETS, UPSCALE_SMALL_IMAGES)
lines.append(
f"\n**{label}:** {ow}ร—{oh} โ†’ **{nw}ร—{nh}** "
f"(ratio {ow/oh:.3f} โ†’ {nw/nh:.3f})"
)
except Exception:
pass
return "\n\n".join(lines)
# =========================================================
# INFERENCE
# =========================================================
@spaces.GPU
def infer(
source_image,
ref_image,
prompt,
negative_prompt,
seed,
randomize_seed,
guidance_scale,
steps,
color_match,
out_width=0,
out_height=0,
progress=gr.Progress(track_tqdm=True),
):
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
if source_image is None:
raise gr.Error("โš ๏ธ Please upload a source image.")
if not prompt or not prompt.strip():
raise gr.Error("โš ๏ธ Please enter an edit prompt.")
# ๆๅ–ๅŽŸๅ›พ๏ผˆๅ…ผๅฎน ImageEditor dict ๅ’Œๆ™ฎ้€š่ทฏๅพ„๏ผ‰
try:
src_pil = extract_pil_from_source(source_image)
except Exception as e:
raise gr.Error(f"โš ๏ธ Could not load source image: {e}")
if src_pil is None:
raise gr.Error("โš ๏ธ Please upload a source image.")
# ่ฎฐๅฝ•ๅŽŸๅง‹ๅฐบๅฏธ๏ผŒๆŽจ็†ๅŽ resize ๅ›žๆฅ๏ผŒ้ฟๅ… 16 ๅฏน้ฝๅฏผ่‡ด่ฃๅ‰ช
orig_size = src_pil.size # (w, h)
# โ”€โ”€ ่ทฏ็”ฑ๏ผšๆŠ ๅ›พ โ”€โ”€
if "ๆŠ " in prompt:
if "้ป‘ๅบ•" in prompt:
# ้ป‘ๅบ•่Šฑๅญ—๏ผš่ฟž้€šๅŒบๅŸŸๅŽป้™คๅค–ๅ›ด้ป‘่‰ฒ๏ผŒไฟ็•™ๆ–‡ๅญ—ๅ†…้ƒจ้ป‘่‰ฒ
result = remove_black_bg(src_pil)
else:
# ๆ™ฎ้€šๆŠ ๅ›พ๏ผšRMBG 2.0 ่ฏญไน‰ๅˆ†ๅ‰ฒ
result = run_rmbg(src_pil)
return result, seed
# ๆ”ถ้›†ๅ›พ็‰‡๏ผšๅŽŸๅ›พๅฟ…้กป๏ผŒๅ‚่€ƒๅ›พๅฏ้€‰
pil_images = [src_pil]
if ref_image is not None:
try:
pil_images.append(load_pil_image(ref_image))
except Exception as e:
print(f"Warning: could not load reference image: {e}")
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device=device).manual_seed(int(seed))
processed_images, width, height, pad_info = prepare_images_before_pipe(
pil_images, allow_upscale=UPSCALE_SMALL_IMAGES
)
# ๆ˜พๅผๆŒ‡ๅฎš่พ“ๅ‡บๅฐบๅฏธ๏ผˆๅฏน้ฝ ComfyUI EmptyLatentImage ่กŒไธบ๏ผ‰
if out_width > 0:
width = (out_width // 16) * 16
if out_height > 0:
height = (out_height // 16) * 16
try:
result = pipe(
image=processed_images,
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_inference_steps=steps,
generator=generator,
true_cfg_scale=guidance_scale,
).images[0]
# โ”€โ”€ ่ฃๆމ่กฅ่พน๏ผŒ่ฟ˜ๅŽŸๅˆฐๅŽŸๅง‹ๆฏ”ไพ‹ๅ†…ๅฎนๅŒบๅŸŸ โ”€โ”€
pad_left, pad_top, content_w, content_h = pad_info
if pad_left > 0 or pad_top > 0 or content_w < width or content_h < height:
result = result.crop((pad_left, pad_top, pad_left + content_w, pad_top + content_h))
# โ”€โ”€ ่ฟ˜ๅŽŸๅˆฐๅŽŸๅง‹ๅฐบๅฏธ โ”€โ”€
if result.size != orig_size:
result = result.resize(orig_size, Image.LANCZOS)
if ref_image is not None and len(pil_images) > 1:
result = add_image_watermark(result, pil_images[1])
if color_match:
# ็”จๅŽŸๅ›พ่ƒŒๆ™ฏ๏ผˆๆ— ็ฌ”่ฟน๏ผ‰ไฝœไธบ่‰ฒๅฝฉๅ‚่€ƒ
if isinstance(source_image, dict):
bg = source_image.get("background")
if bg is not None:
ref_pil = Image.fromarray(bg).convert("RGB") if isinstance(bg, np.ndarray) else bg.convert("RGB")
else:
ref_pil = src_pil
else:
ref_pil = src_pil
ref_pil_resized = ref_pil.resize(result.size, Image.LANCZOS)
result = color_match_reinhard(ref_pil_resized, result)
return result, seed
finally:
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# =========================================================
# UI
# =========================================================
# JS๏ผš็ญ‰ ImageEditor ๆธฒๆŸ“ๅฎŒๅŽ๏ผŒๆŠŠ็ปๅฏนๅฎšไฝ็š„ๅทฅๅ…ทๆ ๆ”นไธบ็›ธๅฏนๅฎšไฝ๏ผŒ
# ไฝฟๅ…ถไธๅ†ๆ‚ฌๆตฎ่ฆ†็›–็”ปๅธƒ๏ผˆCSS ้€‰ๆ‹ฉๅ™จไผš่ขซ Svelte ไฝœ็”จๅŸŸๅ“ˆๅธŒ้˜ปๆŒก๏ผŒ
# ๆ‰€ไปฅ็”จ JS ้€š่ฟ‡ getComputedStyle ็ฒพ็กฎๆฃ€ๆต‹ๅนถๅผบๅˆถไฟฎๆ”น๏ผ‰
_FIX_TOOLBAR_JS = """
() => {
const setup = (ed) => {
if (ed.dataset.toggleReady) return;
// ๆ‰พๅทฅๅ…ทๆ ๅ…ƒ็ด ๏ผˆGradio/Svelte ไผš็ป™ class ๅŠ ๅ“ˆๅธŒ๏ผŒ็”จ includes ๅŒน้…๏ผ‰
const toolbar = Array.from(ed.querySelectorAll('*')).find(el => {
const cls = el.getAttribute('class') || '';
return cls.includes('toolbar') || cls.includes('tool-bar');
});
if (!toolbar) return;
ed.dataset.toggleReady = '1';
// ๆ’ๅ…ฅๅˆ‡ๆขๆŒ‰้’ฎ๏ผŒๆ”พๅœจ toolbar ็š„็ˆถๅฎนๅ™จ็ฌฌไธ€ไฝ
const btn = document.createElement('button');
btn.className = 'toolbar-toggle-btn';
btn.textContent = '๐ŸŽจ ้š่—็”ป็ฌ”ๅทฅๅ…ทๆ ';
let hidden = false;
btn.onclick = () => {
hidden = !hidden;
// ็”จ visibility ่€Œ้ž display๏ผŒ้ฟๅ…็”ปๅธƒๅŒบๅŸŸ่ทณๅŠจ
toolbar.style.visibility = hidden ? 'hidden' : '';
toolbar.style.pointerEvents = hidden ? 'none' : '';
btn.textContent = hidden ? '๐ŸŽจ ๆ˜พ็คบ็”ป็ฌ”ๅทฅๅ…ทๆ ' : '๐ŸŽจ ้š่—็”ป็ฌ”ๅทฅๅ…ทๆ ';
};
toolbar.parentNode.insertBefore(btn, toolbar);
};
const mo = new MutationObserver(() => {
document.querySelectorAll('.src-editor').forEach(setup);
});
mo.observe(document.body, { childList: true, subtree: true });
setTimeout(() => document.querySelectorAll('.src-editor').forEach(setup), 1000);
}
"""
with gr.Blocks(
theme=theme,
js=_FIX_TOOLBAR_JS,
css="""
.gradio-container {
max-width: 1400px !important;
margin: 0 auto;
padding-top: 20px;
}
.hero {
text-align: center;
padding: 24px 0 12px 0;
}
.hero h1 {
font-size: 2.2rem;
font-weight: 800;
margin-bottom: 8px;
}
.hero p {
font-size: 1rem;
color: #666;
margin-bottom: 0;
}
/* ๅทฅๅ…ทๆ ้š่—ๆ—ถ๏ผŒ้š่—ๆŒ‰้’ฎไปๅฏ็‚นๅ‡ป */
.toolbar-toggle-btn {
display: block;
width: 100%;
padding: 4px 10px;
margin-bottom: 2px;
background: #f0f0f0;
border: 1px solid #ddd;
border-radius: 4px;
font-size: 12px;
cursor: pointer;
text-align: left;
color: #555;
}
""",
) as demo:
gr.HTML("""
<div class="hero">
<h1>๐Ÿ”ฅ FireRed Image Edit 1.1 Fast</h1>
</div>
""")
with gr.Tabs():
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# Tab 1: AI ็ผ–่พ‘
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
with gr.Tab("AI Edit"):
with gr.Row():
with gr.Column(scale=1):
source_input = gr.ImageEditor(
label="Source Image โ€” ๅฏ็”จ็”ป็ฌ”ๆ ‡ๆณจๅŒบๅŸŸ๏ผˆ็บข/็ปฟ/่“็ญ‰๏ผ‰๏ผŒๆ็คบ่ฏไธญๅผ•็”จ้ขœ่‰ฒ",
elem_classes=["src-editor"],
brush=gr.Brush(
colors=["#FF0000", "#00CC00", "#0066FF", "#FFFF00", "#FF00FF", "#FFFFFF"],
color_mode="defaults",
),
)
gr.Markdown(
"<small>๐Ÿ”ด็บข ๐ŸŸข็ปฟ ๐Ÿ”ต่“ ๐ŸŸก้ป„ ๐ŸŸฃ็ดซ โฌœ็™ฝ โ€” ็”ปๅฅฝๅŽๆ็คบ่ฏๅ†™๏ผš*ๅŽปๆމ็บข่‰ฒๆ ‡ๆณจ็š„ๅŒบๅŸŸ* ็ญ‰</small>"
)
with gr.Row():
ref_input = gr.Image(
label="Reference Image๏ผˆๅ‚่€ƒๅ›พ๏ผŒๅฏ้€‰๏ผ‰",
type="filepath",
sources=["upload", "clipboard"],
)
prompt_input = gr.Textbox(
label="Prompt",
placeholder="Describe how you want to edit the image...",
lines=4,
)
negative_prompt_input = gr.Textbox(
label="Negative Prompt",
value=DEFAULT_NEGATIVE_PROMPT,
lines=3,
)
color_match_input = gr.Checkbox(label="Color Match โ€” ่‰ฒๅฝฉๅฏน้ฝๅŽŸๅ›พ", value=True)
with gr.Accordion("Advanced Settings", open=False):
seed_input = gr.Slider(
label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0,
)
randomize_seed_input = gr.Checkbox(label="Randomize Seed", value=True)
guidance_scale_input = gr.Slider(
label="Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=1.0,
)
steps_input = gr.Slider(
label="Inference Steps", minimum=1, maximum=50, step=1, value=4,
)
run_button = gr.Button("Generate", variant="primary")
info_markdown = gr.Markdown()
with gr.Column(scale=1):
output_image = gr.Image(label="Result", type="pil")
for trigger in [source_input, ref_input, seed_input]:
trigger.change(
fn=format_info,
inputs=[seed_input, source_input, ref_input],
outputs=[info_markdown],
)
run_button.click(
fn=infer,
inputs=[
source_input, ref_input, prompt_input, negative_prompt_input,
seed_input, randomize_seed_input, guidance_scale_input, steps_input,
color_match_input,
],
outputs=[output_image, seed_input],
).then(
fn=format_info,
inputs=[seed_input, source_input, ref_input],
outputs=[info_markdown],
)
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# Tab 2: PNG ่ดดๅ›พ๏ผˆ็”ป mask โ†’ ็ญ‰ๆฏ”่ดดๅ…ฅ๏ผ‰
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
with gr.Tab("PNG Placement"):
gr.Markdown("**็”จๆณ•๏ผš** ไธŠไผ ๅบ•ๅ›พๅŽๅœจๅ›พไธŠๆถ‚ๆŠนๅ‡บๆ”พ็ฝฎๅŒบๅŸŸ๏ผŒๅ†ไธŠไผ  PNG๏ผŒ็‚นๅ‡ป Applyใ€‚PNG ไผš็ญ‰ๆฏ”็ผฉๆ”พ๏ผŒๆœ€้•ฟ่พนๅฏน้ฝ mask ๆœ€้•ฟ่พน๏ผŒๅฑ…ไธญ่ดดๅ…ฅใ€‚")
with gr.Row():
with gr.Column(scale=1):
mask_editor = gr.ImageEditor(
label="Source Image โ€” ๅœจๅ›พไธŠๆถ‚ๆŠนๅ‡บๆ”พ็ฝฎๅŒบๅŸŸ",
brush=gr.Brush(colors=["#FF6B35"], color_mode="fixed"),
)
png_input = gr.Image(
label="PNG to place๏ผˆๆ”ฏๆŒ้€ๆ˜Ž่ƒŒๆ™ฏ๏ผ‰",
type="numpy",
sources=["upload", "clipboard"],
image_mode="RGBA",
)
apply_button = gr.Button("Apply", variant="primary")
with gr.Column(scale=1):
placement_output = gr.Image(label="Result", type="pil")
apply_button.click(
fn=paste_png_into_mask,
inputs=[mask_editor, png_input],
outputs=[placement_output],
)
if __name__ == "__main__":
demo.launch()