import math import os import threading import time try: import spaces _HAS_SPACES = True except ImportError: _HAS_SPACES = False import torch from diffusers import FlowMatchEulerDiscreteScheduler, QwenImageEditPlusPipeline from huggingface_hub import hf_hub_download def calculate_dimensions(target_area: int, ratio: float): width = math.sqrt(target_area * ratio) height = width / ratio width = round(width / 32) * 32 height = round(height / 32) * 32 return int(width), int(height), None def vit_resize_dims(src_w: int, src_h: int, vit_resize_size: int = 384) -> tuple[int, int]: ratio = float(src_w) / float(src_h) if src_h else 1.0 new_w, new_h, _ = calculate_dimensions(vit_resize_size * vit_resize_size, ratio) return new_w, new_h def scale_bbox_xyxy( bbox_xyxy: tuple[int, int, int, int], src_w: int, src_h: int, dst_w: int, dst_h: int, ) -> tuple[int, int, int, int]: sx = float(dst_w) / float(src_w) if src_w else 1.0 sy = float(dst_h) / float(src_h) if src_h else 1.0 x1, y1, x2, y2 = bbox_xyxy return ( int(round(x1 * sx)), int(round(y1 * sy)), int(round(x2 * sx)), int(round(y2 * sy)), ) def format_bbox_xyxy(bbox_xyxy: tuple[int, int, int, int]) -> str: x1, y1, x2, y2 = bbox_xyxy return f"[{x1}, {y1}, {x2}, {y2}]" def draw_bbox_on_image(image, bbox_xyxy: tuple[int, int, int, int]): from PIL import ImageDraw x1, y1, x2, y2 = bbox_xyxy vis = image.copy() draw = ImageDraw.Draw(vis) w = max(2, int(round(min(vis.size) * 0.006))) draw.rectangle((x1, y1, x2, y2), outline=(255, 64, 64), width=w) return vis def draw_points_on_image(image, points: list[tuple[int, int]], *, connect: bool = False): from PIL import ImageDraw vis = image.copy() draw = ImageDraw.Draw(vis) w, h = vis.size r = max(2, int(round(min(w, h) * 0.004))) if connect and len(points) >= 2: draw.line(points + [points[0]], fill=(255, 64, 64), width=max(1, r // 2)) for x, y in points: draw.ellipse((x - r, y - r, x + r, y + r), fill=(64, 255, 64), outline=(0, 0, 0)) return vis _HF_LORA_REPO = "limuloo1999/RefineAnything" _HF_LORA_FILENAME = "Qwen-Image-Edit-2511-RefineAny.safetensors" _HF_LORA_ADAPTER = "refine_anything" _LIGHTNING_LOADED = False _PIPELINE_LOCK = threading.Lock() def _build_pipeline(model_dir: str): """Build the pipeline at module level. ZeroGPU intercepts .to('cuda') and keeps the model on CPU until a @spaces.GPU function runs.""" scheduler_config = { "base_image_seq_len": 256, "base_shift": math.log(3), "invert_sigmas": False, "max_image_seq_len": 8192, "max_shift": math.log(3), "num_train_timesteps": 1000, "shift": 1.0, "shift_terminal": None, "stochastic_sampling": False, "time_shift_type": "exponential", "use_beta_sigmas": False, "use_dynamic_shifting": True, "use_exponential_sigmas": False, "use_karras_sigmas": False, } scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config) pipe = QwenImageEditPlusPipeline.from_pretrained( model_dir, torch_dtype=torch.bfloat16, scheduler=scheduler, ) pipe.set_progress_bar_config(disable=None) local_path = hf_hub_download( repo_id=_HF_LORA_REPO, filename=_HF_LORA_FILENAME, ) lora_dir = os.path.dirname(local_path) weight_name = os.path.basename(local_path) pipe.load_lora_weights(lora_dir, weight_name=weight_name, adapter_name=_HF_LORA_ADAPTER) pipe.to("cuda") return pipe _DEFAULT_MODEL_DIR = os.environ.get("MODEL_DIR", "Qwen/Qwen-Image-Edit-2511") print(f"[startup] Loading pipeline from {_DEFAULT_MODEL_DIR} ...") _PIPELINE = _build_pipeline(_DEFAULT_MODEL_DIR) print("[startup] Pipeline ready.") def _get_pipeline(load_lightning_lora: bool): global _LIGHTNING_LOADED with _PIPELINE_LOCK: if load_lightning_lora and not _LIGHTNING_LOADED: lightning_path = hf_hub_download( repo_id="lightx2v/Qwen-Image-Edit-2511-Lightning", filename="Qwen-Image-Edit-2511-Lightning-8steps-V1.0-bf16.safetensors", ) lightning_dir = os.path.dirname(lightning_path) lightning_weight = os.path.basename(lightning_path) _PIPELINE.load_lora_weights(lightning_dir, weight_name=lightning_weight, adapter_name="lightning") _LIGHTNING_LOADED = True adapter_names: list[str] = [_HF_LORA_ADAPTER] adapter_weights: list[float] = [1.0] if _LIGHTNING_LOADED: adapter_names.append("lightning") adapter_weights.append(1.0 if load_lightning_lora else 0.0) if hasattr(_PIPELINE, "set_adapters"): try: _PIPELINE.set_adapters(adapter_names, adapter_weights=adapter_weights) except TypeError: _PIPELINE.set_adapters(adapter_names, adapter_weights=[1.0] * len(adapter_names)) return _PIPELINE def build_app(): import base64 import gradio as gr import inspect import io import numpy as np import random import re from PIL import Image def _to_float01_rgb(img: Image.Image) -> np.ndarray: arr = np.asarray(img.convert("RGB")).astype(np.float32) / 255.0 return arr def _to_float01_mask(mask_img: Image.Image) -> np.ndarray: arr = np.asarray(mask_img.convert("L")).astype(np.float32) / 255.0 return arr def composite_masked( *, destination: Image.Image, source: Image.Image, mask: Image.Image, resize_source: bool = True, ) -> Image.Image: dst = destination.convert("RGB") if resize_source and getattr(source, "size", None) != dst.size: src = source.convert("RGB").resize(dst.size, resample=Image.BICUBIC) else: src = source.convert("RGB") m = mask.convert("L") if getattr(m, "size", None) != dst.size: m = m.resize(dst.size, resample=Image.BILINEAR) dst_f = _to_float01_rgb(dst) src_f = _to_float01_rgb(src) m_f = _to_float01_mask(m)[:, :, None] out = src_f * m_f + dst_f * (1.0 - m_f) out = np.clip(out * 255.0 + 0.5, 0, 255).astype(np.uint8) return Image.fromarray(out, mode="RGB") def prepare_paste_mask( mask_l: Image.Image, *, mask_grow: int = 0, blend_kernel: int = 0, ) -> Image.Image: from PIL import ImageFilter m = mask_l.convert("L") if mask_grow and int(mask_grow) > 0: k = 2 * int(mask_grow) + 1 m = m.filter(ImageFilter.MaxFilter(size=k)) if blend_kernel and int(blend_kernel) > 0: m = m.filter(ImageFilter.GaussianBlur(radius=float(blend_kernel))) return m def make_bbox_mask( *, size: tuple[int, int], bbox_xyxy: tuple[int, int, int, int], mask_grow: int = 0, blend_kernel: int = 0, ) -> Image.Image: from PIL import ImageDraw, ImageFilter w, h = size x1, y1, x2, y2 = bbox_xyxy x1 = max(0, min(w - 1, int(x1))) y1 = max(0, min(h - 1, int(y1))) x2 = max(1, min(w, int(x2))) y2 = max(1, min(h, int(y2))) m = Image.new("L", (w, h), 0) draw = ImageDraw.Draw(m) draw.rectangle((x1, y1, max(x1, x2 - 1), max(y1, y2 - 1)), fill=255) if mask_grow and int(mask_grow) > 0: k = 2 * int(mask_grow) + 1 m = m.filter(ImageFilter.MaxFilter(size=k)) if blend_kernel and int(blend_kernel) > 0: m = m.filter(ImageFilter.GaussianBlur(radius=float(blend_kernel))) return m def compute_crop_box_xyxy( *, image_size: tuple[int, int], bbox_xyxy: tuple[int, int, int, int], margin: int, ) -> tuple[int, int, int, int]: w, h = image_size x1, y1, x2, y2 = bbox_xyxy m = max(0, int(margin)) cx1 = max(0, min(w - 1, int(x1) - m)) cy1 = max(0, min(h - 1, int(y1) - m)) cx2 = max(1, min(w, int(x2) + m)) cy2 = max(1, min(h, int(y2) + m)) if cx2 <= cx1: cx2 = min(w, cx1 + 1) if cy2 <= cy1: cy2 = min(h, cy1 + 1) return (cx1, cy1, cx2, cy2) def crop_box_from_1024_area_margin( *, image_size: tuple[int, int], bbox_xyxy: tuple[int, int, int, int], margin: int, ) -> tuple[int, int, int, int]: iw, ih = image_size if iw <= 0 or ih <= 0: return compute_crop_box_xyxy(image_size=image_size, bbox_xyxy=bbox_xyxy, margin=margin) s = math.sqrt(1024 * 1024 / float(iw * ih)) vw, vh = float(iw) * s, float(ih) * s x1, y1, x2, y2 = bbox_xyxy vx1 = max(0.0, min(vw - 1.0, float(x1) * s - float(margin))) vy1 = max(0.0, min(vh - 1.0, float(y1) * s - float(margin))) vx2 = max(1.0, min(vw, float(x2) * s + float(margin))) vy2 = max(1.0, min(vh, float(y2) * s + float(margin))) if vx2 <= vx1: vx2 = min(vw, vx1 + 1.0) if vy2 <= vy1: vy2 = min(vh, vy1 + 1.0) cx1 = max(0, min(iw - 1, int(math.floor(vx1 / s)))) cy1 = max(0, min(ih - 1, int(math.floor(vy1 / s)))) cx2 = max(1, min(iw, int(math.ceil(vx2 / s)))) cy2 = max(1, min(ih, int(math.ceil(vy2 / s)))) if cx2 <= cx1: cx2 = min(iw, cx1 + 1) if cy2 <= cy1: cy2 = min(ih, cy1 + 1) return (cx1, cy1, cx2, cy2) def offset_bbox_xyxy(bbox_xyxy: tuple[int, int, int, int], dx: int, dy: int) -> tuple[int, int, int, int]: x1, y1, x2, y2 = bbox_xyxy return (int(x1) - int(dx), int(y1) - int(dy), int(x2) - int(dx), int(y2) - int(dy)) def _decode_data_url(x): if not isinstance(x, str): return None s = x if s.startswith("data:") and "," in s: s = s.split(",", 1)[1] try: data = base64.b64decode(s) except Exception: return None try: return Image.open(io.BytesIO(data)) except Exception: return None def _to_rgb_pil(x, *, label: str): if x is None: return None if isinstance(x, str): x2 = _decode_data_url(x) if x2 is None: raise gr.Error(f"{label} 数据格式不支持") x = x2 if isinstance(x, np.ndarray): x = Image.fromarray(x.astype(np.uint8)) if not hasattr(x, "convert"): raise gr.Error(f"{label} 数据格式不支持") try: return x.convert("RGB") except Exception as e: raise gr.Error(f"{label} 转换 RGB 失败: {type(e).__name__}: {e}") def mask_to_points_sample_list(mask_img: Image.Image, *, num_points: int = 64, seed: int = 0) -> tuple[str, list[tuple[int, int]]]: arr = np.array(mask_img.convert("L"), dtype=np.uint8) if arr.max() <= 1: mask = arr.astype(bool) else: mask = arr > 0 ys, xs = np.where(mask) if xs.size == 0: raise gr.Error("mask 为空,无法从中采样点") rng = random.Random(int(seed)) idxs = list(range(int(xs.size))) rng.shuffle(idxs) idxs = idxs[: int(num_points)] pts = [(int(xs[i]), int(ys[i])) for i in idxs] s = "[" + ", ".join(f"({int(x)},{int(y)})" for (x, y) in pts) + "]" return s, pts def strip_special_region(prompt: str) -> str: p = (prompt or "").replace("", " ") p = p.replace("\n", " ") p = re.sub(r"\s{2,}", " ", p).strip() return p def strip_location_text(prompt: str) -> str: p = strip_special_region(prompt) p = re.sub(r"\[\s*\d+\s*,\s*\d+\s*,\s*\d+\s*,\s*\d+\s*\]", "", p) p = re.sub(r"\s{2,}", " ", p).strip() return p def mask_has_foreground(mask_l: Image.Image) -> bool: arr = np.array(mask_l.convert("L"), dtype=np.uint8) return bool(arr.max() > 0) def mask_bbox_xyxy(mask_img_l: Image.Image) -> tuple[int, int, int, int] | None: arr = np.array(mask_img_l.convert("L"), dtype=np.uint8) ys, xs = np.where(arr > 0) if xs.size == 0 or ys.size == 0: return None x1 = int(xs.min()) x2 = int(xs.max()) + 1 y1 = int(ys.min()) y2 = int(ys.max()) + 1 w, h = mask_img_l.size x1 = max(0, min(w - 1, x1)) y1 = max(0, min(h - 1, y1)) x2 = max(1, min(w, x2)) y2 = max(1, min(h, y2)) if x2 <= x1 or y2 <= y1: return None return (x1, y1, x2, y2) def render_spatial_prompt(mask_img_l: Image.Image, *, source: str, bbox_margin: int = 0) -> Image.Image | None: src = (source or "mask").strip().lower() if src == "bbox": bbox = mask_bbox_xyxy(mask_img_l) if bbox is None: return None w, h = mask_img_l.size out = Image.new("L", (w, h), 0) x1, y1, x2, y2 = bbox m = max(0, int(bbox_margin)) x1 = max(0, x1 - m) y1 = max(0, y1 - m) x2 = min(w, x2 + m) y2 = min(h, y2 + m) from PIL import ImageDraw draw = ImageDraw.Draw(out) draw.rectangle((x1, y1, max(x1, x2 - 1), max(y1, y2 - 1)), fill=255) return out arr = np.array(mask_img_l.convert("L"), dtype=np.uint8) arr = np.where(arr > 0, 255, 0).astype(np.uint8) return Image.fromarray(arr, mode="L") def overlay_mask_on_image(image_rgb: Image.Image, mask_l: Image.Image) -> Image.Image: base = image_rgb.convert("RGB") m = mask_l.convert("L") if getattr(m, "size", None) != base.size: m = m.resize(base.size, resample=Image.NEAREST) base_f = np.asarray(base).astype(np.float32) mf = (np.asarray(m).astype(np.float32) > 0)[:, :, None].astype(np.float32) color = np.array([64.0, 255.0, 64.0], dtype=np.float32)[None, None, :] alpha = 0.35 out = base_f * (1.0 - alpha * mf) + color * (alpha * mf) out = np.clip(out + 0.5, 0, 255).astype(np.uint8) return Image.fromarray(out, mode="RGB") def extract_bbox_from_image1(image1_value): if image1_value is None: raise gr.Error("image1 必须上传") if not isinstance(image1_value, dict): raise gr.Error("image1 数据格式不支持") if "image" in image1_value and "mask" in image1_value: img = image1_value["image"] mask = image1_value["mask"] if img is None: raise gr.Error("image1 必须上传") elif "background" in image1_value and "layers" in image1_value: img = image1_value.get("background") or image1_value.get("composite") layers = image1_value.get("layers") or [] if img is None: raise gr.Error("image1 数据缺少 background/composite") mask = layers if layers else None else: raise gr.Error("请在 image1 上涂抹选择区域") if isinstance(img, str): img2 = _decode_data_url(img) if img2 is None: raise gr.Error("image1 数据格式不支持(image)") img = img2 if isinstance(mask, str): mask2 = _decode_data_url(mask) if mask2 is None: raise gr.Error("image1 数据格式不支持(mask)") mask = mask2 if isinstance(img, np.ndarray): img_pil = Image.fromarray(img.astype(np.uint8)) else: img_pil = img if hasattr(img_pil, "convert"): img_pil = img_pil.convert("RGB") iw, ih = img_pil.size vit_w, vit_h = vit_resize_dims(iw, ih, vit_resize_size=384) if mask is None: return img_pil, None, None, None, (vit_w, vit_h) if isinstance(mask, list): mask_arr = np.zeros((img_pil.size[1], img_pil.size[0]), dtype=np.uint8) for layer in mask: if isinstance(layer, str): layer2 = _decode_data_url(layer) if layer2 is None: continue layer = layer2 if isinstance(layer, np.ndarray): layer_pil = Image.fromarray(layer.astype(np.uint8)) else: layer_pil = layer if layer_pil is None: continue if getattr(layer_pil, "size", None) != img_pil.size: layer_pil = layer_pil.resize(img_pil.size) layer_arr = np.array(layer_pil, dtype=np.uint8) if layer_arr.ndim == 3 and layer_arr.shape[2] >= 4: layer_mask = layer_arr[:, :, 3] elif layer_arr.ndim == 3: layer_mask = layer_arr.max(axis=2) else: layer_mask = layer_arr mask_arr = np.maximum(mask_arr, layer_mask.astype(np.uint8)) elif isinstance(mask, np.ndarray): mask_arr = mask.astype(np.uint8) if mask_arr.ndim == 3: mask_arr = mask_arr.max(axis=2) mask_pil_l = Image.fromarray(mask_arr, mode="L") if getattr(mask_pil_l, "size", None) != img_pil.size: mask_pil_l = mask_pil_l.resize(img_pil.size, resample=Image.NEAREST) mask_arr = np.array(mask_pil_l, dtype=np.uint8) else: mask_pil_l = mask.convert("L") if getattr(mask_pil_l, "size", None) != img_pil.size: mask_pil_l = mask_pil_l.resize(img_pil.size, resample=Image.NEAREST) mask_arr = np.array(mask_pil_l, dtype=np.uint8) if isinstance(mask, list): mask_pil_l = Image.fromarray(mask_arr, mode="L") ys, xs = np.where(mask_arr > 0) if xs.size == 0 or ys.size == 0: return img_pil, None, None, None, (vit_w, vit_h) x1 = int(xs.min()) x2 = int(xs.max()) + 1 y1 = int(ys.min()) y2 = int(ys.max()) + 1 x1 = max(0, min(iw - 1, x1)) y1 = max(0, min(ih - 1, y1)) x2 = max(1, min(iw, x2)) y2 = max(1, min(ih, y2)) bbox_raw = (x1, y1, x2, y2) bbox_vit = scale_bbox_xyxy(bbox_raw, iw, ih, vit_w, vit_h) return img_pil, mask_pil_l, bbox_raw, bbox_vit, (vit_w, vit_h) def extract_ref_from_image2(image2_value): """Return (ref_pil_rgb | None, crop_info_str | None). If the user painted on image2, crop to the brush bounding-box and return only that region. Otherwise return the full image. """ if image2_value is None: return None, None if not isinstance(image2_value, dict): return _to_rgb_pil(image2_value, label="image2"), None if "image" in image2_value and "mask" in image2_value: img = image2_value["image"] mask = image2_value["mask"] elif "background" in image2_value and "layers" in image2_value: img = image2_value.get("background") or image2_value.get("composite") layers = image2_value.get("layers") or [] mask = layers if layers else None else: img = image2_value mask = None if img is None: return None, None if isinstance(img, str): img2 = _decode_data_url(img) if img2 is None: return None, None img = img2 if isinstance(img, np.ndarray): img_pil = Image.fromarray(img.astype(np.uint8)) else: img_pil = img img_pil = img_pil.convert("RGB") if mask is None: return img_pil, None if isinstance(mask, list): mask_arr = np.zeros((img_pil.size[1], img_pil.size[0]), dtype=np.uint8) for layer in mask: if isinstance(layer, str): layer2 = _decode_data_url(layer) if layer2 is None: continue layer = layer2 if isinstance(layer, np.ndarray): layer_pil = Image.fromarray(layer.astype(np.uint8)) else: layer_pil = layer if layer_pil is None: continue if getattr(layer_pil, "size", None) != img_pil.size: layer_pil = layer_pil.resize(img_pil.size) layer_arr = np.array(layer_pil, dtype=np.uint8) if layer_arr.ndim == 3 and layer_arr.shape[2] >= 4: layer_mask = layer_arr[:, :, 3] elif layer_arr.ndim == 3: layer_mask = layer_arr.max(axis=2) else: layer_mask = layer_arr mask_arr = np.maximum(mask_arr, layer_mask.astype(np.uint8)) elif isinstance(mask, np.ndarray): mask_arr = mask.astype(np.uint8) if mask_arr.ndim == 3: mask_arr = mask_arr.max(axis=2) tmp = Image.fromarray(mask_arr, mode="L") if getattr(tmp, "size", None) != img_pil.size: tmp = tmp.resize(img_pil.size, resample=Image.NEAREST) mask_arr = np.array(tmp, dtype=np.uint8) else: tmp = mask.convert("L") if getattr(tmp, "size", None) != img_pil.size: tmp = tmp.resize(img_pil.size, resample=Image.NEAREST) mask_arr = np.array(tmp, dtype=np.uint8) ys, xs = np.where(mask_arr > 0) if xs.size == 0 or ys.size == 0: return img_pil, None iw, ih = img_pil.size x1 = max(0, min(iw - 1, int(xs.min()))) y1 = max(0, min(ih - 1, int(ys.min()))) x2 = max(1, min(iw, int(xs.max()) + 1)) y2 = max(1, min(ih, int(ys.max()) + 1)) cropped = img_pil.crop((x1, y1, x2, y2)) crop_info = f"ref_crop=[{x1},{y1},{x2},{y2}] ({x2 - x1}x{y2 - y1})" return cropped, crop_info def _predict_impl( image1_value, image2, prompt, mode, spatial_source, seed, steps, true_cfg_scale, guidance_scale, load_lightning_lora, paste_back_bbox, paste_back_mode, focus_crop_for_bbox, focus_crop_margin, paste_mask_grow, paste_blend_kernel, ): prompt = (prompt or "").strip() if not prompt: raise gr.Error("prompt 为空") img_pil, mask_pil_l, bbox_raw, bbox_vit, (vit_w, vit_h) = extract_bbox_from_image1(image1_value) img_pil = _to_rgb_pil(img_pil, label="image1") image2, ref_crop_info = extract_ref_from_image2(image2) has_mask = (mask_pil_l is not None) and mask_has_foreground(mask_pil_l) has_bbox = bbox_raw is not None use_focus_crop = bool(paste_back_bbox) and bool(focus_crop_for_bbox) and has_bbox crop_xyxy = None bbox_for_model_raw = bbox_raw img_for_model = img_pil image2_for_model = image2 mask_for_model_l = mask_pil_l if has_mask else None vit_wh_for_prompt = (vit_w, vit_h) if use_focus_crop: iw, ih = img_pil.size margin = int(focus_crop_margin) if focus_crop_margin is not None and str(focus_crop_margin).strip() else 0 crop_xyxy = crop_box_from_1024_area_margin(image_size=(iw, ih), bbox_xyxy=bbox_raw, margin=margin) cx1, cy1, cx2, cy2 = crop_xyxy img_for_model = img_pil.crop((cx1, cy1, cx2, cy2)) bbox_for_model_raw = offset_bbox_xyxy(bbox_raw, cx1, cy1) if has_mask and mask_pil_l is not None: mask_for_model_l = mask_pil_l.crop((cx1, cy1, cx2, cy2)) vit_w2, vit_h2 = vit_resize_dims(img_for_model.size[0], img_for_model.size[1], vit_resize_size=384) vit_wh_for_prompt = (vit_w2, vit_h2) bbox_vit = scale_bbox_xyxy(bbox_for_model_raw, img_for_model.size[0], img_for_model.size[1], vit_w2, vit_h2) prompt_for_model = strip_location_text(prompt) spatial_source = (spatial_source or "mask").strip().lower() spatial_mask_l = None if mask_for_model_l is not None and mask_has_foreground(mask_for_model_l): spatial_mask_l = render_spatial_prompt(mask_for_model_l, source=spatial_source, bbox_margin=0) info = "" if has_bbox: info = f"BBox(raw)={format_bbox_xyxy(bbox_raw)}" else: info = "未检测到涂抹区域" if has_bbox: info += f" -> QwenVit(384-area)={format_bbox_xyxy(bbox_vit)} vit_wh=({vit_wh_for_prompt[0]},{vit_wh_for_prompt[1]})" if ref_crop_info: info += f" {ref_crop_info}" if spatial_mask_l is not None: info += f" spatial={spatial_source}" if crop_xyxy is not None: info += f" crop={format_bbox_xyxy(crop_xyxy)} bbox_in_crop={format_bbox_xyxy(bbox_for_model_raw)}" vis_base = img_for_model.resize(vit_wh_for_prompt, resample=Image.BICUBIC) if spatial_mask_l is not None: spatial_vis = spatial_mask_l.resize(vit_wh_for_prompt, resample=Image.NEAREST) vis = overlay_mask_on_image(vis_base, spatial_vis) elif has_bbox: vis = draw_bbox_on_image(vis_base, bbox_vit) else: vis = vis_base if mode == "Preview only": return (img_pil, img_pil), prompt_for_model, vis, "Done" seed = int(seed) if seed is not None and str(seed).strip() else 0 steps = int(steps) if steps is not None and str(steps).strip() else 8 true_cfg_scale = float(true_cfg_scale) if true_cfg_scale is not None and str(true_cfg_scale).strip() else 4.0 guidance_scale = float(guidance_scale) if guidance_scale is not None and str(guidance_scale).strip() else 1.0 pipe = _get_pipeline(load_lightning_lora=bool(load_lightning_lora)) img = img_for_model if image2_for_model is None else [img_for_model, image2_for_model] if spatial_mask_l is not None: spatial_rgb = spatial_mask_l.convert("RGB") if isinstance(img, list): img = img + [spatial_rgb] else: img = [img, spatial_rgb] gen = torch.Generator(device="cuda") gen.manual_seed(seed) t0 = time.time() with torch.inference_mode(): try: out = pipe( image=img, prompt=prompt_for_model, generator=gen, true_cfg_scale=true_cfg_scale, negative_prompt=" ", num_inference_steps=steps, guidance_scale=guidance_scale, num_images_per_prompt=1, ) except Exception as e: raise gr.Error(f"Inference failed: {type(e).__name__}: {e}") dt = time.time() - t0 out_img = out.images[0] if paste_back_bbox: paste_back_mode = (paste_back_mode or "mask").strip().lower() mg = int(paste_mask_grow) if paste_mask_grow is not None and str(paste_mask_grow).strip() else 0 bk = int(paste_blend_kernel) if paste_blend_kernel is not None and str(paste_blend_kernel).strip() else 0 paste_mask = None if paste_back_mode.startswith("mask") and mask_for_model_l is not None and mask_has_foreground(mask_for_model_l): paste_mask = prepare_paste_mask(mask_for_model_l, mask_grow=mg, blend_kernel=bk) elif bbox_for_model_raw is not None: paste_mask = make_bbox_mask(size=img_for_model.size, bbox_xyxy=bbox_for_model_raw, mask_grow=mg, blend_kernel=bk) if paste_mask is not None: out_img_crop = composite_masked(destination=img_for_model, source=out_img, mask=paste_mask, resize_source=True) if crop_xyxy is not None: cx1, cy1, cx2, cy2 = crop_xyxy out_full = img_pil.copy() out_full.paste(out_img_crop, (cx1, cy1)) out_img = out_full else: out_img = out_img_crop status = f"Done ({dt:.2f}s)" return (img_pil, out_img), prompt_for_model, vis, status if _HAS_SPACES: predict = spaces.GPU(duration=180)(_predict_impl) else: predict = _predict_impl _DESCRIPTION_EN = """\ **RefineAnything** refines local regions of an image guided by a text prompt. \ Upload a source image, **brush over the area** you want to edit, and describe the desired change. \ Optionally upload a reference image for style/content guidance — \ leave it as-is to reference the whole image, or **brush on it** to specify exactly which region to reference.\ """ _DESCRIPTION_CN = """\ **RefineAnything** 根据文字提示精修图片的局部区域。\ 上传一张源图,**用画笔涂抹**需要编辑的区域,再输入想要的修改描述即可。\ 可选上传第二张参考图来引导风格/内容——不涂抹则参考整张图,**涂抹则精确指定参考区域**。\ """ _NOTE_EN = ( "For refinement tasks, prompts starting with **refine** usually work better. " "The model also shows some grounding edit ability (for example **add**, **remove**, **modify**) " "even without dedicated grounding training." ) _NOTE_CN = ( "做 refine 任务时,prompt 以 **refine** 开头通常效果更好。" "此外模型也具备一定 grounding edit 能力(如 **add**、**remove**、**modify**)," "虽然我们没有使用专门的 grounding 数据训练。" ) def _randomize_seed(): return random.randint(0, 2**31 - 1) with gr.Blocks(title="RefineAnything", theme=gr.themes.Soft()) as demo: gr.Markdown("# RefineAnything") gr.Markdown(_DESCRIPTION_EN) gr.Markdown(_DESCRIPTION_CN) gr.Markdown(_NOTE_EN) gr.Markdown(_NOTE_CN) with gr.Row(): with gr.Column(): if hasattr(gr, "ImageMask"): image1 = gr.ImageMask(label="Source image (brush to select region)", type="pil") else: image1 = gr.Image(label="Source image", type="pil") with gr.Column(): if hasattr(gr, "ImageMask"): image2 = gr.ImageMask(label="Reference image (optional, brush to crop)", type="pil") else: image2 = gr.Image(label="Reference image (optional)", type="pil") prompt = gr.Textbox(label="Prompt", lines=2, placeholder="Describe the edit you want...") with gr.Row(): mode = gr.Radio(["Run inference", "Preview only"], value="Run inference", label="Mode", scale=2) seed = gr.Number(label="Seed", value=0, precision=0, scale=1) seed_dice = gr.Button("🎲 Random", scale=0, min_width=110) steps = gr.Number(label="Steps", value=8, precision=0, scale=1) with gr.Row(): spatial_source = gr.Radio(["mask", "bbox"], value="mask", label="Spatial prompt source", scale=2) load_lightning_lora = gr.Checkbox(label="Lightning LoRA (faster)", value=True, scale=1) with gr.Row(): paste_back_mode = gr.Radio(["mask", "bbox"], value="mask", label="Paste-back mode", scale=1) with gr.Column(scale=1): focus_crop_margin = gr.Number(label="Crop margin (px)", value=64, precision=0) gr.Markdown( "Note: Increasing this value usually improves harmony between the refined region and surrounding areas; decreasing it usually improves fine-detail recovery." ) with gr.Accordion("Advanced settings", open=False): with gr.Row(): true_cfg_scale = gr.Number(label="True CFG scale", value=4.0) guidance_scale = gr.Number(label="Guidance scale", value=1.0) with gr.Row(): paste_back_bbox = gr.Checkbox(label="Composite paste-back", value=True) focus_crop_for_bbox = gr.Checkbox(label="Focus-crop edit region", value=True) with gr.Row(): paste_mask_grow = gr.Number(label="Mask grow", value=3, precision=0) paste_blend_kernel = gr.Number(label="Blend kernel", value=5, precision=0) run_btn = gr.Button("Run", variant="primary", size="lg") gr.Markdown("### Output") out_image = gr.ImageSlider(label="Before / After") with gr.Row(): replaced_prompt = gr.Textbox(label="Actual prompt sent", lines=2) status = gr.Textbox(label="Status", lines=1) image1_vis = gr.Image(label="Input preview (ViT 384) + region overlay", type="pil") run_btn.click( fn=predict, inputs=[ image1, image2, prompt, mode, spatial_source, seed, steps, true_cfg_scale, guidance_scale, load_lightning_lora, paste_back_bbox, paste_back_mode, focus_crop_for_bbox, focus_crop_margin, paste_mask_grow, paste_blend_kernel, ], outputs=[out_image, replaced_prompt, image1_vis, status], ) seed_dice.click(fn=_randomize_seed, inputs=None, outputs=seed) return demo demo = build_app() if __name__ == "__main__": demo.launch(show_error=True, ssr_mode=False)