GeoRemover / app.py
zixinz
chore: ignore pyc and __pycache__
2f713b7
raw
history blame
14.9 kB
import os
import sys
import pathlib
import subprocess
import random
from typing import Optional, Tuple
import gradio as gr
import spaces
import torch
from PIL import Image, ImageOps
import numpy as np
import cv2
# ---------------- Paths & assets ----------------
BASE_DIR = pathlib.Path(__file__).resolve().parent
CODE_DEPTH = BASE_DIR / "code_depth"
CODE_EDIT = BASE_DIR / "code_edit"
GET_ASSETS = BASE_DIR / "get_assets.sh"
EXPECTED_ASSETS = [
BASE_DIR / "code_depth" / "checkpoints" / "video_depth_anything_vits.pth",
BASE_DIR / "code_depth" / "checkpoints" / "video_depth_anything_vitl.pth",
BASE_DIR / "code_edit" / "stage1" / "checkpoint-4800" / "pytorch_lora_weights.safetensors",
BASE_DIR / "code_edit" / "stage2" / "checkpoint-20000" / "pytorch_lora_weights.safetensors",
]
# import depth helper
if str(CODE_DEPTH) not in sys.path:
sys.path.insert(0, str(CODE_DEPTH))
from depth_infer import DepthModel # noqa: E402
# import your custom diffusers
if str(CODE_EDIT / "diffusers") not in sys.path:
sys.path.insert(0, str(CODE_EDIT / "diffusers"))
from diffusers.pipelines.flux.pipeline_flux_fill_unmasked_image_condition_version import ( # type: ignore # noqa: E402
FluxFillPipeline_token12_depth_only as FluxFillPipeline,
)
# ---------------- Assets ensure (on-demand) ----------------
def _have_all_assets() -> bool:
return all(p.is_file() for p in EXPECTED_ASSETS)
def _ensure_executable(p: pathlib.Path):
if not p.exists():
raise FileNotFoundError(f"Not found: {p}")
os.chmod(p, os.stat(p).st_mode | 0o111)
def ensure_assets_if_missing():
if os.getenv("SKIP_ASSET_DOWNLOAD") == "1":
print("↪️ SKIP_ASSET_DOWNLOAD=1 -> 跳过资产下载检查")
return
if _have_all_assets():
print("✅ Assets already present")
return
print("⬇️ Missing assets, running get_assets.sh ...")
_ensure_executable(GET_ASSETS)
subprocess.run(
["bash", str(GET_ASSETS)],
check=True,
cwd=str(BASE_DIR),
env={**os.environ, "HF_HUB_DISABLE_TELEMETRY": "1"},
)
if not _have_all_assets():
missing = [str(p.relative_to(BASE_DIR)) for p in EXPECTED_ASSETS if not p.exists()]
raise RuntimeError(f"Assets missing after get_assets.sh: {missing}")
print("✅ Assets ready.")
try:
ensure_assets_if_missing()
except Exception as e:
print(f"⚠️ Asset prepare failed: {e}")
# ---------------- Global singletons ----------------
_MODELS: dict[str, DepthModel] = {}
_PIPE: Optional[FluxFillPipeline] = None
def get_model(encoder: str) -> DepthModel:
if encoder not in _MODELS:
_MODELS[encoder] = DepthModel(BASE_DIR, encoder=encoder)
return _MODELS[encoder]
def get_pipe() -> FluxFillPipeline:
global _PIPE
if _PIPE is not None:
return _PIPE
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16 if device == "cuda" else torch.float32
print(f"[pipe] load FLUX.1-Fill-dev dtype={dtype}, device={device}")
pipe = FluxFillPipeline.from_pretrained("black-forest-labs/FLUX.1-Fill-dev", torch_dtype=dtype).to(device)
# LoRA(stage1)
lora_dir = CODE_EDIT / "stage1" / "checkpoint-4800"
if lora_dir.exists():
try:
pipe.load_lora_weights(str(lora_dir)) # 需要 peft
print(f"[pipe] loaded LoRA from: {lora_dir}")
except Exception as e:
print(f"[pipe] load LoRA failed (continue without): {e}")
else:
print(f"[pipe] LoRA path not found: {lora_dir} (continue without)")
_PIPE = pipe
return pipe
# ---------------- Mask helpers ----------------
def to_grayscale_mask(im: Image.Image) -> Image.Image:
"""
将任意 RGBA/RGB/L 的图转为 L。
输出:白=需要移除/填充区域,黑=保留。
"""
if im.mode == "RGBA":
mask = im.split()[-1] # alpha as mask
else:
mask = im.convert("L")
# 简单二值化,去噪
mask = mask.point(lambda p: 255 if p > 16 else 0)
return mask # 不做 invert,白色=mask区域
def dilate_mask(mask_l: Image.Image, px: int) -> Image.Image:
"""对白色区域做膨胀;px 约等于扩大像素。"""
if px <= 0:
return mask_l
arr = np.array(mask_l, dtype=np.uint8)
kernel = np.ones((3, 3), np.uint8)
iters = max(1, int(px // 2)) # 经验
dilated = cv2.dilate(arr, kernel, iterations=iters)
return Image.fromarray(dilated, mode="L")
def _mask_from_red(img: Image.Image, out_size: Tuple[int, int]) -> Image.Image:
"""
从一张 RGBA/RGB 图里提取“纯红笔迹”为二值蒙版(白=画笔,黑=其他)。
阈值稍微宽一点以容忍压缩/插值。
"""
arr = np.array(img.convert("RGBA"))
r, g, b, a = arr[..., 0], arr[..., 1], arr[..., 2], arr[..., 3]
# 条件:红高、绿低、蓝低、且 alpha>0
red_hit = (r >= 200) & (g <= 40) & (b <= 40) & (a > 0)
mask = (red_hit.astype(np.uint8) * 255)
m = Image.fromarray(mask, mode="L").resize(out_size, Image.NEAREST)
return m
def pick_mask(
upload_mask: Optional[Image.Image],
sketch_data: Optional[dict],
base_image: Image.Image,
dilate_px: int = 0,
) -> Optional[Image.Image]:
"""
规则:
1) 若用户上传了 mask:直接用(白=mask)
2) 否则从 ImageEditor 返回里只“认红色笔迹”为 mask:
- 先看 sketch_data['mask'](有些版本会给)
- 不然遍历 sketch_data['layers'][*]['image'],合并其中的红色笔迹
- 若还没有,再退到 sketch_data['composite'] 里找红色笔迹
"""
# 1) 上传优先
if isinstance(upload_mask, Image.Image):
m = to_grayscale_mask(upload_mask).resize(base_image.size, Image.NEAREST)
return dilate_mask(m, dilate_px) if dilate_px > 0 else m
# 2) 手绘(ImageEditor)
if isinstance(sketch_data, dict):
# 2a) 显式 mask(仍然支持)
m = sketch_data.get("mask")
if isinstance(m, Image.Image):
m = to_grayscale_mask(m).resize(base_image.size, Image.NEAREST)
return dilate_mask(m, dilate_px) if dilate_px > 0 else m
# 2b) 从 layers 里合并红色笔迹
layers = sketch_data.get("layers")
acc = None
if isinstance(layers, list) and layers:
acc = Image.new("L", base_image.size, 0)
for lyr in layers:
if not isinstance(lyr, dict):
continue
li = lyr.get("image") or lyr.get("mask")
if isinstance(li, Image.Image):
m_layer = _mask_from_red(li, base_image.size)
# 合并:有任一层画过就算 mask
acc = ImageOps.lighter(acc, m_layer)
if acc.getbbox() is not None:
return dilate_mask(acc, dilate_px) if dilate_px > 0 else acc
# 2c) 最后从 composite 里找红色笔迹
comp = sketch_data.get("composite")
if isinstance(comp, Image.Image):
m_comp = _mask_from_red(comp, base_image.size)
if m_comp.getbbox() is not None:
return dilate_mask(m_comp, dilate_px) if dilate_px > 0 else m_comp
# 3) 没拿到就返回 None(后面会提示“需要掩码”)
return None
def _round_mult64(x: float, mode: str = "nearest") -> int:
"""
把 x 对齐到 64 的倍数:
- mode="ceil" 向上取整
- mode="floor" 向下取整
- mode="nearest" 最近的倍数
"""
if mode == "ceil":
return int((x + 63) // 64) * 64
elif mode == "floor":
return int(x // 64) * 64
else: # nearest
return int((x + 32) // 64) * 64
def prepare_size_for_flux(img: Image.Image, target_max: int = 1024) -> tuple[int, int]:
"""
步骤:
1) 先把原始 w,h 向上对齐到 64 的倍数(避免小图过小)
2) 把长边固定为 target_max(默认1024)
3) 短边按比例缩放并对齐到 64 的倍数(至少 64)
"""
w, h = img.size
# 1) 先各自向上对齐到 64 的倍数
w1 = max(64, _round_mult64(w, mode="ceil"))
h1 = max(64, _round_mult64(h, mode="ceil"))
# 2) 固定长边为 target_max,短边按比例
if w1 >= h1:
out_w = target_max # 长边固定 1024
scaled_h = h1 * (target_max / w1)
out_h = max(64, _round_mult64(scaled_h, mode="nearest"))
else:
out_h = target_max
scaled_w = w1 * (target_max / h1)
out_w = max(64, _round_mult64(scaled_w, mode="nearest"))
return int(out_w), int(out_h)
# ---------------- Preview depth for canvas (彩色) ----------------
def preview_depth(image: Optional[Image.Image], encoder: str, max_res: int, input_size: int, fp32: bool):
if image is None:
return None
dm = get_model(encoder)
# 彩色可视化(RGB),严格按你之前的 colormap 风格
d_rgb = dm.infer(image=image, max_res=max_res, input_size=input_size, fp32=fp32, grayscale=False)
return d_rgb
def prepare_canvas(image, depth_img, source):
base = depth_img if source == "depth" else image
if base is None:
raise gr.Error("请先上传图片(并等待深度预览出来),再点击\"Prepare canvas\"。")
# 对 ImageEditor 用通用的 gr.update 来设置 value
return gr.update(value=base)
# ---------------- Two-stage pipeline: depth(color) -> fill ----------------
@spaces.GPU
def run_depth_and_fill(
image: Image.Image,
mask_upload: Optional[Image.Image],
sketch: Optional[dict],
prompt: str,
encoder: str,
max_res: int,
input_size: int,
fp32: bool,
max_side: int,
mask_dilate_px: int,
guidance_scale: float,
steps: int,
seed: Optional[int],
) -> Tuple[Image.Image, Image.Image]:
if image is None:
raise gr.Error("请先上传一张图片。")
# 1) 生成彩色深度图(RGB)
depth_model = get_model(encoder)
depth_rgb: Image.Image = depth_model.infer(
image=image, max_res=max_res, input_size=input_size, fp32=fp32, grayscale=False
).convert("RGB")
print(f"[DEBUG] Depth RGB: mode={depth_rgb.mode}, size={depth_rgb.size}")
# 2) 提取 mask(上传 > 手绘)
mask_l = pick_mask(mask_upload, sketch, image, dilate_px=mask_dilate_px)
if (mask_l is None) or (mask_l.getbbox() is None):
raise gr.Error("没有检测到有效的 mask:请确认已在画布上涂抹或上传 mask 图片。")
print(f"[DEBUG] Mask: mode={mask_l.mode}, size={mask_l.size}, bbox={mask_l.getbbox()}")
# 3) 确定输出尺寸
width, height = prepare_size_for_flux(depth_rgb, target_max=max_side)
orig_w, orig_h = image.size
print(f"[DEBUG] FLUX size: {width}x{height}, original: {orig_w}x{orig_h}")
# 4) 运行 FLUX pipeline
# 关键修复:image 参数应该传入 depth_rgb 而不是原图
pipe = get_pipe()
generator = torch.Generator("cpu").manual_seed(int(seed)) if (seed is not None and seed >= 0) else torch.Generator("cpu").manual_seed(random.randint(0, 2**31 - 1))
result = pipe(
prompt=prompt,
image=depth_rgb, # 修复:传入彩色深度图而不是原图
mask_image=mask_l,
width=width,
height=height,
guidance_scale=float(guidance_scale),
num_inference_steps=int(steps),
max_sequence_length=512,
generator=generator,
depth=depth_rgb, # depth 参数也传入彩色深度图
).images[0]
final_result = result.resize((orig_w, orig_h), Image.BICUBIC)
# 返回结果和 mask 预览
mask_preview = mask_l.resize((orig_w, orig_h), Image.NEAREST).convert("RGB")
return final_result, mask_preview
# ---------------- UI ----------------
with gr.Blocks() as demo:
gr.Markdown("## GeoRemover · Depth Removal (Depth(color) → FLUX Fill)")
with gr.Row():
with gr.Column(scale=1):
# 输入图
img = gr.Image(label="Upload image", type="pil")
# Mask 两种方式:上传 or 画
with gr.Tab("Upload mask"):
mask_upload = gr.Image(label="Mask (optional)", type="pil")
with gr.Tab("Draw mask"):
draw_source = gr.Radio(["image", "depth"], value="image", label="Draw on")
prepare_btn = gr.Button("Prepare canvas")
sketch = gr.ImageEditor(
label="Sketch mask (draw with brush)",
type="pil",
# 画笔只给纯红,方便我们精确提取笔迹
brush=gr.Brush(colors=["#FF0000"], default_size=24)
)
# prompt
prompt = gr.Textbox(label="Prompt", value="A beautiful scene")
# 可调参数
with gr.Accordion("Advanced (Depth & FLUX)", open=False):
encoder = gr.Dropdown(["vits", "vitl"], value="vitl", label="Depth encoder")
max_res = gr.Slider(512, 2048, value=1280, step=64, label="Depth: max_res")
input_size = gr.Slider(256, 1024, value=518, step=2, label="Depth: input_size")
fp32 = gr.Checkbox(False, label="Depth: use FP32 (default FP16)")
max_side = gr.Slider(512, 1536, value=1024, step=64, label="FLUX: max side (px)")
mask_dilate_px = gr.Slider(0, 128, value=0, step=1, label="Mask dilation (px)")
guidance_scale = gr.Slider(0, 50, value=30, step=0.5, label="FLUX: guidance_scale")
steps = gr.Slider(10, 75, value=50, step=1, label="FLUX: steps")
seed = gr.Number(value=0, precision=0, label="Seed (>=0 固定;留空随机)")
run_btn = gr.Button("Run", variant="primary")
with gr.Column(scale=1):
depth_preview = gr.Image(label="Depth preview (colored)", interactive=False)
mask_preview = gr.Image(label="Mask preview (what will be removed)", interactive=False)
out = gr.Image(label="Output")
# 事件:上传图片后生成"彩色深度预览"
img.change(
fn=preview_depth,
inputs=[img, encoder, max_res, input_size, fp32],
outputs=[depth_preview],
)
# 准备画布:把原图或"彩色深度图"放进 ImageEditor
prepare_btn.click(
fn=prepare_canvas,
inputs=[img, depth_preview, draw_source],
outputs=[sketch],
)
# 运行
run_btn.click(
fn=run_depth_and_fill,
inputs=[img, mask_upload, sketch, prompt, encoder, max_res, input_size, fp32,
max_side, mask_dilate_px, guidance_scale, steps, seed],
outputs=[out, mask_preview],
api_name="run",
)
if __name__ == "__main__":
os.environ.setdefault("HF_HUB_DISABLE_TELEMETRY", "1")
demo.launch(server_name="0.0.0.0", server_port=7860)