zixinz commited on
Commit
134053b
·
1 Parent(s): 5f25c59

chore: ignore pyc and __pycache__

Browse files
Files changed (1) hide show
  1. app.py +257 -75
app.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import gradio as gr
3
  import spaces
4
  import sys, pathlib
@@ -14,6 +13,12 @@ from diffusers.pipelines.flux.pipeline_flux_fill_unmasked_image_condition_versio
14
  FluxFillPipeline_token12_depth_only as FluxFillPipeline,
15
  )
16
 
 
 
 
 
 
 
17
  import os
18
  import sys
19
  import pathlib
@@ -21,7 +26,6 @@ import subprocess
21
  import random
22
  from typing import Optional, Tuple
23
 
24
-
25
  import torch
26
  from PIL import Image, ImageOps
27
  import numpy as np
@@ -63,7 +67,7 @@ def _ensure_executable(p: pathlib.Path):
63
 
64
  def ensure_assets_if_missing():
65
  if os.getenv("SKIP_ASSET_DOWNLOAD") == "1":
66
- print("↪️ SKIP_ASSET_DOWNLOAD=1 -> 跳过资产下载检查")
67
  return
68
  if _have_all_assets():
69
  print("✅ Assets already present")
@@ -89,6 +93,9 @@ except Exception as e:
89
  # ---------------- Global singletons ----------------
90
  _MODELS: dict[str, DepthModel] = {}
91
  _PIPE: Optional[FluxFillPipeline] = None
 
 
 
92
 
93
  def get_model(encoder: str) -> DepthModel:
94
  if encoder not in _MODELS:
@@ -103,14 +110,11 @@ def get_pipe() -> FluxFillPipeline:
103
  device = "cuda" if torch.cuda.is_available() else "cpu"
104
  dtype = torch.bfloat16 if device == "cuda" else torch.float32
105
 
106
-
107
- local_flux = BASE_DIR / "code_edit" / "flux_cache"
108
  use_local = local_flux.exists()
109
 
110
-
111
  hf_token = os.environ.get("HF_TOKEN")
112
 
113
-
114
  try:
115
  from huggingface_hub import hf_hub_enable_hf_transfer
116
  hf_hub_enable_hf_transfer()
@@ -124,12 +128,11 @@ def get_pipe() -> FluxFillPipeline:
124
  local_flux, torch_dtype=dtype
125
  ).to(device)
126
  else:
127
- # 在线拉取(需要 gated 访问 + token
128
  pipe = FluxFillPipeline.from_pretrained(
129
  "black-forest-labs/FLUX.1-Fill-dev",
130
  torch_dtype=dtype,
131
- token=hf_token,
132
- # use_auth_token=hf_token,
133
  ).to(device)
134
  except Exception as e:
135
  raise RuntimeError(
@@ -140,25 +143,25 @@ def get_pipe() -> FluxFillPipeline:
140
 
141
  # -------- LoRA (stage1) --------
142
  lora_dir = CODE_EDIT / "stage1" / "checkpoint-4800"
143
- lora_file = "pytorch_lora_weights.safetensors" # 你的实际文件名
144
  adapter_name = "stage1"
145
 
146
  if lora_dir.exists():
147
  try:
148
- import peft # just to assert backend is present
149
  print(f"[pipe] loading LoRA from: {lora_dir}/{lora_file}")
150
  pipe.load_lora_weights(
151
  str(lora_dir),
152
- weight_name=lora_file, # 关键:指定文件名
153
- adapter_name=adapter_name # 给一个可切换的名字
154
  )
155
- # 新版 diffusers:优先 set_adapters
156
  try:
157
  pipe.set_adapters(adapter_name, scale=1.0)
158
  print(f"[pipe] set_adapters('{adapter_name}', scale=1.0)")
159
  except Exception as e_set:
160
  print(f"[pipe] set_adapters not available ({e_set}); trying fuse_lora()")
161
- # 旧版/或不支持 set_adapters pipeline:融合 LoRA
162
  try:
163
  pipe.fuse_lora(lora_scale=1.0)
164
  print("[pipe] fuse_lora(lora_scale=1.0) done")
@@ -175,39 +178,126 @@ def get_pipe() -> FluxFillPipeline:
175
  _PIPE = pipe
176
  return pipe
177
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  # ---------------- Mask helpers ----------------
179
  def to_grayscale_mask(im: Image.Image) -> Image.Image:
180
  """
181
- 将任意 RGBA/RGB/L 的图转为 L
182
- 输出:白=需要移除/填充区域,黑=保留。
183
  """
184
  if im.mode == "RGBA":
185
  mask = im.split()[-1] # alpha as mask
186
  else:
187
  mask = im.convert("L")
188
- # 简单二值化,去噪
189
  mask = mask.point(lambda p: 255 if p > 16 else 0)
190
- return mask # 不做 invert,白色=mask区域
191
 
192
  def dilate_mask(mask_l: Image.Image, px: int) -> Image.Image:
193
- """对白色区域做膨胀;px 约等于扩大像素。"""
194
  if px <= 0:
195
  return mask_l
196
  arr = np.array(mask_l, dtype=np.uint8)
197
  kernel = np.ones((3, 3), np.uint8)
198
- iters = max(1, int(px // 2)) # 经验
199
  dilated = cv2.dilate(arr, kernel, iterations=iters)
200
  return Image.fromarray(dilated, mode="L")
201
 
202
  def _mask_from_red(img: Image.Image, out_size: Tuple[int, int]) -> Image.Image:
203
  """
204
- 从一张 RGBA/RGB 图里提取“纯红笔迹”为二值蒙版(白=画笔,黑=其他)。
205
- 阈值稍微宽一点以容忍压缩/插值。
206
  """
207
  arr = np.array(img.convert("RGBA"))
208
  r, g, b, a = arr[..., 0], arr[..., 1], arr[..., 2], arr[..., 3]
209
 
210
- # 条件:红高、绿低、蓝低、且 alpha>0
211
  red_hit = (r >= 200) & (g <= 40) & (b <= 40) & (a > 0)
212
 
213
  mask = (red_hit.astype(np.uint8) * 255)
@@ -221,27 +311,27 @@ def pick_mask(
221
  dilate_px: int = 0,
222
  ) -> Optional[Image.Image]:
223
  """
224
- 规则:
225
- 1) 若用户上传了 mask:直接用(白=mask
226
- 2) 否则从 ImageEditor 返回里只“认红色笔迹”为 mask
227
- - 先看 sketch_data['mask'](有些版本会给)
228
- - 不然遍历 sketch_data['layers'][*]['image'],合并其中的红色笔迹
229
- - 若还没有,再退到 sketch_data['composite'] 里找红色笔迹
230
  """
231
- # 1) 上传优先
232
  if isinstance(upload_mask, Image.Image):
233
  m = to_grayscale_mask(upload_mask).resize(base_image.size, Image.NEAREST)
234
  return dilate_mask(m, dilate_px) if dilate_px > 0 else m
235
 
236
- # 2) 手绘(ImageEditor
237
  if isinstance(sketch_data, dict):
238
- # 2a) 显式 mask(仍然支持)
239
  m = sketch_data.get("mask")
240
  if isinstance(m, Image.Image):
241
  m = to_grayscale_mask(m).resize(base_image.size, Image.NEAREST)
242
  return dilate_mask(m, dilate_px) if dilate_px > 0 else m
243
 
244
- # 2b) layers 里合并红色笔迹
245
  layers = sketch_data.get("layers")
246
  acc = None
247
  if isinstance(layers, list) and layers:
@@ -252,28 +342,28 @@ def pick_mask(
252
  li = lyr.get("image") or lyr.get("mask")
253
  if isinstance(li, Image.Image):
254
  m_layer = _mask_from_red(li, base_image.size)
255
- # 合并:有任一层画过就算 mask
256
  acc = ImageOps.lighter(acc, m_layer)
257
  if acc.getbbox() is not None:
258
  return dilate_mask(acc, dilate_px) if dilate_px > 0 else acc
259
 
260
- # 2c) 最后从 composite 里找红色笔迹
261
  comp = sketch_data.get("composite")
262
  if isinstance(comp, Image.Image):
263
  m_comp = _mask_from_red(comp, base_image.size)
264
  if m_comp.getbbox() is not None:
265
  return dilate_mask(m_comp, dilate_px) if dilate_px > 0 else m_comp
266
 
267
- # 3) 没拿到就返回 None(后面会提示“需要掩码”)
268
  return None
269
 
270
 
271
  def _round_mult64(x: float, mode: str = "nearest") -> int:
272
  """
273
- x 对齐到 64 的倍数:
274
- - mode="ceil" 向上取整
275
- - mode="floor" 向下取整
276
- - mode="nearest" 最近的倍数
277
  """
278
  if mode == "ceil":
279
  return int((x + 63) // 64) * 64
@@ -284,20 +374,20 @@ def _round_mult64(x: float, mode: str = "nearest") -> int:
284
 
285
  def prepare_size_for_flux(img: Image.Image, target_max: int = 1024) -> tuple[int, int]:
286
  """
287
- 步骤:
288
- 1) 先把原始 w,h 向上对齐到 64 的倍数(避免小图过小)
289
- 2) 把长边固定为 target_max(默认1024)
290
- 3) 短边按比例缩放并对齐到 64 的倍数(至少 64
291
  """
292
  w, h = img.size
293
 
294
- # 1) 先各自向上对齐到 64 的倍数
295
  w1 = max(64, _round_mult64(w, mode="ceil"))
296
  h1 = max(64, _round_mult64(h, mode="ceil"))
297
 
298
- # 2) 固定长边为 target_max,短边按比例
299
  if w1 >= h1:
300
- out_w = target_max # 长边固定 1024
301
  scaled_h = h1 * (target_max / w1)
302
  out_h = max(64, _round_mult64(scaled_h, mode="nearest"))
303
  else:
@@ -306,21 +396,22 @@ def prepare_size_for_flux(img: Image.Image, target_max: int = 1024) -> tuple[int
306
  out_w = max(64, _round_mult64(scaled_w, mode="nearest"))
307
 
308
  return int(out_w), int(out_h)
 
309
  @spaces.GPU
310
- # ---------------- Preview depth for canvas (彩色) ----------------
311
  def preview_depth(image: Optional[Image.Image], encoder: str, max_res: int, input_size: int, fp32: bool):
312
  if image is None:
313
  return None
314
  dm = get_model(encoder)
315
- # 彩色可视化(RGB),严格按你之前的 colormap 风格
316
  d_rgb = dm.infer(image=image, max_res=max_res, input_size=input_size, fp32=fp32, grayscale=False)
317
  return d_rgb
318
 
319
  def prepare_canvas(image, depth_img, source):
320
  base = depth_img if source == "depth" else image
321
  if base is None:
322
- raise gr.Error("请先上传图片(并等待深度预览出来),再点击\"Prepare canvas\"。")
323
- # ImageEditor 用通用的 gr.update 来设置 value
324
  return gr.update(value=base)
325
 
326
  # ---------------- Two-stage pipeline: depth(color) -> fill ----------------
@@ -341,9 +432,9 @@ def run_depth_and_fill(
341
  seed: Optional[int],
342
  ) -> Tuple[Image.Image, Image.Image]:
343
  if image is None:
344
- raise gr.Error("请先上传一张图片。")
345
 
346
- # 1) 生成彩色深度图(RGB
347
  depth_model = get_model(encoder)
348
  depth_rgb: Image.Image = depth_model.infer(
349
  image=image, max_res=max_res, input_size=input_size, fp32=fp32, grayscale=False
@@ -351,26 +442,26 @@ def run_depth_and_fill(
351
 
352
  print(f"[DEBUG] Depth RGB: mode={depth_rgb.mode}, size={depth_rgb.size}")
353
 
354
- # 2) 提取 mask(上传 > 手绘)
355
  mask_l = pick_mask(mask_upload, sketch, image, dilate_px=mask_dilate_px)
356
  if (mask_l is None) or (mask_l.getbbox() is None):
357
- raise gr.Error("没有检测到有效的 mask:请确认已在画布上涂抹或上传 mask 图片。")
358
 
359
  print(f"[DEBUG] Mask: mode={mask_l.mode}, size={mask_l.size}, bbox={mask_l.getbbox()}")
360
 
361
- # 3) 确定输出尺寸
362
  width, height = prepare_size_for_flux(depth_rgb, target_max=max_side)
363
  orig_w, orig_h = image.size
364
  print(f"[DEBUG] FLUX size: {width}x{height}, original: {orig_w}x{orig_h}")
365
 
366
- # 4) 运行 FLUX pipeline
367
- # 关键修复:image 参数应该传入 depth_rgb 而不是原图
368
  pipe = get_pipe()
369
  generator = torch.Generator("cpu").manual_seed(int(seed)) if (seed is not None and seed >= 0) else torch.Generator("cpu").manual_seed(random.randint(0, 2**31 - 1))
370
 
371
  result = pipe(
372
  prompt=prompt,
373
- image=depth_rgb, # 修复:传入彩色深度图而不是原图
374
  mask_image=mask_l,
375
  width=width,
376
  height=height,
@@ -378,26 +469,101 @@ def run_depth_and_fill(
378
  num_inference_steps=int(steps),
379
  max_sequence_length=512,
380
  generator=generator,
381
- depth=depth_rgb, # depth 参数也传入彩色深度图
382
  ).images[0]
383
 
384
  final_result = result.resize((orig_w, orig_h), Image.BICUBIC)
385
 
386
- # 返回结果和 mask 预览
387
  mask_preview = mask_l.resize((orig_w, orig_h), Image.NEAREST).convert("RGB")
388
  return final_result, mask_preview
389
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390
 
391
  # ---------------- UI ----------------
392
  with gr.Blocks() as demo:
393
- gr.Markdown("## GeoRemover · Depth Removal (Depth(color) → FLUX Fill)")
394
 
395
  with gr.Row():
396
  with gr.Column(scale=1):
397
- # 输入图
398
  img = gr.Image(label="Upload image", type="pil")
399
 
400
- # Mask 两种方式:上传 or
401
  with gr.Tab("Upload mask"):
402
  mask_upload = gr.Image(label="Mask (optional)", type="pil")
403
 
@@ -407,15 +573,14 @@ with gr.Blocks() as demo:
407
  sketch = gr.ImageEditor(
408
  label="Sketch mask (draw with brush)",
409
  type="pil",
410
- # 画笔只给纯红,方便我们精确提取笔迹
411
  brush=gr.Brush(colors=["#FF0000"], default_size=24)
412
  )
413
 
414
-
415
  # prompt
416
  prompt = gr.Textbox(label="Prompt", value="A beautiful scene")
417
 
418
- # 可调参数
419
  with gr.Accordion("Advanced (Depth & FLUX)", open=False):
420
  encoder = gr.Dropdown(["vits", "vitl"], value="vitl", label="Depth encoder")
421
  max_res = gr.Slider(512, 2048, value=1280, step=64, label="Depth: max_res")
@@ -425,30 +590,36 @@ with gr.Blocks() as demo:
425
  mask_dilate_px = gr.Slider(0, 128, value=0, step=1, label="Mask dilation (px)")
426
  guidance_scale = gr.Slider(0, 50, value=30, step=0.5, label="FLUX: guidance_scale")
427
  steps = gr.Slider(10, 75, value=50, step=1, label="FLUX: steps")
428
- seed = gr.Number(value=0, precision=0, label="Seed (>=0 固定;留空随机)")
429
 
430
  run_btn = gr.Button("Run", variant="primary")
 
 
 
431
 
432
  with gr.Column(scale=1):
433
  depth_preview = gr.Image(label="Depth preview (colored)", interactive=False)
434
- mask_preview = gr.Image(label="Mask preview (what will be removed)", interactive=False)
435
  out = gr.Image(label="Output")
 
 
 
436
 
437
- # 事件:上传图片后生成"彩色深度预览"
438
  img.change(
439
  fn=preview_depth,
440
  inputs=[img, encoder, max_res, input_size, fp32],
441
  outputs=[depth_preview],
442
  )
443
 
444
- # 准备画布:把原图或"彩色深度图"放进 ImageEditor
445
  prepare_btn.click(
446
  fn=prepare_canvas,
447
  inputs=[img, depth_preview, draw_source],
448
  outputs=[sketch],
449
  )
450
 
451
- # 运行
452
  run_btn.click(
453
  fn=run_depth_and_fill,
454
  inputs=[img, mask_upload, sketch, prompt, encoder, max_res, input_size, fp32,
@@ -457,6 +628,17 @@ with gr.Blocks() as demo:
457
  api_name="run",
458
  )
459
 
 
 
 
 
 
 
 
 
 
 
 
460
  if __name__ == "__main__":
461
  os.environ.setdefault("HF_HUB_DISABLE_TELEMETRY", "1")
462
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
1
  import gradio as gr
2
  import spaces
3
  import sys, pathlib
 
13
  FluxFillPipeline_token12_depth_only as FluxFillPipeline,
14
  )
15
 
16
+ # ==== STAGE-2 ONLY ADDED: import Stage-2 Pipeline (do not touch Stage-1) ====
17
+ from diffusers.pipelines.flux.pipeline_flux_fill_unmasked_image_condition_version import (
18
+ FluxFillPipeline_token12_depth as FluxFillPipelineStage2,
19
+ )
20
+ # ===========================================================================
21
+
22
  import os
23
  import sys
24
  import pathlib
 
26
  import random
27
  from typing import Optional, Tuple
28
 
 
29
  import torch
30
  from PIL import Image, ImageOps
31
  import numpy as np
 
67
 
68
  def ensure_assets_if_missing():
69
  if os.getenv("SKIP_ASSET_DOWNLOAD") == "1":
70
+ print("↪️ SKIP_ASSET_DOWNLOAD=1 -> skip asset download check")
71
  return
72
  if _have_all_assets():
73
  print("✅ Assets already present")
 
93
  # ---------------- Global singletons ----------------
94
  _MODELS: dict[str, DepthModel] = {}
95
  _PIPE: Optional[FluxFillPipeline] = None
96
+ # ==== STAGE-2 ONLY ADDED: singleton ====
97
+ _PIPE_STAGE2: Optional[FluxFillPipelineStage2] = None
98
+ # ======================================
99
 
100
  def get_model(encoder: str) -> DepthModel:
101
  if encoder not in _MODELS:
 
110
  device = "cuda" if torch.cuda.is_available() else "cpu"
111
  dtype = torch.bfloat16 if device == "cuda" else torch.float32
112
 
113
+ local_flux = BASE_DIR / "code_edit" / "flux_cache"
 
114
  use_local = local_flux.exists()
115
 
 
116
  hf_token = os.environ.get("HF_TOKEN")
117
 
 
118
  try:
119
  from huggingface_hub import hf_hub_enable_hf_transfer
120
  hf_hub_enable_hf_transfer()
 
128
  local_flux, torch_dtype=dtype
129
  ).to(device)
130
  else:
131
+ # Fetch online (requires gated access + token)
132
  pipe = FluxFillPipeline.from_pretrained(
133
  "black-forest-labs/FLUX.1-Fill-dev",
134
  torch_dtype=dtype,
135
+ token=hf_token,
 
136
  ).to(device)
137
  except Exception as e:
138
  raise RuntimeError(
 
143
 
144
  # -------- LoRA (stage1) --------
145
  lora_dir = CODE_EDIT / "stage1" / "checkpoint-4800"
146
+ lora_file = "pytorch_lora_weights.safetensors" # your actual file name
147
  adapter_name = "stage1"
148
 
149
  if lora_dir.exists():
150
  try:
151
+ import peft # assert backend is present
152
  print(f"[pipe] loading LoRA from: {lora_dir}/{lora_file}")
153
  pipe.load_lora_weights(
154
  str(lora_dir),
155
+ weight_name=lora_file, # important: specify filename
156
+ adapter_name=adapter_name # a switchable name
157
  )
158
+ # Newer diffusers prefer set_adapters
159
  try:
160
  pipe.set_adapters(adapter_name, scale=1.0)
161
  print(f"[pipe] set_adapters('{adapter_name}', scale=1.0)")
162
  except Exception as e_set:
163
  print(f"[pipe] set_adapters not available ({e_set}); trying fuse_lora()")
164
+ # Older / pipelines without set_adapters: fuse LoRA
165
  try:
166
  pipe.fuse_lora(lora_scale=1.0)
167
  print("[pipe] fuse_lora(lora_scale=1.0) done")
 
178
  _PIPE = pipe
179
  return pipe
180
 
181
+ # ==== STAGE-2 ONLY ADDED: Stage-2 loader (no change to Stage-1 logic) ====
182
+ def get_pipe_stage2() -> FluxFillPipelineStage2:
183
+ """
184
+ Load Stage-2 FluxFillPipeline_token12_depth and mount the Stage-2 LoRA.
185
+ """
186
+ global _PIPE_STAGE2
187
+ if _PIPE_STAGE2 is not None:
188
+ return _PIPE_STAGE2
189
+
190
+ device = "cuda" if torch.cuda.is_available() else "cpu"
191
+ dtype = torch.bfloat16 if device == "cuda" else torch.float32
192
+
193
+ local_flux = BASE_DIR / "code_edit" / "flux_cache"
194
+ use_local = local_flux.exists()
195
+ hf_token = os.environ.get("HF_TOKEN")
196
+
197
+ try:
198
+ from huggingface_hub import hf_hub_enable_hf_transfer
199
+ hf_hub_enable_hf_transfer()
200
+ except Exception:
201
+ pass
202
+
203
+ print(f"[stage2] loading FLUX.1-Fill-dev (dtype={dtype}, device={device}, local={use_local})")
204
+ try:
205
+ if use_local:
206
+ pipe2 = FluxFillPipelineStage2.from_pretrained(local_flux, torch_dtype=dtype).to(device)
207
+ else:
208
+ pipe2 = FluxFillPipelineStage2.from_pretrained(
209
+ "black-forest-labs/FLUX.1-Fill-dev",
210
+ torch_dtype=dtype,
211
+ token=hf_token,
212
+ ).to(device)
213
+ except Exception as e:
214
+ raise RuntimeError("Stage-2: Failed to load FLUX.1-Fill-dev.") from e
215
+
216
+ # Load Stage-2 LoRA
217
+ lora_dir2 = CODE_EDIT / "stage2" / "checkpoint-20000"
218
+ candidate_names = [
219
+ "pytorch_lora_weights.safetensors",
220
+ "adapter_model.safetensors",
221
+ "lora.safetensors",
222
+ ]
223
+ weight_name = None
224
+ for name in candidate_names:
225
+ if (lora_dir2 / name).is_file():
226
+ weight_name = name
227
+ break
228
+
229
+ if not lora_dir2.exists():
230
+ raise RuntimeError(f"Stage-2 LoRA dir not found: {lora_dir2}")
231
+ if weight_name is None:
232
+ raise RuntimeError(
233
+ f"Stage-2 LoRA weight not found under {lora_dir2}. "
234
+ f"Tried: {candidate_names}"
235
+ )
236
+
237
+ try:
238
+ import peft # noqa: F401
239
+ except Exception as e:
240
+ raise RuntimeError(
241
+ "peft is not installed (requires peft>=0.11 to load LoRA)."
242
+ ) from e
243
+
244
+ try:
245
+ print(f"[stage2] loading LoRA: {lora_dir2}/{weight_name}")
246
+ pipe2.load_lora_weights(
247
+ str(lora_dir2),
248
+ weight_name=weight_name,
249
+ adapter_name="stage2",
250
+ )
251
+ try:
252
+ pipe2.set_adapters("stage2", scale=1.0)
253
+ print("[stage2] set_adapters('stage2', 1.0)")
254
+ except Exception as e_set:
255
+ print(f"[stage2] set_adapters not available ({e_set}); trying fuse_lora()")
256
+ try:
257
+ pipe2.fuse_lora(lora_scale=1.0)
258
+ print("[stage2] fuse_lora(lora_scale=1.0) done")
259
+ except Exception as e_fuse:
260
+ raise RuntimeError(f"Stage-2 fuse_lora failed: {e_fuse}") from e_fuse
261
+ except Exception as e:
262
+ raise RuntimeError(f"Stage-2 LoRA load failed: {e}") from e
263
+
264
+ _PIPE_STAGE2 = pipe2
265
+ return pipe2
266
+ # ==========================================================================
267
+
268
  # ---------------- Mask helpers ----------------
269
  def to_grayscale_mask(im: Image.Image) -> Image.Image:
270
  """
271
+ Convert any RGBA/RGB/L image to L mode.
272
+ Output: white = region to remove/fill, black = keep.
273
  """
274
  if im.mode == "RGBA":
275
  mask = im.split()[-1] # alpha as mask
276
  else:
277
  mask = im.convert("L")
278
+ # simple binarization & denoise
279
  mask = mask.point(lambda p: 255 if p > 16 else 0)
280
+ return mask # do not invert; white = mask region
281
 
282
  def dilate_mask(mask_l: Image.Image, px: int) -> Image.Image:
283
+ """Dilate white region by ~px pixels."""
284
  if px <= 0:
285
  return mask_l
286
  arr = np.array(mask_l, dtype=np.uint8)
287
  kernel = np.ones((3, 3), np.uint8)
288
+ iters = max(1, int(px // 2)) # heuristic
289
  dilated = cv2.dilate(arr, kernel, iterations=iters)
290
  return Image.fromarray(dilated, mode="L")
291
 
292
  def _mask_from_red(img: Image.Image, out_size: Tuple[int, int]) -> Image.Image:
293
  """
294
+ Extract "pure red strokes" as a binary mask (white=brush, black=others) from an RGBA/RGB image.
295
+ Thresholds are a bit lenient to tolerate compression/resampling.
296
  """
297
  arr = np.array(img.convert("RGBA"))
298
  r, g, b, a = arr[..., 0], arr[..., 1], arr[..., 2], arr[..., 3]
299
 
300
+ # condition: high red, low green/blue, and alpha>0
301
  red_hit = (r >= 200) & (g <= 40) & (b <= 40) & (a > 0)
302
 
303
  mask = (red_hit.astype(np.uint8) * 255)
 
311
  dilate_px: int = 0,
312
  ) -> Optional[Image.Image]:
313
  """
314
+ Rules:
315
+ 1) If user uploaded a mask: use it directly (white=mask)
316
+ 2) Otherwise, from ImageEditor output, only recognize "red strokes" as mask:
317
+ - Try sketch_data['mask'] first (some versions provide it)
318
+ - Else merge red strokes from sketch_data['layers'][*]['image']
319
+ - If still none, try sketch_data['composite'] for red strokes
320
  """
321
+ # 1) Uploaded mask has highest priority
322
  if isinstance(upload_mask, Image.Image):
323
  m = to_grayscale_mask(upload_mask).resize(base_image.size, Image.NEAREST)
324
  return dilate_mask(m, dilate_px) if dilate_px > 0 else m
325
 
326
+ # 2) Hand-drawn (ImageEditor)
327
  if isinstance(sketch_data, dict):
328
+ # 2a) explicit mask (still supported)
329
  m = sketch_data.get("mask")
330
  if isinstance(m, Image.Image):
331
  m = to_grayscale_mask(m).resize(base_image.size, Image.NEAREST)
332
  return dilate_mask(m, dilate_px) if dilate_px > 0 else m
333
 
334
+ # 2b) merge red strokes from layers
335
  layers = sketch_data.get("layers")
336
  acc = None
337
  if isinstance(layers, list) and layers:
 
342
  li = lyr.get("image") or lyr.get("mask")
343
  if isinstance(li, Image.Image):
344
  m_layer = _mask_from_red(li, base_image.size)
345
+ # merge: any layer with strokes contributes to mask
346
  acc = ImageOps.lighter(acc, m_layer)
347
  if acc.getbbox() is not None:
348
  return dilate_mask(acc, dilate_px) if dilate_px > 0 else acc
349
 
350
+ # 2c) finally, search composite for red strokes
351
  comp = sketch_data.get("composite")
352
  if isinstance(comp, Image.Image):
353
  m_comp = _mask_from_red(comp, base_image.size)
354
  if m_comp.getbbox() is not None:
355
  return dilate_mask(m_comp, dilate_px) if dilate_px > 0 else m_comp
356
 
357
+ # 3) still none -> return None (caller will prompt for a mask)
358
  return None
359
 
360
 
361
  def _round_mult64(x: float, mode: str = "nearest") -> int:
362
  """
363
+ Align x to a multiple of 64:
364
+ - mode="ceil" round up
365
+ - mode="floor" round down
366
+ - mode="nearest" nearest multiple
367
  """
368
  if mode == "ceil":
369
  return int((x + 63) // 64) * 64
 
374
 
375
  def prepare_size_for_flux(img: Image.Image, target_max: int = 1024) -> tuple[int, int]:
376
  """
377
+ Steps:
378
+ 1) First round w,h up to multiples of 64 (avoid too-small sizes)
379
+ 2) Fix the long side to target_max (default 1024)
380
+ 3) Scale the short side proportionally and align to a multiple of 64 (at least 64)
381
  """
382
  w, h = img.size
383
 
384
+ # 1) round each up to multiple of 64
385
  w1 = max(64, _round_mult64(w, mode="ceil"))
386
  h1 = max(64, _round_mult64(h, mode="ceil"))
387
 
388
+ # 2) fix long side to target_max; scale short side
389
  if w1 >= h1:
390
+ out_w = target_max
391
  scaled_h = h1 * (target_max / w1)
392
  out_h = max(64, _round_mult64(scaled_h, mode="nearest"))
393
  else:
 
396
  out_w = max(64, _round_mult64(scaled_w, mode="nearest"))
397
 
398
  return int(out_w), int(out_h)
399
+
400
  @spaces.GPU
401
+ # ---------------- Preview depth for canvas (colored) ----------------
402
  def preview_depth(image: Optional[Image.Image], encoder: str, max_res: int, input_size: int, fp32: bool):
403
  if image is None:
404
  return None
405
  dm = get_model(encoder)
406
+ # colored visualization (RGB), consistent with your previous colormap style
407
  d_rgb = dm.infer(image=image, max_res=max_res, input_size=input_size, fp32=fp32, grayscale=False)
408
  return d_rgb
409
 
410
  def prepare_canvas(image, depth_img, source):
411
  base = depth_img if source == "depth" else image
412
  if base is None:
413
+ raise gr.Error('Please upload an image (and wait for the depth preview), then click "Prepare canvas".')
414
+ # Use a generic gr.update to set ImageEditor value
415
  return gr.update(value=base)
416
 
417
  # ---------------- Two-stage pipeline: depth(color) -> fill ----------------
 
432
  seed: Optional[int],
433
  ) -> Tuple[Image.Image, Image.Image]:
434
  if image is None:
435
+ raise gr.Error("Please upload an image first.")
436
 
437
+ # 1) produce a colored depth map (RGB)
438
  depth_model = get_model(encoder)
439
  depth_rgb: Image.Image = depth_model.infer(
440
  image=image, max_res=max_res, input_size=input_size, fp32=fp32, grayscale=False
 
442
 
443
  print(f"[DEBUG] Depth RGB: mode={depth_rgb.mode}, size={depth_rgb.size}")
444
 
445
+ # 2) extract mask (uploaded > drawn)
446
  mask_l = pick_mask(mask_upload, sketch, image, dilate_px=mask_dilate_px)
447
  if (mask_l is None) or (mask_l.getbbox() is None):
448
+ raise gr.Error("No valid mask detected: please draw on the canvas or upload a mask image.")
449
 
450
  print(f"[DEBUG] Mask: mode={mask_l.mode}, size={mask_l.size}, bbox={mask_l.getbbox()}")
451
 
452
+ # 3) decide output size
453
  width, height = prepare_size_for_flux(depth_rgb, target_max=max_side)
454
  orig_w, orig_h = image.size
455
  print(f"[DEBUG] FLUX size: {width}x{height}, original: {orig_w}x{orig_h}")
456
 
457
+ # 4) run FLUX pipeline
458
+ # Key fix: pass depth_rgb as `image` instead of the original image
459
  pipe = get_pipe()
460
  generator = torch.Generator("cpu").manual_seed(int(seed)) if (seed is not None and seed >= 0) else torch.Generator("cpu").manual_seed(random.randint(0, 2**31 - 1))
461
 
462
  result = pipe(
463
  prompt=prompt,
464
+ image=depth_rgb, # FIX: pass the colored depth map, not the original image
465
  mask_image=mask_l,
466
  width=width,
467
  height=height,
 
469
  num_inference_steps=int(steps),
470
  max_sequence_length=512,
471
  generator=generator,
472
+ depth=depth_rgb, # also feed depth input (colored depth)
473
  ).images[0]
474
 
475
  final_result = result.resize((orig_w, orig_h), Image.BICUBIC)
476
 
477
+ # return result and mask preview
478
  mask_preview = mask_l.resize((orig_w, orig_h), Image.NEAREST).convert("RGB")
479
  return final_result, mask_preview
480
 
481
+ def _to_pil_rgb(img_like) -> Image.Image:
482
+ """Normalize input to PIL RGB. Supports PIL/L/RGBA/np.array."""
483
+ if isinstance(img_like, Image.Image):
484
+ return img_like.convert("RGB")
485
+ # numpy array -> PIL
486
+ try:
487
+ arr = np.array(img_like)
488
+ if arr.ndim == 2: # grayscale
489
+ arr = np.stack([arr, arr, arr], axis=-1)
490
+ return Image.fromarray(arr.astype(np.uint8), mode="RGB")
491
+ except Exception:
492
+ raise gr.Error("Stage-2: `depth` / `depth_image` is not a valid image. Please check the provided objects.")
493
+
494
+ # ==== STAGE-2 ONLY ADDED: Stage-2 inference (takes Stage-1 output + Stage-1 depth preview) ====
495
+ @spaces.GPU
496
+ def run_stage2_refine(
497
+ image: Image.Image, # original image (RGB)
498
+ stage1_out: Image.Image, # output from Stage-1
499
+ depth_img_from_stage1_input: Image.Image, # ★ new: Stage-1 depth preview (from UI)
500
+ mask_upload: Optional[Image.Image],
501
+ sketch: Optional[dict],
502
+ prompt: str,
503
+ encoder: str,
504
+ max_res: int,
505
+ input_size: int,
506
+ fp32: bool,
507
+ max_side: int,
508
+ guidance_scale: float,
509
+ steps: int,
510
+ seed: Optional[int],
511
+ ) -> Image.Image:
512
+ if image is None or stage1_out is None:
513
+ raise gr.Error("Please complete Stage-1 generation first (needs original image and Stage-1 output).")
514
+
515
+ # allow refine without mask (use all-black -> no masked area)
516
+ mask_l = pick_mask(mask_upload, sketch, image, dilate_px=0)
517
+ if (mask_l is None) or (mask_l.getbbox() is None):
518
+ mask_l = Image.new("L", image.size, 0)
519
+
520
+ # unify sizes (based on original image)
521
+ width, height = prepare_size_for_flux(image, target_max=max_side)
522
+ orig_w, orig_h = image.size
523
+
524
+ pipe2 = get_pipe_stage2()
525
+ g2 = torch.Generator("cpu").manual_seed(int(seed)) if (seed is not None and seed >= 0) \
526
+ else torch.Generator("cpu").manual_seed(random.randint(0, 2**31 - 1))
527
+ depth_pil = _to_pil_rgb(stage1_out) # for `depth`
528
+ depth_image_pil = _to_pil_rgb(depth_img_from_stage1_input) # for `depth_image`
529
+ image_rgb = _to_pil_rgb(image) # normalize original image to RGB
530
+
531
+ # resize to (width, height)
532
+ depth_pil = depth_pil.resize((width, height), Image.BICUBIC)
533
+ depth_image_pil = depth_image_pil.resize((width, height), Image.BICUBIC)
534
+ # ★★ Mapping:
535
+ # - image = original RGB
536
+ # - depth = Stage-1 output (treated as updated geometry)
537
+ # - depth_image = Stage-1 input depth (UI's depth preview)
538
+ out2 = pipe2(
539
+ prompt=prompt,
540
+ image=image, # ← original RGB
541
+ mask_image=mask_l,
542
+ width=width,
543
+ height=height,
544
+ guidance_scale=float(guidance_scale),
545
+ num_inference_steps=int(steps),
546
+ max_sequence_length=512,
547
+ generator=g2,
548
+ depth=depth_pil, # ← Stage-1 output as `depth`
549
+ depth_image=depth_image_pil, # ← Stage-1 depth preview as `depth_image`
550
+ ).images[0]
551
+
552
+ out2 = out2.resize((orig_w * 3, orig_h), Image.BICUBIC) # preserve your original ×3 display layout
553
+ return out2
554
+
555
+ # ===================================================================
556
 
557
  # ---------------- UI ----------------
558
  with gr.Blocks() as demo:
559
+ gr.Markdown("## GeoRemover · Depth Removal (Depth (colored) → FLUX Fill)")
560
 
561
  with gr.Row():
562
  with gr.Column(scale=1):
563
+ # input image
564
  img = gr.Image(label="Upload image", type="pil")
565
 
566
+ # Mask: upload or draw
567
  with gr.Tab("Upload mask"):
568
  mask_upload = gr.Image(label="Mask (optional)", type="pil")
569
 
 
573
  sketch = gr.ImageEditor(
574
  label="Sketch mask (draw with brush)",
575
  type="pil",
576
+ # Provide red-only brush for precise extraction of strokes
577
  brush=gr.Brush(colors=["#FF0000"], default_size=24)
578
  )
579
 
 
580
  # prompt
581
  prompt = gr.Textbox(label="Prompt", value="A beautiful scene")
582
 
583
+ # tunables
584
  with gr.Accordion("Advanced (Depth & FLUX)", open=False):
585
  encoder = gr.Dropdown(["vits", "vitl"], value="vitl", label="Depth encoder")
586
  max_res = gr.Slider(512, 2048, value=1280, step=64, label="Depth: max_res")
 
590
  mask_dilate_px = gr.Slider(0, 128, value=0, step=1, label="Mask dilation (px)")
591
  guidance_scale = gr.Slider(0, 50, value=30, step=0.5, label="FLUX: guidance_scale")
592
  steps = gr.Slider(10, 75, value=50, step=1, label="FLUX: steps")
593
+ seed = gr.Number(value=0, precision=0, label="Seed (>=0 fixed; empty = random)")
594
 
595
  run_btn = gr.Button("Run", variant="primary")
596
+ # ==== STAGE-2 ONLY ADDED: add Stage-2 button ====
597
+ run_btn_stage2 = gr.Button("Run Stage-2 (Refine)", variant="secondary")
598
+ # =================================================
599
 
600
  with gr.Column(scale=1):
601
  depth_preview = gr.Image(label="Depth preview (colored)", interactive=False)
602
+ mask_preview = gr.Image(label="Mask preview (to be removed)", interactive=False)
603
  out = gr.Image(label="Output")
604
+ # ==== STAGE-2 ONLY ADDED: Stage-2 output ====
605
+ out_stage2 = gr.Image(label="Output (Stage-2 refine)")
606
+ # ============================================
607
 
608
+ # Event: when image changes, compute the colored depth preview
609
  img.change(
610
  fn=preview_depth,
611
  inputs=[img, encoder, max_res, input_size, fp32],
612
  outputs=[depth_preview],
613
  )
614
 
615
+ # Prepare canvas: put original image or colored depth image into ImageEditor
616
  prepare_btn.click(
617
  fn=prepare_canvas,
618
  inputs=[img, depth_preview, draw_source],
619
  outputs=[sketch],
620
  )
621
 
622
+ # Run Stage-1 (wiring unchanged)
623
  run_btn.click(
624
  fn=run_depth_and_fill,
625
  inputs=[img, mask_upload, sketch, prompt, encoder, max_res, input_size, fp32,
 
628
  api_name="run",
629
  )
630
 
631
+ # ==== STAGE-2 ONLY ADDED: run after Stage-1 has produced a result ====
632
+ run_btn_stage2.click(
633
+ fn=run_stage2_refine,
634
+ inputs=[img, out, depth_preview, # ← pass depth_preview as the 3rd input to Stage-2
635
+ mask_upload, sketch, prompt, encoder, max_res, input_size, fp32,
636
+ max_side, guidance_scale, steps, seed],
637
+ outputs=[out_stage2],
638
+ api_name="run_stage2",
639
+ )
640
+ # ====================================================================
641
+
642
  if __name__ == "__main__":
643
  os.environ.setdefault("HF_HUB_DISABLE_TELEMETRY", "1")
644
+ demo.launch(server_name="0.0.0.0", server_port=7860)