Spaces:
Running on Zero
Running on Zero
fix: pad-to-bucket instead of stretch, preserve exact aspect ratio, edge-fill padding
63da0ec verified | import os | |
| import gc | |
| import random | |
| from typing import Iterable, List, Tuple | |
| from huggingface_hub import login as hf_login | |
| _hf_token = os.environ.get("HF_TOKEN") | |
| if _hf_token: | |
| hf_login(token=_hf_token) | |
| import gradio as gr | |
| import numpy as np | |
| import spaces | |
| import torch | |
| from PIL import Image | |
| from gradio.themes import Soft | |
| from gradio.themes.utils import colors, fonts, sizes | |
| # ========================================================= | |
| # THEME | |
| # ========================================================= | |
| colors.fire_red = colors.Color( | |
| name="fire_red", | |
| c50="#FFF5F0", | |
| c100="#FFE8DB", | |
| c200="#FFD0B5", | |
| c300="#FFB088", | |
| c400="#FF8C5A", | |
| c500="#FF6B35", | |
| c600="#E8531F", | |
| c700="#CC4317", | |
| c800="#A63812", | |
| c900="#80300F", | |
| c950="#5C220A", | |
| ) | |
| class FireRedTheme(Soft): | |
| def __init__( | |
| self, | |
| *, | |
| primary_hue: colors.Color | str = colors.gray, | |
| secondary_hue: colors.Color | str = colors.fire_red, | |
| neutral_hue: colors.Color | str = colors.slate, | |
| text_size: sizes.Size | str = sizes.text_md, | |
| font: fonts.Font | str | Iterable[fonts.Font | str] = ( | |
| fonts.GoogleFont("Inter"), | |
| "system-ui", | |
| "sans-serif", | |
| ), | |
| font_mono: fonts.Font | str | Iterable[fonts.Font | str] = ( | |
| fonts.GoogleFont("JetBrains Mono"), | |
| "ui-monospace", | |
| "monospace", | |
| ), | |
| ): | |
| super().__init__( | |
| primary_hue=primary_hue, | |
| secondary_hue=secondary_hue, | |
| neutral_hue=neutral_hue, | |
| text_size=text_size, | |
| font=font, | |
| font_mono=font_mono, | |
| ) | |
| super().set( | |
| body_background_fill="#f0f2f6", | |
| body_background_fill_dark="*neutral_950", | |
| background_fill_primary="white", | |
| background_fill_primary_dark="*neutral_900", | |
| block_background_fill="white", | |
| block_background_fill_dark="*neutral_800", | |
| block_border_width="1px", | |
| block_border_color="*neutral_200", | |
| block_border_color_dark="*neutral_700", | |
| block_shadow="0 1px 4px rgba(0,0,0,0.05)", | |
| block_shadow_dark="0 1px 4px rgba(0,0,0,0.25)", | |
| block_title_text_weight="600", | |
| block_label_background_fill="*neutral_50", | |
| block_label_background_fill_dark="*neutral_800", | |
| button_primary_text_color="white", | |
| button_primary_text_color_hover="white", | |
| button_primary_background_fill="linear-gradient(135deg, *secondary_500, *secondary_600)", | |
| button_primary_background_fill_hover="linear-gradient(135deg, *secondary_600, *secondary_700)", | |
| button_primary_background_fill_dark="linear-gradient(135deg, *secondary_500, *secondary_600)", | |
| button_primary_background_fill_hover_dark="linear-gradient(135deg, *secondary_600, *secondary_700)", | |
| button_primary_shadow="0 4px 14px rgba(232, 83, 31, 0.25)", | |
| button_secondary_text_color="*secondary_700", | |
| button_secondary_text_color_dark="*secondary_300", | |
| button_secondary_background_fill="*secondary_50", | |
| button_secondary_background_fill_hover="*secondary_100", | |
| button_secondary_background_fill_dark="rgba(255, 107, 53, 0.1)", | |
| button_secondary_background_fill_hover_dark="rgba(255, 107, 53, 0.2)", | |
| button_large_padding="12px 24px", | |
| slider_color="*secondary_500", | |
| slider_color_dark="*secondary_500", | |
| input_border_color_focus="*secondary_400", | |
| input_border_color_focus_dark="*secondary_500", | |
| color_accent_soft="*secondary_50", | |
| color_accent_soft_dark="rgba(255, 107, 53, 0.15)", | |
| ) | |
| theme = FireRedTheme() | |
| # ========================================================= | |
| # MODEL | |
| # ========================================================= | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print("CUDA_VISIBLE_DEVICES =", os.environ.get("CUDA_VISIBLE_DEVICES")) | |
| print("torch.__version__ =", torch.__version__) | |
| print("device =", device) | |
| from diffusers import FlowMatchEulerDiscreteScheduler, QwenImageEditPlusPipeline # noqa: E402,F401 | |
| from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3 # noqa: E402 | |
| from transformers import AutoModelForImageSegmentation # noqa: E402 | |
| from torchvision import transforms # noqa: E402 | |
| import torch.nn.functional as F # noqa: E402 | |
| dtype = torch.bfloat16 | |
| # โโ FireRed ็ผ่พๆจกๅ๏ผๅฎๆนๅ็ๅ ่ฝฝ๏ผโโ | |
| pipe = QwenImageEditPlusPipeline.from_pretrained( | |
| "FireRedTeam/FireRed-Image-Edit-1.1", | |
| torch_dtype=dtype, | |
| ).to(device) | |
| pipe.vae.enable_tiling() | |
| pipe.vae.enable_slicing() | |
| try: | |
| pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3()) | |
| print("Flash Attention 3 Processor set successfully.") | |
| except Exception as e: | |
| print(f"Warning: Could not set FA3 processor: {e}") | |
| # โโ Lightning LoRA๏ผ4ๆญฅๅ ้๏ผไธ ComfyUI Rebels.json ๅฎๅ จไธ่ด๏ผโโ | |
| try: | |
| pipe.load_lora_weights( | |
| "Osrivers/Qwen-Image-Lightning-4steps-V2.0-bf16.safetensors", | |
| weight_name="Qwen-Image-Lightning-4steps-V2.0-bf16.safetensors", | |
| adapter_name="lightning", | |
| ) | |
| pipe.set_adapters(["lightning"], adapter_weights=[1.0]) | |
| print("Lightning LoRA (4steps V2.0) loaded successfully.") | |
| except Exception as e: | |
| print(f"Warning: Could not load Lightning LoRA: {e}") | |
| # โโ RMBG 2.0 ๆ ๅพๆจกๅ โโ | |
| rmbg = AutoModelForImageSegmentation.from_pretrained( | |
| "briaai/RMBG-2.0", | |
| trust_remote_code=True, | |
| ) | |
| rmbg.to(device) | |
| rmbg.eval() | |
| MAX_SEED = np.iinfo(np.int32).max | |
| DEFAULT_NEGATIVE_PROMPT = ( | |
| "worst quality, low quality, bad anatomy, bad hands, text, error, " | |
| "missing fingers, extra digit, fewer digits, cropped, jpeg artifacts, " | |
| "signature, watermark, username, blurry" | |
| ) | |
| # ========================================================= | |
| # SAFE BUCKETS (~1MP each) | |
| # ========================================================= | |
| SAFE_BUCKETS: List[Tuple[int, int]] = [ | |
| # ๆ ๅๆกถ (~1MP) | |
| (1024, 1024), | |
| (1184, 880), | |
| (880, 1184), | |
| (1392, 752), | |
| (752, 1392), | |
| (1568, 672), | |
| (672, 1568), | |
| # ๅฎฝๅพๆกถ๏ผ็ปผ่บ่ฑๅญ็ญ้ฟๆกๅฝขๅพ๏ผ | |
| (1920, 640), # 3:1 | |
| (1600, 400), # 4:1 โ Rebels.json ๅๆฌพ | |
| (2048, 512), # 4:1 | |
| (1920, 384), # 5:1 | |
| (2560, 512), # 5:1 | |
| (2048, 336), # ~6:1 | |
| ] | |
| UPSCALE_SMALL_IMAGES = True | |
| _rmbg_normalize = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
| ]) | |
| RMBG_SIZE = 1024 | |
| def run_rmbg(pil_image: Image.Image) -> Image.Image: | |
| """็จ RMBG-2.0 ๅป้ค่ๆฏ๏ผไธ ComfyUI comfyui-rmbg ๅฎๅ จไธ่ด๏ผ | |
| squish ๅฐ 1024ร1024๏ผsigmoid ๆฟๆดป๏ผbilinear resize ๅๅๅฐบๅฏธใ | |
| """ | |
| orig_w, orig_h = pil_image.size | |
| inp = _rmbg_normalize(pil_image.convert("RGB").resize((RMBG_SIZE, RMBG_SIZE), Image.LANCZOS)) | |
| inp = inp.unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| outputs = rmbg(inp) | |
| # ไธ ComfyUI ๅฎๅ จไธ่ด๏ผๅๆๅ่พๅบๅฑ๏ผsigmoid ๆฟๆดป | |
| if isinstance(outputs, list): | |
| result = outputs[-1].sigmoid().cpu() | |
| elif isinstance(outputs, dict) and 'logits' in outputs: | |
| result = outputs['logits'].sigmoid().cpu() | |
| else: | |
| result = outputs.sigmoid().cpu() | |
| result = torch.clamp(result.squeeze(), 0, 1) | |
| result = F.interpolate(result.unsqueeze(0).unsqueeze(0), size=(orig_h, orig_w), mode='bilinear').squeeze() | |
| mask_pil = Image.fromarray((result.numpy() * 255).astype(np.uint8)) | |
| out = pil_image.convert("RGBA") | |
| out.putalpha(mask_pil) | |
| return out | |
| def color_match_reinhard(source: Image.Image, result: Image.Image) -> Image.Image: | |
| """Reinhard RGB ๅๅผ/ๆ ๅๅทฎ่ฒๅฝฉ่ฟ็งป๏ผๅฐ result ็่ฒ่ฐๅฏน้ฝ sourceใ""" | |
| src = np.array(source.convert("RGB")).astype(np.float32) | |
| res = np.array(result.convert("RGB")).astype(np.float32) | |
| out = np.zeros_like(res) | |
| for c in range(3): | |
| s_mean, s_std = src[:, :, c].mean(), src[:, :, c].std() | |
| r_mean, r_std = res[:, :, c].mean(), res[:, :, c].std() | |
| ratio = s_std / (r_std + 1e-6) if r_std > 0.5 else 1.0 | |
| out[:, :, c] = (res[:, :, c] - r_mean) * ratio + s_mean | |
| return Image.fromarray(np.clip(out, 0, 255).astype(np.uint8)) | |
| def remove_black_bg(pil_image: Image.Image, dark_thresh: int = 40) -> Image.Image: | |
| """ไป ๅป้คไธๅ่พน่ฟ้็้ป่ฒ่ๆฏ๏ผไฟ็ๆๅญๅ ้จ็้ป่ฒๅ ็ด ใ | |
| ็จ่ฟ้ๅบๅๆ ่ฎฐ๏ผflood fill๏ผๅฎ็ฐ๏ผไธไพ่ต AI ๆจกๅใ | |
| """ | |
| from scipy import ndimage as ndi | |
| arr = np.array(pil_image.convert("RGB")) | |
| dark_mask = np.all(arr <= dark_thresh, axis=2) | |
| labeled, _ = ndi.label(dark_mask) | |
| # ๆพๆๆไธๅพ็่พน็ผ็ธ่ฟ็่ฟ้ๅบๅ | |
| border_labels = set() | |
| border_labels.update(labeled[0, :].tolist()) | |
| border_labels.update(labeled[-1, :].tolist()) | |
| border_labels.update(labeled[:, 0].tolist()) | |
| border_labels.update(labeled[:, -1].tolist()) | |
| border_labels.discard(0) # 0 = ้้ป่ฒๅบๅ | |
| bg_mask = np.zeros(arr.shape[:2], dtype=bool) | |
| for lbl in border_labels: | |
| bg_mask |= (labeled == lbl) | |
| alpha = np.where(bg_mask, 0, 255).astype(np.uint8) | |
| out = pil_image.convert("RGBA") | |
| out.putalpha(Image.fromarray(alpha)) | |
| return out | |
| def add_image_watermark(result: Image.Image, ref: Image.Image, size: int = 200, padding: int = 16) -> Image.Image: | |
| result = result.copy().convert("RGBA") | |
| thumb = ref.convert("RGBA") | |
| thumb.thumbnail((size, size), Image.LANCZOS) | |
| result.paste(thumb, (padding, padding), thumb) | |
| return result.convert("RGB") | |
| def paste_png_into_mask(editor_value: dict, png_image) -> Image.Image: | |
| """ | |
| ไป ImageEditor ็ mask ๅฑๆๅ bounding box๏ผ | |
| ๆ PNG ็ญๆฏ็ผฉๆพ๏ผๆ้ฟ่พน = mask ๆ้ฟ่พน๏ผๅๅฑ ไธญ่ดดๅ ฅใ | |
| """ | |
| if editor_value is None: | |
| raise gr.Error("โ ๏ธ Please upload and draw a mask on the source image.") | |
| if png_image is None: | |
| raise gr.Error("โ ๏ธ Please upload a PNG to place.") | |
| # ๅๅบๅพๅ mask ๅฑ | |
| background: Image.Image = editor_value.get("background") | |
| layers: list = editor_value.get("layers", []) | |
| if background is None: | |
| raise gr.Error("โ ๏ธ No source image found.") | |
| if not layers: | |
| raise gr.Error("โ ๏ธ Please draw a mask area on the image first.") | |
| if isinstance(background, np.ndarray): | |
| background = Image.fromarray(background) | |
| background = background.convert("RGBA") | |
| mask_layer = layers[0] | |
| if isinstance(mask_layer, np.ndarray): | |
| mask_layer = Image.fromarray(mask_layer) | |
| mask_layer = mask_layer.convert("RGBA") | |
| # ไป mask ๅฑ็ alpha ้้ๆพ bounding box | |
| alpha = mask_layer.split()[3] | |
| bbox = alpha.getbbox() | |
| if bbox is None: | |
| raise gr.Error("โ ๏ธ Mask area is empty. Please draw on the image.") | |
| x1, y1, x2, y2 = bbox | |
| mask_w = x2 - x1 | |
| mask_h = y2 - y1 | |
| mask_longest = max(mask_w, mask_h) | |
| # ๅ ่ฝฝ PNG | |
| if isinstance(png_image, str): | |
| png = Image.open(png_image).convert("RGBA") | |
| else: | |
| png = Image.fromarray(png_image).convert("RGBA") | |
| png_w, png_h = png.size | |
| png_longest = max(png_w, png_h) | |
| # ็ญๆฏ็ผฉๆพ๏ผๆ้ฟ่พนๅฏน้ฝ mask ๆ้ฟ่พน | |
| scale = mask_longest / png_longest | |
| new_w = max(1, int(png_w * scale)) | |
| new_h = max(1, int(png_h * scale)) | |
| png_resized = png.resize((new_w, new_h), Image.LANCZOS) | |
| # ๅฑ ไธญ่ดดๅ ฅ mask ๅบๅ | |
| paste_x = x1 + (mask_w - new_w) // 2 | |
| paste_y = y1 + (mask_h - new_h) // 2 | |
| result = background.copy() | |
| result.paste(png_resized, (paste_x, paste_y), png_resized) | |
| return result.convert("RGB") | |
| # ========================================================= | |
| # HELPERS | |
| # ========================================================= | |
| def load_pil_image(item) -> Image.Image: | |
| if item is None: | |
| return None | |
| if isinstance(item, Image.Image): | |
| return item.convert("RGB") | |
| if isinstance(item, str): | |
| return Image.open(item).convert("RGB") | |
| if isinstance(item, (tuple, list)): | |
| path = item[0] | |
| if isinstance(path, Image.Image): | |
| return path.convert("RGB") | |
| return Image.open(path).convert("RGB") | |
| return Image.open(item.name).convert("RGB") | |
| def pick_best_bucket( | |
| orig_w: int, | |
| orig_h: int, | |
| buckets: List[Tuple[int, int]] = SAFE_BUCKETS, | |
| allow_upscale: bool = UPSCALE_SMALL_IMAGES, | |
| ) -> Tuple[int, int]: | |
| if orig_w <= 0 or orig_h <= 0: | |
| return 1024, 1024 | |
| orig_ratio = orig_w / orig_h | |
| def score(bucket): | |
| bw, bh = bucket | |
| ratio_diff = abs((bw / bh) - orig_ratio) | |
| area_diff = abs((bw * bh) - (orig_w * orig_h)) | |
| return (ratio_diff, area_diff) | |
| sorted_buckets = sorted(buckets, key=score) | |
| if allow_upscale: | |
| return sorted_buckets[0] | |
| not_larger = [b for b in sorted_buckets if b[0] <= orig_w and b[1] <= orig_h] | |
| return not_larger[0] if not_larger else sorted_buckets[0] | |
| def prepare_images_before_pipe( | |
| pil_images: List[Image.Image], | |
| allow_upscale: bool = UPSCALE_SMALL_IMAGES, | |
| divisible_by: int = 16, | |
| ) -> Tuple[List[Image.Image], int, int, tuple]: | |
| """ๅๅคๅพ็๏ผ็ญๆฏ็ผฉๆพ + ่กฅ่พนๅฐๆไฝณ bucket๏ผไฟ็ๅๅงๆฏไพใ | |
| ่ฟๅ (processed_images, width, height, pad_info) | |
| pad_info = (pad_left, pad_top, content_w, content_h) ็จไบๆจ็ๅ่ฃๅช่กฅ่พนใ | |
| """ | |
| if not pil_images: | |
| raise ValueError("No input images.") | |
| base_w, base_h = pil_images[0].size | |
| # ้ๆไฝณ bucket๏ผ~1MP๏ผๆฏไพๆๆฅ่ฟ๏ผ | |
| bucket_w, bucket_h = pick_best_bucket(base_w, base_h, SAFE_BUCKETS, allow_upscale) | |
| # ็ญๆฏ็ผฉๆพ fit ๅฐ bucket ๅ ๏ผไธๆไผธ๏ผ | |
| scale = min(bucket_w / base_w, bucket_h / base_h) | |
| content_w = max(divisible_by, round(base_w * scale)) | |
| content_h = max(divisible_by, round(base_h * scale)) | |
| # ๅฑ ไธญ่กฅ่พนๅฐ bucket ๅฐบๅฏธ | |
| pad_left = (bucket_w - content_w) // 2 | |
| pad_top = (bucket_h - content_h) // 2 | |
| pad_info = (pad_left, pad_top, content_w, content_h) | |
| processed = [] | |
| for img in pil_images: | |
| # ็ญๆฏ็ผฉๆพ | |
| resized = img.resize((content_w, content_h), Image.LANCZOS) | |
| # ๅๅปบ bucket ๅคงๅฐ็็ปๅธ๏ผ่พน็ผ็จ้ๅๅกซๅ ๅๅฐๆฅ็ผ | |
| canvas = Image.new("RGB", (bucket_w, bucket_h), (0, 0, 0)) | |
| canvas.paste(resized, (pad_left, pad_top)) | |
| # ็จ่พน็ผๅ็ด ๅกซๅ ่กฅ่พนๅบๅ๏ผๆฏ็บฏ้ปๆๆๅฅฝ๏ผ | |
| import numpy as _np | |
| arr = np.array(canvas) | |
| res_arr = np.array(resized) | |
| # ๅกซๅ ๅทฆๅณ | |
| if pad_left > 0: | |
| left_col = res_arr[:, 0:1, :] | |
| arr[pad_top:pad_top+content_h, :pad_left, :] = np.broadcast_to(left_col, (content_h, pad_left, 3)) | |
| right_start = pad_left + content_w | |
| if right_start < bucket_w: | |
| right_col = res_arr[:, -1:, :] | |
| arr[pad_top:pad_top+content_h, right_start:, :] = np.broadcast_to(right_col, (content_h, bucket_w - right_start, 3)) | |
| # ๅกซๅ ไธไธ | |
| if pad_top > 0: | |
| top_row = arr[pad_top:pad_top+1, :, :] | |
| arr[:pad_top, :, :] = np.broadcast_to(top_row, (pad_top, bucket_w, 3)) | |
| bottom_start = pad_top + content_h | |
| if bottom_start < bucket_h: | |
| bottom_row = arr[bottom_start-1:bottom_start, :, :] | |
| arr[bottom_start:, :, :] = np.broadcast_to(bottom_row, (bucket_h - bottom_start, bucket_w, 3)) | |
| processed.append(Image.fromarray(arr)) | |
| return processed, bucket_w, bucket_h, pad_info | |
| def extract_pil_from_source(source) -> Image.Image: | |
| """ไป gr.ImageEditor dict ๆๆฎ้่ทฏๅพ/PIL ไธญๆๅๅพ็๏ผไฝฟ็จ composite ไฟ็ๆถ่ฒๆ ๆณจ๏ผใ""" | |
| if source is None: | |
| return None | |
| if isinstance(source, dict): | |
| img = source.get("composite") | |
| if img is None: | |
| img = source.get("background") | |
| if img is None: | |
| return None | |
| if isinstance(img, np.ndarray): | |
| return Image.fromarray(img).convert("RGB") | |
| return img.convert("RGB") | |
| return load_pil_image(source) | |
| def format_info(seed_val, source_img, ref_img): | |
| lines = [f"**Seed:** `{int(seed_val)}`"] | |
| for label, img in [("Source", source_img), ("Reference", ref_img)]: | |
| if img is None: | |
| continue | |
| try: | |
| pil = extract_pil_from_source(img) if label == "Source" else load_pil_image(img) | |
| ow, oh = pil.size | |
| nw, nh = pick_best_bucket(ow, oh, SAFE_BUCKETS, UPSCALE_SMALL_IMAGES) | |
| lines.append( | |
| f"\n**{label}:** {ow}ร{oh} โ **{nw}ร{nh}** " | |
| f"(ratio {ow/oh:.3f} โ {nw/nh:.3f})" | |
| ) | |
| except Exception: | |
| pass | |
| return "\n\n".join(lines) | |
| # ========================================================= | |
| # INFERENCE | |
| # ========================================================= | |
| def infer( | |
| source_image, | |
| ref_image, | |
| prompt, | |
| negative_prompt, | |
| seed, | |
| randomize_seed, | |
| guidance_scale, | |
| steps, | |
| color_match, | |
| out_width=0, | |
| out_height=0, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| if source_image is None: | |
| raise gr.Error("โ ๏ธ Please upload a source image.") | |
| if not prompt or not prompt.strip(): | |
| raise gr.Error("โ ๏ธ Please enter an edit prompt.") | |
| # ๆๅๅๅพ๏ผๅ ผๅฎน ImageEditor dict ๅๆฎ้่ทฏๅพ๏ผ | |
| try: | |
| src_pil = extract_pil_from_source(source_image) | |
| except Exception as e: | |
| raise gr.Error(f"โ ๏ธ Could not load source image: {e}") | |
| if src_pil is None: | |
| raise gr.Error("โ ๏ธ Please upload a source image.") | |
| # ่ฎฐๅฝๅๅงๅฐบๅฏธ๏ผๆจ็ๅ resize ๅๆฅ๏ผ้ฟๅ 16 ๅฏน้ฝๅฏผ่ด่ฃๅช | |
| orig_size = src_pil.size # (w, h) | |
| # โโ ่ทฏ็ฑ๏ผๆ ๅพ โโ | |
| if "ๆ " in prompt: | |
| if "้ปๅบ" in prompt: | |
| # ้ปๅบ่ฑๅญ๏ผ่ฟ้ๅบๅๅป้คๅคๅด้ป่ฒ๏ผไฟ็ๆๅญๅ ้จ้ป่ฒ | |
| result = remove_black_bg(src_pil) | |
| else: | |
| # ๆฎ้ๆ ๅพ๏ผRMBG 2.0 ่ฏญไนๅๅฒ | |
| result = run_rmbg(src_pil) | |
| return result, seed | |
| # ๆถ้ๅพ็๏ผๅๅพๅฟ ้กป๏ผๅ่ๅพๅฏ้ | |
| pil_images = [src_pil] | |
| if ref_image is not None: | |
| try: | |
| pil_images.append(load_pil_image(ref_image)) | |
| except Exception as e: | |
| print(f"Warning: could not load reference image: {e}") | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| generator = torch.Generator(device=device).manual_seed(int(seed)) | |
| processed_images, width, height, pad_info = prepare_images_before_pipe( | |
| pil_images, allow_upscale=UPSCALE_SMALL_IMAGES | |
| ) | |
| # ๆพๅผๆๅฎ่พๅบๅฐบๅฏธ๏ผๅฏน้ฝ ComfyUI EmptyLatentImage ่กไธบ๏ผ | |
| if out_width > 0: | |
| width = (out_width // 16) * 16 | |
| if out_height > 0: | |
| height = (out_height // 16) * 16 | |
| try: | |
| result = pipe( | |
| image=processed_images, | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| height=height, | |
| width=width, | |
| num_inference_steps=steps, | |
| generator=generator, | |
| true_cfg_scale=guidance_scale, | |
| ).images[0] | |
| # โโ ่ฃๆ่กฅ่พน๏ผ่ฟๅๅฐๅๅงๆฏไพๅ ๅฎนๅบๅ โโ | |
| pad_left, pad_top, content_w, content_h = pad_info | |
| if pad_left > 0 or pad_top > 0 or content_w < width or content_h < height: | |
| result = result.crop((pad_left, pad_top, pad_left + content_w, pad_top + content_h)) | |
| # โโ ่ฟๅๅฐๅๅงๅฐบๅฏธ โโ | |
| if result.size != orig_size: | |
| result = result.resize(orig_size, Image.LANCZOS) | |
| if ref_image is not None and len(pil_images) > 1: | |
| result = add_image_watermark(result, pil_images[1]) | |
| if color_match: | |
| # ็จๅๅพ่ๆฏ๏ผๆ ็ฌ่ฟน๏ผไฝไธบ่ฒๅฝฉๅ่ | |
| if isinstance(source_image, dict): | |
| bg = source_image.get("background") | |
| if bg is not None: | |
| ref_pil = Image.fromarray(bg).convert("RGB") if isinstance(bg, np.ndarray) else bg.convert("RGB") | |
| else: | |
| ref_pil = src_pil | |
| else: | |
| ref_pil = src_pil | |
| ref_pil_resized = ref_pil.resize(result.size, Image.LANCZOS) | |
| result = color_match_reinhard(ref_pil_resized, result) | |
| return result, seed | |
| finally: | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # ========================================================= | |
| # UI | |
| # ========================================================= | |
| # JS๏ผ็ญ ImageEditor ๆธฒๆๅฎๅ๏ผๆ็ปๅฏนๅฎไฝ็ๅทฅๅ ทๆ ๆนไธบ็ธๅฏนๅฎไฝ๏ผ | |
| # ไฝฟๅ ถไธๅๆฌๆตฎ่ฆ็็ปๅธ๏ผCSS ้ๆฉๅจไผ่ขซ Svelte ไฝ็จๅๅๅธ้ปๆก๏ผ | |
| # ๆไปฅ็จ JS ้่ฟ getComputedStyle ็ฒพ็กฎๆฃๆตๅนถๅผบๅถไฟฎๆน๏ผ | |
| _FIX_TOOLBAR_JS = """ | |
| () => { | |
| const setup = (ed) => { | |
| if (ed.dataset.toggleReady) return; | |
| // ๆพๅทฅๅ ทๆ ๅ ็ด ๏ผGradio/Svelte ไผ็ป class ๅ ๅๅธ๏ผ็จ includes ๅน้ ๏ผ | |
| const toolbar = Array.from(ed.querySelectorAll('*')).find(el => { | |
| const cls = el.getAttribute('class') || ''; | |
| return cls.includes('toolbar') || cls.includes('tool-bar'); | |
| }); | |
| if (!toolbar) return; | |
| ed.dataset.toggleReady = '1'; | |
| // ๆๅ ฅๅๆขๆ้ฎ๏ผๆพๅจ toolbar ็็ถๅฎนๅจ็ฌฌไธไฝ | |
| const btn = document.createElement('button'); | |
| btn.className = 'toolbar-toggle-btn'; | |
| btn.textContent = '๐จ ้่็ป็ฌๅทฅๅ ทๆ '; | |
| let hidden = false; | |
| btn.onclick = () => { | |
| hidden = !hidden; | |
| // ็จ visibility ่้ display๏ผ้ฟๅ ็ปๅธๅบๅ่ทณๅจ | |
| toolbar.style.visibility = hidden ? 'hidden' : ''; | |
| toolbar.style.pointerEvents = hidden ? 'none' : ''; | |
| btn.textContent = hidden ? '๐จ ๆพ็คบ็ป็ฌๅทฅๅ ทๆ ' : '๐จ ้่็ป็ฌๅทฅๅ ทๆ '; | |
| }; | |
| toolbar.parentNode.insertBefore(btn, toolbar); | |
| }; | |
| const mo = new MutationObserver(() => { | |
| document.querySelectorAll('.src-editor').forEach(setup); | |
| }); | |
| mo.observe(document.body, { childList: true, subtree: true }); | |
| setTimeout(() => document.querySelectorAll('.src-editor').forEach(setup), 1000); | |
| } | |
| """ | |
| with gr.Blocks( | |
| theme=theme, | |
| js=_FIX_TOOLBAR_JS, | |
| css=""" | |
| .gradio-container { | |
| max-width: 1400px !important; | |
| margin: 0 auto; | |
| padding-top: 20px; | |
| } | |
| .hero { | |
| text-align: center; | |
| padding: 24px 0 12px 0; | |
| } | |
| .hero h1 { | |
| font-size: 2.2rem; | |
| font-weight: 800; | |
| margin-bottom: 8px; | |
| } | |
| .hero p { | |
| font-size: 1rem; | |
| color: #666; | |
| margin-bottom: 0; | |
| } | |
| /* ๅทฅๅ ทๆ ้่ๆถ๏ผ้่ๆ้ฎไปๅฏ็นๅป */ | |
| .toolbar-toggle-btn { | |
| display: block; | |
| width: 100%; | |
| padding: 4px 10px; | |
| margin-bottom: 2px; | |
| background: #f0f0f0; | |
| border: 1px solid #ddd; | |
| border-radius: 4px; | |
| font-size: 12px; | |
| cursor: pointer; | |
| text-align: left; | |
| color: #555; | |
| } | |
| """, | |
| ) as demo: | |
| gr.HTML(""" | |
| <div class="hero"> | |
| <h1>๐ฅ FireRed Image Edit 1.1 Fast</h1> | |
| </div> | |
| """) | |
| with gr.Tabs(): | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # Tab 1: AI ็ผ่พ | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| with gr.Tab("AI Edit"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| source_input = gr.ImageEditor( | |
| label="Source Image โ ๅฏ็จ็ป็ฌๆ ๆณจๅบๅ๏ผ็บข/็ปฟ/่็ญ๏ผ๏ผๆ็คบ่ฏไธญๅผ็จ้ข่ฒ", | |
| elem_classes=["src-editor"], | |
| brush=gr.Brush( | |
| colors=["#FF0000", "#00CC00", "#0066FF", "#FFFF00", "#FF00FF", "#FFFFFF"], | |
| color_mode="defaults", | |
| ), | |
| ) | |
| gr.Markdown( | |
| "<small>๐ด็บข ๐ข็ปฟ ๐ต่ ๐ก้ป ๐ฃ็ดซ โฌ็ฝ โ ็ปๅฅฝๅๆ็คบ่ฏๅ๏ผ*ๅปๆ็บข่ฒๆ ๆณจ็ๅบๅ* ็ญ</small>" | |
| ) | |
| with gr.Row(): | |
| ref_input = gr.Image( | |
| label="Reference Image๏ผๅ่ๅพ๏ผๅฏ้๏ผ", | |
| type="filepath", | |
| sources=["upload", "clipboard"], | |
| ) | |
| prompt_input = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Describe how you want to edit the image...", | |
| lines=4, | |
| ) | |
| negative_prompt_input = gr.Textbox( | |
| label="Negative Prompt", | |
| value=DEFAULT_NEGATIVE_PROMPT, | |
| lines=3, | |
| ) | |
| color_match_input = gr.Checkbox(label="Color Match โ ่ฒๅฝฉๅฏน้ฝๅๅพ", value=True) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| seed_input = gr.Slider( | |
| label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, | |
| ) | |
| randomize_seed_input = gr.Checkbox(label="Randomize Seed", value=True) | |
| guidance_scale_input = gr.Slider( | |
| label="Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=1.0, | |
| ) | |
| steps_input = gr.Slider( | |
| label="Inference Steps", minimum=1, maximum=50, step=1, value=4, | |
| ) | |
| run_button = gr.Button("Generate", variant="primary") | |
| info_markdown = gr.Markdown() | |
| with gr.Column(scale=1): | |
| output_image = gr.Image(label="Result", type="pil") | |
| for trigger in [source_input, ref_input, seed_input]: | |
| trigger.change( | |
| fn=format_info, | |
| inputs=[seed_input, source_input, ref_input], | |
| outputs=[info_markdown], | |
| ) | |
| run_button.click( | |
| fn=infer, | |
| inputs=[ | |
| source_input, ref_input, prompt_input, negative_prompt_input, | |
| seed_input, randomize_seed_input, guidance_scale_input, steps_input, | |
| color_match_input, | |
| ], | |
| outputs=[output_image, seed_input], | |
| ).then( | |
| fn=format_info, | |
| inputs=[seed_input, source_input, ref_input], | |
| outputs=[info_markdown], | |
| ) | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # Tab 2: PNG ่ดดๅพ๏ผ็ป mask โ ็ญๆฏ่ดดๅ ฅ๏ผ | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| with gr.Tab("PNG Placement"): | |
| gr.Markdown("**็จๆณ๏ผ** ไธไผ ๅบๅพๅๅจๅพไธๆถๆนๅบๆพ็ฝฎๅบๅ๏ผๅไธไผ PNG๏ผ็นๅป ApplyใPNG ไผ็ญๆฏ็ผฉๆพ๏ผๆ้ฟ่พนๅฏน้ฝ mask ๆ้ฟ่พน๏ผๅฑ ไธญ่ดดๅ ฅใ") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| mask_editor = gr.ImageEditor( | |
| label="Source Image โ ๅจๅพไธๆถๆนๅบๆพ็ฝฎๅบๅ", | |
| brush=gr.Brush(colors=["#FF6B35"], color_mode="fixed"), | |
| ) | |
| png_input = gr.Image( | |
| label="PNG to place๏ผๆฏๆ้ๆ่ๆฏ๏ผ", | |
| type="numpy", | |
| sources=["upload", "clipboard"], | |
| image_mode="RGBA", | |
| ) | |
| apply_button = gr.Button("Apply", variant="primary") | |
| with gr.Column(scale=1): | |
| placement_output = gr.Image(label="Result", type="pil") | |
| apply_button.click( | |
| fn=paste_png_into_mask, | |
| inputs=[mask_editor, png_input], | |
| outputs=[placement_output], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |