GeoRemover / app.py
zixinz
chore: ignore pyc and __pycache__
5f25c59
raw
history blame
17.6 kB
import gradio as gr
import spaces
import sys, pathlib
BASE_DIR = pathlib.Path(__file__).resolve().parent
LOCAL_DIFFUSERS_SRC = BASE_DIR / "code_edit" / "diffusers" / "src"
if (LOCAL_DIFFUSERS_SRC / "diffusers").exists():
sys.path.insert(0, str(LOCAL_DIFFUSERS_SRC))
else:
raise RuntimeError(f"Local diffusers not found at: {LOCAL_DIFFUSERS_SRC}")
from diffusers.pipelines.flux.pipeline_flux_fill_unmasked_image_condition_version import (
FluxFillPipeline_token12_depth_only as FluxFillPipeline,
)
import os
import sys
import pathlib
import subprocess
import random
from typing import Optional, Tuple
import torch
from PIL import Image, ImageOps
import numpy as np
import cv2
# ---------------- Paths & assets ----------------
BASE_DIR = pathlib.Path(__file__).resolve().parent
CODE_DEPTH = BASE_DIR / "code_depth"
CODE_EDIT = BASE_DIR / "code_edit"
GET_ASSETS = BASE_DIR / "get_assets.sh"
EXPECTED_ASSETS = [
BASE_DIR / "code_depth" / "checkpoints" / "video_depth_anything_vits.pth",
BASE_DIR / "code_depth" / "checkpoints" / "video_depth_anything_vitl.pth",
BASE_DIR / "code_edit" / "stage1" / "checkpoint-4800" / "pytorch_lora_weights.safetensors",
BASE_DIR / "code_edit" / "stage2" / "checkpoint-20000" / "pytorch_lora_weights.safetensors",
]
# import depth helper
if str(CODE_DEPTH) not in sys.path:
sys.path.insert(0, str(CODE_DEPTH))
from depth_infer import DepthModel # noqa: E402
# import your custom diffusers
if str(CODE_EDIT / "diffusers") not in sys.path:
sys.path.insert(0, str(CODE_EDIT / "diffusers"))
from diffusers.pipelines.flux.pipeline_flux_fill_unmasked_image_condition_version import ( # type: ignore # noqa: E402
FluxFillPipeline_token12_depth_only as FluxFillPipeline,
)
# ---------------- Assets ensure (on-demand) ----------------
def _have_all_assets() -> bool:
return all(p.is_file() for p in EXPECTED_ASSETS)
def _ensure_executable(p: pathlib.Path):
if not p.exists():
raise FileNotFoundError(f"Not found: {p}")
os.chmod(p, os.stat(p).st_mode | 0o111)
def ensure_assets_if_missing():
if os.getenv("SKIP_ASSET_DOWNLOAD") == "1":
print("↪️ SKIP_ASSET_DOWNLOAD=1 -> 跳过资产下载检查")
return
if _have_all_assets():
print("✅ Assets already present")
return
print("⬇️ Missing assets, running get_assets.sh ...")
_ensure_executable(GET_ASSETS)
subprocess.run(
["bash", str(GET_ASSETS)],
check=True,
cwd=str(BASE_DIR),
env={**os.environ, "HF_HUB_DISABLE_TELEMETRY": "1"},
)
if not _have_all_assets():
missing = [str(p.relative_to(BASE_DIR)) for p in EXPECTED_ASSETS if not p.exists()]
raise RuntimeError(f"Assets missing after get_assets.sh: {missing}")
print("✅ Assets ready.")
try:
ensure_assets_if_missing()
except Exception as e:
print(f"⚠️ Asset prepare failed: {e}")
# ---------------- Global singletons ----------------
_MODELS: dict[str, DepthModel] = {}
_PIPE: Optional[FluxFillPipeline] = None
def get_model(encoder: str) -> DepthModel:
if encoder not in _MODELS:
_MODELS[encoder] = DepthModel(BASE_DIR, encoder=encoder)
return _MODELS[encoder]
def get_pipe() -> FluxFillPipeline:
global _PIPE
if _PIPE is not None:
return _PIPE
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16 if device == "cuda" else torch.float32
local_flux = BASE_DIR / "code_edit" / "flux_cache"
use_local = local_flux.exists()
hf_token = os.environ.get("HF_TOKEN")
try:
from huggingface_hub import hf_hub_enable_hf_transfer
hf_hub_enable_hf_transfer()
except Exception:
pass
print(f"[pipe] loading FLUX.1-Fill-dev (dtype={dtype}, device={device}, local={use_local})")
try:
if use_local:
pipe = FluxFillPipeline.from_pretrained(
local_flux, torch_dtype=dtype
).to(device)
else:
# 在线拉取(需要 gated 访问 + token)
pipe = FluxFillPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Fill-dev",
torch_dtype=dtype,
token=hf_token,
# use_auth_token=hf_token,
).to(device)
except Exception as e:
raise RuntimeError(
"Failed to load FLUX.1-Fill-dev. "
"Make sure your account has access to the gated repo and HF_TOKEN is set as a Space secret, "
"or pre-download to a local cache directory."
) from e
# -------- LoRA (stage1) --------
lora_dir = CODE_EDIT / "stage1" / "checkpoint-4800"
lora_file = "pytorch_lora_weights.safetensors" # 你的实际文件名
adapter_name = "stage1"
if lora_dir.exists():
try:
import peft # just to assert backend is present
print(f"[pipe] loading LoRA from: {lora_dir}/{lora_file}")
pipe.load_lora_weights(
str(lora_dir),
weight_name=lora_file, # 关键:指定文件名
adapter_name=adapter_name # 给一个可切换的名字
)
# 新版 diffusers:优先 set_adapters
try:
pipe.set_adapters(adapter_name, scale=1.0)
print(f"[pipe] set_adapters('{adapter_name}', scale=1.0)")
except Exception as e_set:
print(f"[pipe] set_adapters not available ({e_set}); trying fuse_lora()")
# 旧版/或不支持 set_adapters 的 pipeline:融合 LoRA
try:
pipe.fuse_lora(lora_scale=1.0)
print("[pipe] fuse_lora(lora_scale=1.0) done")
except Exception as e_fuse:
print(f"[pipe] fuse_lora failed: {e_fuse}")
print("[pipe] LoRA ready ✅")
except ImportError:
print("[pipe] peft not installed; LoRA will be skipped (add `peft>=0.11` to requirements).")
except Exception as e:
print(f"[pipe] load_lora_weights failed (continue without): {e}")
else:
print(f"[pipe] LoRA path not found: {lora_dir} (continue without)")
_PIPE = pipe
return pipe
# ---------------- Mask helpers ----------------
def to_grayscale_mask(im: Image.Image) -> Image.Image:
"""
将任意 RGBA/RGB/L 的图转为 L。
输出:白=需要移除/填充区域,黑=保留。
"""
if im.mode == "RGBA":
mask = im.split()[-1] # alpha as mask
else:
mask = im.convert("L")
# 简单二值化,去噪
mask = mask.point(lambda p: 255 if p > 16 else 0)
return mask # 不做 invert,白色=mask区域
def dilate_mask(mask_l: Image.Image, px: int) -> Image.Image:
"""对白色区域做膨胀;px 约等于扩大像素。"""
if px <= 0:
return mask_l
arr = np.array(mask_l, dtype=np.uint8)
kernel = np.ones((3, 3), np.uint8)
iters = max(1, int(px // 2)) # 经验
dilated = cv2.dilate(arr, kernel, iterations=iters)
return Image.fromarray(dilated, mode="L")
def _mask_from_red(img: Image.Image, out_size: Tuple[int, int]) -> Image.Image:
"""
从一张 RGBA/RGB 图里提取“纯红笔迹”为二值蒙版(白=画笔,黑=其他)。
阈值稍微宽一点以容忍压缩/插值。
"""
arr = np.array(img.convert("RGBA"))
r, g, b, a = arr[..., 0], arr[..., 1], arr[..., 2], arr[..., 3]
# 条件:红高、绿低、蓝低、且 alpha>0
red_hit = (r >= 200) & (g <= 40) & (b <= 40) & (a > 0)
mask = (red_hit.astype(np.uint8) * 255)
m = Image.fromarray(mask, mode="L").resize(out_size, Image.NEAREST)
return m
def pick_mask(
upload_mask: Optional[Image.Image],
sketch_data: Optional[dict],
base_image: Image.Image,
dilate_px: int = 0,
) -> Optional[Image.Image]:
"""
规则:
1) 若用户上传了 mask:直接用(白=mask)
2) 否则从 ImageEditor 返回里只“认红色笔迹”为 mask:
- 先看 sketch_data['mask'](有些版本会给)
- 不然遍历 sketch_data['layers'][*]['image'],合并其中的红色笔迹
- 若还没有,再退到 sketch_data['composite'] 里找红色笔迹
"""
# 1) 上传优先
if isinstance(upload_mask, Image.Image):
m = to_grayscale_mask(upload_mask).resize(base_image.size, Image.NEAREST)
return dilate_mask(m, dilate_px) if dilate_px > 0 else m
# 2) 手绘(ImageEditor)
if isinstance(sketch_data, dict):
# 2a) 显式 mask(仍然支持)
m = sketch_data.get("mask")
if isinstance(m, Image.Image):
m = to_grayscale_mask(m).resize(base_image.size, Image.NEAREST)
return dilate_mask(m, dilate_px) if dilate_px > 0 else m
# 2b) 从 layers 里合并红色笔迹
layers = sketch_data.get("layers")
acc = None
if isinstance(layers, list) and layers:
acc = Image.new("L", base_image.size, 0)
for lyr in layers:
if not isinstance(lyr, dict):
continue
li = lyr.get("image") or lyr.get("mask")
if isinstance(li, Image.Image):
m_layer = _mask_from_red(li, base_image.size)
# 合并:有任一层画过就算 mask
acc = ImageOps.lighter(acc, m_layer)
if acc.getbbox() is not None:
return dilate_mask(acc, dilate_px) if dilate_px > 0 else acc
# 2c) 最后从 composite 里找红色笔迹
comp = sketch_data.get("composite")
if isinstance(comp, Image.Image):
m_comp = _mask_from_red(comp, base_image.size)
if m_comp.getbbox() is not None:
return dilate_mask(m_comp, dilate_px) if dilate_px > 0 else m_comp
# 3) 没拿到就返回 None(后面会提示“需要掩码”)
return None
def _round_mult64(x: float, mode: str = "nearest") -> int:
"""
把 x 对齐到 64 的倍数:
- mode="ceil" 向上取整
- mode="floor" 向下取整
- mode="nearest" 最近的倍数
"""
if mode == "ceil":
return int((x + 63) // 64) * 64
elif mode == "floor":
return int(x // 64) * 64
else: # nearest
return int((x + 32) // 64) * 64
def prepare_size_for_flux(img: Image.Image, target_max: int = 1024) -> tuple[int, int]:
"""
步骤:
1) 先把原始 w,h 向上对齐到 64 的倍数(避免小图过小)
2) 把长边固定为 target_max(默认1024)
3) 短边按比例缩放并对齐到 64 的倍数(至少 64)
"""
w, h = img.size
# 1) 先各自向上对齐到 64 的倍数
w1 = max(64, _round_mult64(w, mode="ceil"))
h1 = max(64, _round_mult64(h, mode="ceil"))
# 2) 固定长边为 target_max,短边按比例
if w1 >= h1:
out_w = target_max # 长边固定 1024
scaled_h = h1 * (target_max / w1)
out_h = max(64, _round_mult64(scaled_h, mode="nearest"))
else:
out_h = target_max
scaled_w = w1 * (target_max / h1)
out_w = max(64, _round_mult64(scaled_w, mode="nearest"))
return int(out_w), int(out_h)
@spaces.GPU
# ---------------- Preview depth for canvas (彩色) ----------------
def preview_depth(image: Optional[Image.Image], encoder: str, max_res: int, input_size: int, fp32: bool):
if image is None:
return None
dm = get_model(encoder)
# 彩色可视化(RGB),严格按你之前的 colormap 风格
d_rgb = dm.infer(image=image, max_res=max_res, input_size=input_size, fp32=fp32, grayscale=False)
return d_rgb
def prepare_canvas(image, depth_img, source):
base = depth_img if source == "depth" else image
if base is None:
raise gr.Error("请先上传图片(并等待深度预览出来),再点击\"Prepare canvas\"。")
# 对 ImageEditor 用通用的 gr.update 来设置 value
return gr.update(value=base)
# ---------------- Two-stage pipeline: depth(color) -> fill ----------------
@spaces.GPU
def run_depth_and_fill(
image: Image.Image,
mask_upload: Optional[Image.Image],
sketch: Optional[dict],
prompt: str,
encoder: str,
max_res: int,
input_size: int,
fp32: bool,
max_side: int,
mask_dilate_px: int,
guidance_scale: float,
steps: int,
seed: Optional[int],
) -> Tuple[Image.Image, Image.Image]:
if image is None:
raise gr.Error("请先上传一张图片。")
# 1) 生成彩色深度图(RGB)
depth_model = get_model(encoder)
depth_rgb: Image.Image = depth_model.infer(
image=image, max_res=max_res, input_size=input_size, fp32=fp32, grayscale=False
).convert("RGB")
print(f"[DEBUG] Depth RGB: mode={depth_rgb.mode}, size={depth_rgb.size}")
# 2) 提取 mask(上传 > 手绘)
mask_l = pick_mask(mask_upload, sketch, image, dilate_px=mask_dilate_px)
if (mask_l is None) or (mask_l.getbbox() is None):
raise gr.Error("没有检测到有效的 mask:请确认已在画布上涂抹或上传 mask 图片。")
print(f"[DEBUG] Mask: mode={mask_l.mode}, size={mask_l.size}, bbox={mask_l.getbbox()}")
# 3) 确定输出尺寸
width, height = prepare_size_for_flux(depth_rgb, target_max=max_side)
orig_w, orig_h = image.size
print(f"[DEBUG] FLUX size: {width}x{height}, original: {orig_w}x{orig_h}")
# 4) 运行 FLUX pipeline
# 关键修复:image 参数应该传入 depth_rgb 而不是原图
pipe = get_pipe()
generator = torch.Generator("cpu").manual_seed(int(seed)) if (seed is not None and seed >= 0) else torch.Generator("cpu").manual_seed(random.randint(0, 2**31 - 1))
result = pipe(
prompt=prompt,
image=depth_rgb, # 修复:传入彩色深度图而不是原图
mask_image=mask_l,
width=width,
height=height,
guidance_scale=float(guidance_scale),
num_inference_steps=int(steps),
max_sequence_length=512,
generator=generator,
depth=depth_rgb, # depth 参数也传入彩色深度图
).images[0]
final_result = result.resize((orig_w, orig_h), Image.BICUBIC)
# 返回结果和 mask 预览
mask_preview = mask_l.resize((orig_w, orig_h), Image.NEAREST).convert("RGB")
return final_result, mask_preview
# ---------------- UI ----------------
with gr.Blocks() as demo:
gr.Markdown("## GeoRemover · Depth Removal (Depth(color) → FLUX Fill)")
with gr.Row():
with gr.Column(scale=1):
# 输入图
img = gr.Image(label="Upload image", type="pil")
# Mask 两种方式:上传 or 画
with gr.Tab("Upload mask"):
mask_upload = gr.Image(label="Mask (optional)", type="pil")
with gr.Tab("Draw mask"):
draw_source = gr.Radio(["image", "depth"], value="image", label="Draw on")
prepare_btn = gr.Button("Prepare canvas")
sketch = gr.ImageEditor(
label="Sketch mask (draw with brush)",
type="pil",
# 画笔只给纯红,方便我们精确提取笔迹
brush=gr.Brush(colors=["#FF0000"], default_size=24)
)
# prompt
prompt = gr.Textbox(label="Prompt", value="A beautiful scene")
# 可调参数
with gr.Accordion("Advanced (Depth & FLUX)", open=False):
encoder = gr.Dropdown(["vits", "vitl"], value="vitl", label="Depth encoder")
max_res = gr.Slider(512, 2048, value=1280, step=64, label="Depth: max_res")
input_size = gr.Slider(256, 1024, value=518, step=2, label="Depth: input_size")
fp32 = gr.Checkbox(False, label="Depth: use FP32 (default FP16)")
max_side = gr.Slider(512, 1536, value=1024, step=64, label="FLUX: max side (px)")
mask_dilate_px = gr.Slider(0, 128, value=0, step=1, label="Mask dilation (px)")
guidance_scale = gr.Slider(0, 50, value=30, step=0.5, label="FLUX: guidance_scale")
steps = gr.Slider(10, 75, value=50, step=1, label="FLUX: steps")
seed = gr.Number(value=0, precision=0, label="Seed (>=0 固定;留空随机)")
run_btn = gr.Button("Run", variant="primary")
with gr.Column(scale=1):
depth_preview = gr.Image(label="Depth preview (colored)", interactive=False)
mask_preview = gr.Image(label="Mask preview (what will be removed)", interactive=False)
out = gr.Image(label="Output")
# 事件:上传图片后生成"彩色深度预览"
img.change(
fn=preview_depth,
inputs=[img, encoder, max_res, input_size, fp32],
outputs=[depth_preview],
)
# 准备画布:把原图或"彩色深度图"放进 ImageEditor
prepare_btn.click(
fn=prepare_canvas,
inputs=[img, depth_preview, draw_source],
outputs=[sketch],
)
# 运行
run_btn.click(
fn=run_depth_and_fill,
inputs=[img, mask_upload, sketch, prompt, encoder, max_res, input_size, fp32,
max_side, mask_dilate_px, guidance_scale, steps, seed],
outputs=[out, mask_preview],
api_name="run",
)
if __name__ == "__main__":
os.environ.setdefault("HF_HUB_DISABLE_TELEMETRY", "1")
demo.launch(server_name="0.0.0.0", server_port=7860)