Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import sys | |
| import pathlib | |
| import subprocess | |
| import random | |
| from typing import Optional, Tuple | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| from PIL import Image, ImageOps | |
| import numpy as np | |
| import cv2 | |
| # ---------------- Paths & assets ---------------- | |
| BASE_DIR = pathlib.Path(__file__).resolve().parent | |
| CODE_DEPTH = BASE_DIR / "code_depth" | |
| CODE_EDIT = BASE_DIR / "code_edit" | |
| GET_ASSETS = BASE_DIR / "get_assets.sh" | |
| EXPECTED_ASSETS = [ | |
| BASE_DIR / "code_depth" / "checkpoints" / "video_depth_anything_vits.pth", | |
| BASE_DIR / "code_depth" / "checkpoints" / "video_depth_anything_vitl.pth", | |
| BASE_DIR / "code_edit" / "stage1" / "checkpoint-4800" / "pytorch_lora_weights.safetensors", | |
| BASE_DIR / "code_edit" / "stage2" / "checkpoint-20000" / "pytorch_lora_weights.safetensors", | |
| ] | |
| # import depth helper | |
| if str(CODE_DEPTH) not in sys.path: | |
| sys.path.insert(0, str(CODE_DEPTH)) | |
| from depth_infer import DepthModel # noqa: E402 | |
| # import your custom diffusers | |
| if str(CODE_EDIT / "diffusers") not in sys.path: | |
| sys.path.insert(0, str(CODE_EDIT / "diffusers")) | |
| from diffusers.pipelines.flux.pipeline_flux_fill_unmasked_image_condition_version import ( # type: ignore # noqa: E402 | |
| FluxFillPipeline_token12_depth_only as FluxFillPipeline, | |
| ) | |
| # ---------------- Assets ensure (on-demand) ---------------- | |
| def _have_all_assets() -> bool: | |
| return all(p.is_file() for p in EXPECTED_ASSETS) | |
| def _ensure_executable(p: pathlib.Path): | |
| if not p.exists(): | |
| raise FileNotFoundError(f"Not found: {p}") | |
| os.chmod(p, os.stat(p).st_mode | 0o111) | |
| def ensure_assets_if_missing(): | |
| if os.getenv("SKIP_ASSET_DOWNLOAD") == "1": | |
| print("↪️ SKIP_ASSET_DOWNLOAD=1 -> 跳过资产下载检查") | |
| return | |
| if _have_all_assets(): | |
| print("✅ Assets already present") | |
| return | |
| print("⬇️ Missing assets, running get_assets.sh ...") | |
| _ensure_executable(GET_ASSETS) | |
| subprocess.run( | |
| ["bash", str(GET_ASSETS)], | |
| check=True, | |
| cwd=str(BASE_DIR), | |
| env={**os.environ, "HF_HUB_DISABLE_TELEMETRY": "1"}, | |
| ) | |
| if not _have_all_assets(): | |
| missing = [str(p.relative_to(BASE_DIR)) for p in EXPECTED_ASSETS if not p.exists()] | |
| raise RuntimeError(f"Assets missing after get_assets.sh: {missing}") | |
| print("✅ Assets ready.") | |
| try: | |
| ensure_assets_if_missing() | |
| except Exception as e: | |
| print(f"⚠️ Asset prepare failed: {e}") | |
| # ---------------- Global singletons ---------------- | |
| _MODELS: dict[str, DepthModel] = {} | |
| _PIPE: Optional[FluxFillPipeline] = None | |
| def get_model(encoder: str) -> DepthModel: | |
| if encoder not in _MODELS: | |
| _MODELS[encoder] = DepthModel(BASE_DIR, encoder=encoder) | |
| return _MODELS[encoder] | |
| def get_pipe() -> FluxFillPipeline: | |
| global _PIPE | |
| if _PIPE is not None: | |
| return _PIPE | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.bfloat16 if device == "cuda" else torch.float32 | |
| print(f"[pipe] load FLUX.1-Fill-dev dtype={dtype}, device={device}") | |
| pipe = FluxFillPipeline.from_pretrained("black-forest-labs/FLUX.1-Fill-dev", torch_dtype=dtype).to(device) | |
| # LoRA(stage1) | |
| lora_dir = CODE_EDIT / "stage1" / "checkpoint-4800" | |
| if lora_dir.exists(): | |
| try: | |
| pipe.load_lora_weights(str(lora_dir)) # 需要 peft | |
| print(f"[pipe] loaded LoRA from: {lora_dir}") | |
| except Exception as e: | |
| print(f"[pipe] load LoRA failed (continue without): {e}") | |
| else: | |
| print(f"[pipe] LoRA path not found: {lora_dir} (continue without)") | |
| _PIPE = pipe | |
| return pipe | |
| # ---------------- Mask helpers ---------------- | |
| def to_grayscale_mask(im: Image.Image) -> Image.Image: | |
| """ | |
| 将任意 RGBA/RGB/L 的图转为 L。 | |
| 输出:白=需要移除/填充区域,黑=保留。 | |
| """ | |
| if im.mode == "RGBA": | |
| mask = im.split()[-1] # alpha as mask | |
| else: | |
| mask = im.convert("L") | |
| # 简单二值化,去噪 | |
| mask = mask.point(lambda p: 255 if p > 16 else 0) | |
| return mask # 不做 invert,白色=mask区域 | |
| def dilate_mask(mask_l: Image.Image, px: int) -> Image.Image: | |
| """对白色区域做膨胀;px 约等于扩大像素。""" | |
| if px <= 0: | |
| return mask_l | |
| arr = np.array(mask_l, dtype=np.uint8) | |
| kernel = np.ones((3, 3), np.uint8) | |
| iters = max(1, int(px // 2)) # 经验 | |
| dilated = cv2.dilate(arr, kernel, iterations=iters) | |
| return Image.fromarray(dilated, mode="L") | |
| def _mask_from_red(img: Image.Image, out_size: Tuple[int, int]) -> Image.Image: | |
| """ | |
| 从一张 RGBA/RGB 图里提取“纯红笔迹”为二值蒙版(白=画笔,黑=其他)。 | |
| 阈值稍微宽一点以容忍压缩/插值。 | |
| """ | |
| arr = np.array(img.convert("RGBA")) | |
| r, g, b, a = arr[..., 0], arr[..., 1], arr[..., 2], arr[..., 3] | |
| # 条件:红高、绿低、蓝低、且 alpha>0 | |
| red_hit = (r >= 200) & (g <= 40) & (b <= 40) & (a > 0) | |
| mask = (red_hit.astype(np.uint8) * 255) | |
| m = Image.fromarray(mask, mode="L").resize(out_size, Image.NEAREST) | |
| return m | |
| def pick_mask( | |
| upload_mask: Optional[Image.Image], | |
| sketch_data: Optional[dict], | |
| base_image: Image.Image, | |
| dilate_px: int = 0, | |
| ) -> Optional[Image.Image]: | |
| """ | |
| 规则: | |
| 1) 若用户上传了 mask:直接用(白=mask) | |
| 2) 否则从 ImageEditor 返回里只“认红色笔迹”为 mask: | |
| - 先看 sketch_data['mask'](有些版本会给) | |
| - 不然遍历 sketch_data['layers'][*]['image'],合并其中的红色笔迹 | |
| - 若还没有,再退到 sketch_data['composite'] 里找红色笔迹 | |
| """ | |
| # 1) 上传优先 | |
| if isinstance(upload_mask, Image.Image): | |
| m = to_grayscale_mask(upload_mask).resize(base_image.size, Image.NEAREST) | |
| return dilate_mask(m, dilate_px) if dilate_px > 0 else m | |
| # 2) 手绘(ImageEditor) | |
| if isinstance(sketch_data, dict): | |
| # 2a) 显式 mask(仍然支持) | |
| m = sketch_data.get("mask") | |
| if isinstance(m, Image.Image): | |
| m = to_grayscale_mask(m).resize(base_image.size, Image.NEAREST) | |
| return dilate_mask(m, dilate_px) if dilate_px > 0 else m | |
| # 2b) 从 layers 里合并红色笔迹 | |
| layers = sketch_data.get("layers") | |
| acc = None | |
| if isinstance(layers, list) and layers: | |
| acc = Image.new("L", base_image.size, 0) | |
| for lyr in layers: | |
| if not isinstance(lyr, dict): | |
| continue | |
| li = lyr.get("image") or lyr.get("mask") | |
| if isinstance(li, Image.Image): | |
| m_layer = _mask_from_red(li, base_image.size) | |
| # 合并:有任一层画过就算 mask | |
| acc = ImageOps.lighter(acc, m_layer) | |
| if acc.getbbox() is not None: | |
| return dilate_mask(acc, dilate_px) if dilate_px > 0 else acc | |
| # 2c) 最后从 composite 里找红色笔迹 | |
| comp = sketch_data.get("composite") | |
| if isinstance(comp, Image.Image): | |
| m_comp = _mask_from_red(comp, base_image.size) | |
| if m_comp.getbbox() is not None: | |
| return dilate_mask(m_comp, dilate_px) if dilate_px > 0 else m_comp | |
| # 3) 没拿到就返回 None(后面会提示“需要掩码”) | |
| return None | |
| def _round_mult64(x: float, mode: str = "nearest") -> int: | |
| """ | |
| 把 x 对齐到 64 的倍数: | |
| - mode="ceil" 向上取整 | |
| - mode="floor" 向下取整 | |
| - mode="nearest" 最近的倍数 | |
| """ | |
| if mode == "ceil": | |
| return int((x + 63) // 64) * 64 | |
| elif mode == "floor": | |
| return int(x // 64) * 64 | |
| else: # nearest | |
| return int((x + 32) // 64) * 64 | |
| def prepare_size_for_flux(img: Image.Image, target_max: int = 1024) -> tuple[int, int]: | |
| """ | |
| 步骤: | |
| 1) 先把原始 w,h 向上对齐到 64 的倍数(避免小图过小) | |
| 2) 把长边固定为 target_max(默认1024) | |
| 3) 短边按比例缩放并对齐到 64 的倍数(至少 64) | |
| """ | |
| w, h = img.size | |
| # 1) 先各自向上对齐到 64 的倍数 | |
| w1 = max(64, _round_mult64(w, mode="ceil")) | |
| h1 = max(64, _round_mult64(h, mode="ceil")) | |
| # 2) 固定长边为 target_max,短边按比例 | |
| if w1 >= h1: | |
| out_w = target_max # 长边固定 1024 | |
| scaled_h = h1 * (target_max / w1) | |
| out_h = max(64, _round_mult64(scaled_h, mode="nearest")) | |
| else: | |
| out_h = target_max | |
| scaled_w = w1 * (target_max / h1) | |
| out_w = max(64, _round_mult64(scaled_w, mode="nearest")) | |
| return int(out_w), int(out_h) | |
| # ---------------- Preview depth for canvas (彩色) ---------------- | |
| def preview_depth(image: Optional[Image.Image], encoder: str, max_res: int, input_size: int, fp32: bool): | |
| if image is None: | |
| return None | |
| dm = get_model(encoder) | |
| # 彩色可视化(RGB),严格按你之前的 colormap 风格 | |
| d_rgb = dm.infer(image=image, max_res=max_res, input_size=input_size, fp32=fp32, grayscale=False) | |
| return d_rgb | |
| def prepare_canvas(image, depth_img, source): | |
| base = depth_img if source == "depth" else image | |
| if base is None: | |
| raise gr.Error("请先上传图片(并等待深度预览出来),再点击\"Prepare canvas\"。") | |
| # 对 ImageEditor 用通用的 gr.update 来设置 value | |
| return gr.update(value=base) | |
| # ---------------- Two-stage pipeline: depth(color) -> fill ---------------- | |
| def run_depth_and_fill( | |
| image: Image.Image, | |
| mask_upload: Optional[Image.Image], | |
| sketch: Optional[dict], | |
| prompt: str, | |
| encoder: str, | |
| max_res: int, | |
| input_size: int, | |
| fp32: bool, | |
| max_side: int, | |
| mask_dilate_px: int, | |
| guidance_scale: float, | |
| steps: int, | |
| seed: Optional[int], | |
| ) -> Tuple[Image.Image, Image.Image]: | |
| if image is None: | |
| raise gr.Error("请先上传一张图片。") | |
| # 1) 生成彩色深度图(RGB) | |
| depth_model = get_model(encoder) | |
| depth_rgb: Image.Image = depth_model.infer( | |
| image=image, max_res=max_res, input_size=input_size, fp32=fp32, grayscale=False | |
| ).convert("RGB") | |
| print(f"[DEBUG] Depth RGB: mode={depth_rgb.mode}, size={depth_rgb.size}") | |
| # 2) 提取 mask(上传 > 手绘) | |
| mask_l = pick_mask(mask_upload, sketch, image, dilate_px=mask_dilate_px) | |
| if (mask_l is None) or (mask_l.getbbox() is None): | |
| raise gr.Error("没有检测到有效的 mask:请确认已在画布上涂抹或上传 mask 图片。") | |
| print(f"[DEBUG] Mask: mode={mask_l.mode}, size={mask_l.size}, bbox={mask_l.getbbox()}") | |
| # 3) 确定输出尺寸 | |
| width, height = prepare_size_for_flux(depth_rgb, target_max=max_side) | |
| orig_w, orig_h = image.size | |
| print(f"[DEBUG] FLUX size: {width}x{height}, original: {orig_w}x{orig_h}") | |
| # 4) 运行 FLUX pipeline | |
| # 关键修复:image 参数应该传入 depth_rgb 而不是原图 | |
| pipe = get_pipe() | |
| generator = torch.Generator("cpu").manual_seed(int(seed)) if (seed is not None and seed >= 0) else torch.Generator("cpu").manual_seed(random.randint(0, 2**31 - 1)) | |
| result = pipe( | |
| prompt=prompt, | |
| image=depth_rgb, # 修复:传入彩色深度图而不是原图 | |
| mask_image=mask_l, | |
| width=width, | |
| height=height, | |
| guidance_scale=float(guidance_scale), | |
| num_inference_steps=int(steps), | |
| max_sequence_length=512, | |
| generator=generator, | |
| depth=depth_rgb, # depth 参数也传入彩色深度图 | |
| ).images[0] | |
| final_result = result.resize((orig_w, orig_h), Image.BICUBIC) | |
| # 返回结果和 mask 预览 | |
| mask_preview = mask_l.resize((orig_w, orig_h), Image.NEAREST).convert("RGB") | |
| return final_result, mask_preview | |
| # ---------------- UI ---------------- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## GeoRemover · Depth Removal (Depth(color) → FLUX Fill)") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # 输入图 | |
| img = gr.Image(label="Upload image", type="pil") | |
| # Mask 两种方式:上传 or 画 | |
| with gr.Tab("Upload mask"): | |
| mask_upload = gr.Image(label="Mask (optional)", type="pil") | |
| with gr.Tab("Draw mask"): | |
| draw_source = gr.Radio(["image", "depth"], value="image", label="Draw on") | |
| prepare_btn = gr.Button("Prepare canvas") | |
| sketch = gr.ImageEditor( | |
| label="Sketch mask (draw with brush)", | |
| type="pil", | |
| # 画笔只给纯红,方便我们精确提取笔迹 | |
| brush=gr.Brush(colors=["#FF0000"], default_size=24) | |
| ) | |
| # prompt | |
| prompt = gr.Textbox(label="Prompt", value="A beautiful scene") | |
| # 可调参数 | |
| with gr.Accordion("Advanced (Depth & FLUX)", open=False): | |
| encoder = gr.Dropdown(["vits", "vitl"], value="vitl", label="Depth encoder") | |
| max_res = gr.Slider(512, 2048, value=1280, step=64, label="Depth: max_res") | |
| input_size = gr.Slider(256, 1024, value=518, step=2, label="Depth: input_size") | |
| fp32 = gr.Checkbox(False, label="Depth: use FP32 (default FP16)") | |
| max_side = gr.Slider(512, 1536, value=1024, step=64, label="FLUX: max side (px)") | |
| mask_dilate_px = gr.Slider(0, 128, value=0, step=1, label="Mask dilation (px)") | |
| guidance_scale = gr.Slider(0, 50, value=30, step=0.5, label="FLUX: guidance_scale") | |
| steps = gr.Slider(10, 75, value=50, step=1, label="FLUX: steps") | |
| seed = gr.Number(value=0, precision=0, label="Seed (>=0 固定;留空随机)") | |
| run_btn = gr.Button("Run", variant="primary") | |
| with gr.Column(scale=1): | |
| depth_preview = gr.Image(label="Depth preview (colored)", interactive=False) | |
| mask_preview = gr.Image(label="Mask preview (what will be removed)", interactive=False) | |
| out = gr.Image(label="Output") | |
| # 事件:上传图片后生成"彩色深度预览" | |
| img.change( | |
| fn=preview_depth, | |
| inputs=[img, encoder, max_res, input_size, fp32], | |
| outputs=[depth_preview], | |
| ) | |
| # 准备画布:把原图或"彩色深度图"放进 ImageEditor | |
| prepare_btn.click( | |
| fn=prepare_canvas, | |
| inputs=[img, depth_preview, draw_source], | |
| outputs=[sketch], | |
| ) | |
| # 运行 | |
| run_btn.click( | |
| fn=run_depth_and_fill, | |
| inputs=[img, mask_upload, sketch, prompt, encoder, max_res, input_size, fp32, | |
| max_side, mask_dilate_px, guidance_scale, steps, seed], | |
| outputs=[out, mask_preview], | |
| api_name="run", | |
| ) | |
| if __name__ == "__main__": | |
| os.environ.setdefault("HF_HUB_DISABLE_TELEMETRY", "1") | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |