File size: 27,338 Bytes
2668d28
 
a63b67d
 
 
 
a01e858
a63b67d
 
 
 
 
 
 
 
 
134053b
 
 
 
 
 
5a0778e
 
2f713b7
a01e858
2f713b7
5a0778e
2f713b7
 
 
5a0778e
2f713b7
5a0778e
2f713b7
 
 
5a0778e
2f713b7
 
 
 
 
 
 
a01e858
2f713b7
 
 
69b2678
a01e858
2f713b7
 
 
 
 
 
a01e858
2f713b7
 
69b2678
 
 
 
 
 
2f713b7
a01e858
 
 
 
2f713b7
134053b
2f713b7
 
 
 
 
 
5a0778e
2f713b7
5a0778e
2f713b7
5a0778e
 
2f713b7
 
 
 
5a0778e
 
2f713b7
5a0778e
2f713b7
5a0778e
2f713b7
a01e858
2f713b7
134053b
 
 
69b2678
 
 
 
 
 
2f713b7
a01e858
 
 
2f713b7
 
 
6f74ce3
2f713b7
 
 
134053b
6f74ce3
 
 
 
 
 
 
 
 
 
 
 
 
a01e858
6f74ce3
 
 
 
134053b
6f74ce3
 
 
 
a01e858
6f74ce3
 
a01e858
5f25c59
a01e858
5f25c59
 
 
 
a01e858
5f25c59
 
 
a01e858
 
5f25c59
 
 
a01e858
5f25c59
 
 
 
 
 
 
 
 
a01e858
5f25c59
 
 
 
 
2f713b7
 
 
134053b
 
 
a01e858
134053b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a01e858
134053b
 
 
 
 
a01e858
134053b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f713b7
 
 
134053b
 
2f713b7
 
a01e858
2f713b7
 
a01e858
2f713b7
a01e858
2f713b7
 
a01e858
2f713b7
 
 
 
134053b
2f713b7
 
 
 
 
a01e858
 
2f713b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a01e858
 
 
134053b
 
 
2f713b7
134053b
2f713b7
 
 
 
134053b
2f713b7
134053b
2f713b7
 
 
 
 
134053b
2f713b7
 
 
 
 
 
 
 
 
 
a01e858
2f713b7
 
 
134053b
2f713b7
 
 
 
 
 
a01e858
2f713b7
 
 
 
134053b
 
 
 
2f713b7
 
 
 
 
 
 
 
 
 
134053b
a01e858
134053b
a01e858
2f713b7
 
 
 
 
 
134053b
2f713b7
 
 
 
 
 
 
 
134053b
41300b4
134053b
2f713b7
 
 
 
 
 
 
 
 
 
134053b
2f713b7
 
a01e858
5a0778e
2f713b7
69b2678
2f713b7
 
 
 
 
 
 
 
 
 
 
 
 
 
134053b
2f713b7
134053b
2f713b7
 
 
 
a01e858
2f713b7
 
134053b
2f713b7
 
a01e858
 
2f713b7
 
134053b
2f713b7
 
 
69b2678
a01e858
2f713b7
a01e858
 
 
 
 
2f713b7
 
 
a01e858
2f713b7
 
 
 
 
 
 
a01e858
2f713b7
 
 
a01e858
134053b
2f713b7
 
 
134053b
 
 
 
 
 
a01e858
134053b
 
 
a01e858
134053b
a01e858
134053b
 
 
 
a01e858
134053b
 
 
 
 
 
 
 
 
 
 
 
 
a01e858
134053b
a01e858
134053b
 
 
 
a01e858
134053b
 
 
 
a01e858
 
 
134053b
a01e858
134053b
 
a01e858
134053b
a01e858
134053b
 
a01e858
 
 
 
 
134053b
 
a01e858
134053b
 
 
 
 
 
 
a01e858
 
134053b
 
a01e858
134053b
 
2f713b7
69b2678
a01e858
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f713b7
69b2678
2f713b7
a01e858
 
 
 
 
2f713b7
134053b
2f713b7
a01e858
 
 
 
2f713b7
 
a01e858
 
 
 
 
 
 
 
 
 
 
 
 
 
2f713b7
a01e858
2f713b7
a01e858
2f713b7
 
a01e858
 
 
 
 
 
2f713b7
a01e858
2f713b7
a01e858
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f713b7
 
a01e858
 
2f713b7
 
a01e858
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f713b7
 
 
 
 
 
a01e858
2f713b7
 
 
 
 
 
a01e858
2f713b7
 
 
 
 
 
a01e858
 
 
 
2f713b7
5a0778e
a01e858
134053b
 
a01e858
134053b
 
 
 
 
 
5a0778e
2f713b7
134053b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
import gradio as gr
import spaces
import sys, pathlib
BASE_DIR = pathlib.Path(__file__).resolve().parent
LOCAL_DIFFUSERS_SRC = BASE_DIR / "code_edit" / "diffusers" / "src"

# Ensure local diffusers is importable
if (LOCAL_DIFFUSERS_SRC / "diffusers").exists():
    sys.path.insert(0, str(LOCAL_DIFFUSERS_SRC))
else:
    raise RuntimeError(f"Local diffusers not found at: {LOCAL_DIFFUSERS_SRC}")

from diffusers.pipelines.flux.pipeline_flux_fill_unmasked_image_condition_version import (
    FluxFillPipeline_token12_depth_only as FluxFillPipeline,
)

# ==== STAGE-2 ONLY ADDED: import Stage-2 Pipeline (do not touch Stage-1) ====
from diffusers.pipelines.flux.pipeline_flux_fill_unmasked_image_condition_version import (
    FluxFillPipeline_token12_depth as FluxFillPipelineStage2,
)
# ===========================================================================

import os
import subprocess
import random
from typing import Optional, Tuple, Dict, Any

import torch
from PIL import Image, ImageOps
import numpy as np
import cv2

# ---------------- Paths & assets ----------------
BASE_DIR = pathlib.Path(__file__).resolve().parent
CODE_DEPTH = BASE_DIR / "code_depth"
CODE_EDIT = BASE_DIR / "code_edit"
GET_ASSETS = BASE_DIR / "get_assets.sh"

EXPECTED_ASSETS = [
    BASE_DIR / "code_depth" / "checkpoints" / "video_depth_anything_vits.pth",
    BASE_DIR / "code_depth" / "checkpoints" / "video_depth_anything_vitl.pth",
    BASE_DIR / "code_edit" / "stage1" / "checkpoint-4800" / "pytorch_lora_weights.safetensors",
    BASE_DIR / "code_edit" / "stage2" / "checkpoint-20000" / "pytorch_lora_weights.safetensors",
]

# Import depth helper
if str(CODE_DEPTH) not in sys.path:
    sys.path.insert(0, str(CODE_DEPTH))
from depth_infer import DepthModel  # noqa: E402

# Import your custom diffusers (local fork)
if str(CODE_EDIT / "diffusers") not in sys.path:
    sys.path.insert(0, str(CODE_EDIT / "diffusers"))
from diffusers.pipelines.flux.pipeline_flux_fill_unmasked_image_condition_version import (  # type: ignore # noqa: E402
    FluxFillPipeline_token12_depth_only as FluxFillPipeline,
)

# ---------------- Asset preparation (on-demand) ----------------
def _have_all_assets() -> bool:
    return all(p.is_file() for p in EXPECTED_ASSETS)

def _ensure_executable(p: pathlib.Path):
    if not p.exists():
        raise FileNotFoundError(f"Not found: {p}")
    os.chmod(p, os.stat(p).st_mode | 0o111)

def ensure_assets_if_missing():
    """
    If SKIP_ASSET_DOWNLOAD=1 -> skip checks.
    Otherwise ensure checkpoints/LoRAs exist; if missing, run get_assets.sh.
    """
    if os.getenv("SKIP_ASSET_DOWNLOAD") == "1":
        print("↪️  SKIP_ASSET_DOWNLOAD=1 -> skip asset download check")
        return
    if _have_all_assets():
        print("✅ Assets already present")
        return
    print("⬇️  Missing assets, running get_assets.sh ...")
    _ensure_executable(GET_ASSETS)
    subprocess.run(
        ["bash", str(GET_ASSETS)],
        check=True,
        cwd=str(BASE_DIR),
        env={**os.environ, "HF_HUB_DISABLE_TELEMETRY": "1"},
    )
    if not _have_all_assets():
        missing = [str(p.relative_to(BASE_DIR)) for p in EXPECTED_ASSETS if not p.exists()]
        raise RuntimeError(f"Assets missing after get_assets.sh: {missing}")
    print("✅ Assets ready.")

try:
    ensure_assets_if_missing()
except Exception as e:
    print(f"⚠️ Asset prepare failed: {e}")

# ---------------- Global singletons ----------------
_MODELS: Dict[str, DepthModel] = {}
_PIPE: Optional[FluxFillPipeline] = None
# ==== STAGE-2 ONLY ADDED: singleton ====
_PIPE_STAGE2: Optional[FluxFillPipelineStage2] = None
# ======================================

def get_model(encoder: str) -> DepthModel:
    if encoder not in _MODELS:
        _MODELS[encoder] = DepthModel(BASE_DIR, encoder=encoder)
    return _MODELS[encoder]

def get_pipe() -> FluxFillPipeline:
    """
    Load Stage-1 pipeline (FluxFillPipeline_token12_depth_only) and mount Stage-1 LoRA if present.
    """
    global _PIPE
    if _PIPE is not None:
        return _PIPE

    device = "cuda" if torch.cuda.is_available() else "cpu"
    dtype = torch.bfloat16 if device == "cuda" else torch.float32

    local_flux = BASE_DIR / "code_edit" / "flux_cache"
    use_local = local_flux.exists()

    hf_token = os.environ.get("HF_TOKEN")

    try:
        from huggingface_hub import hf_hub_enable_hf_transfer
        hf_hub_enable_hf_transfer()
    except Exception:
        pass

    print(f"[pipe] loading FLUX.1-Fill-dev (dtype={dtype}, device={device}, local={use_local})")
    try:
        if use_local:
            pipe = FluxFillPipeline.from_pretrained(local_flux, torch_dtype=dtype).to(device)
        else:
            pipe = FluxFillPipeline.from_pretrained(
                "black-forest-labs/FLUX.1-Fill-dev",
                torch_dtype=dtype,
                token=hf_token,
            ).to(device)
    except Exception as e:
        raise RuntimeError(
            "Failed to load FLUX.1-Fill-dev. "
            "Ensure gated access and HF_TOKEN; or pre-download to local cache."
        ) from e

    # -------- LoRA (Stage-1) --------
    lora_dir = CODE_EDIT / "stage1" / "checkpoint-4800"
    lora_file = "pytorch_lora_weights.safetensors"
    adapter_name = "stage1"

    if lora_dir.exists():
        try:
            import peft  # assert backend presence
            print(f"[pipe] loading LoRA from: {lora_dir}/{lora_file}")
            pipe.load_lora_weights(
                str(lora_dir),
                weight_name=lora_file,
                adapter_name=adapter_name,
            )
            try:
                pipe.set_adapters(adapter_name, scale=1.0)
                print(f"[pipe] set_adapters('{adapter_name}', 1.0)")
            except Exception as e_set:
                print(f"[pipe] set_adapters not available ({e_set}); trying fuse_lora()")
                try:
                    pipe.fuse_lora(lora_scale=1.0)
                    print("[pipe] fuse_lora(lora_scale=1.0) done")
                except Exception as e_fuse:
                    print(f"[pipe] fuse_lora failed: {e_fuse}")
            print("[pipe] LoRA ready ✅")
        except ImportError:
            print("[pipe] peft not installed; LoRA skipped (add `peft>=0.11`).")
        except Exception as e:
            print(f"[pipe] load_lora_weights failed (continue without): {e}")
    else:
        print(f"[pipe] LoRA path not found: {lora_dir} (continue without)")

    _PIPE = pipe
    return pipe

# ==== STAGE-2 ONLY ADDED: Stage-2 loader (no change to Stage-1 logic) ====
def get_pipe_stage2() -> FluxFillPipelineStage2:
    """
    Load Stage-2 FluxFillPipeline_token12_depth and mount Stage-2 LoRA.
    """
    global _PIPE_STAGE2
    if _PIPE_STAGE2 is not None:
        return _PIPE_STAGE2

    device = "cuda" if torch.cuda.is_available() else "cpu"
    dtype = torch.bfloat16 if device == "cuda" else torch.float32

    local_flux = BASE_DIR / "code_edit" / "flux_cache"
    use_local = local_flux.exists()
    hf_token = os.environ.get("HF_TOKEN")

    try:
        from huggingface_hub import hf_hub_enable_hf_transfer
        hf_hub_enable_hf_transfer()
    except Exception:
        pass

    print(f"[stage2] loading FLUX.1-Fill-dev (dtype={dtype}, device={device}, local={use_local})")
    try:
        if use_local:
            pipe2 = FluxFillPipelineStage2.from_pretrained(local_flux, torch_dtype=dtype).to(device)
        else:
            pipe2 = FluxFillPipelineStage2.from_pretrained(
                "black-forest-labs/FLUX.1-Fill-dev",
                torch_dtype=dtype,
                token=hf_token,
            ).to(device)
    except Exception as e:
        raise RuntimeError("Stage-2: Failed to load FLUX.1-Fill-dev.") from e

    # Load Stage-2 LoRA
    lora_dir2 = CODE_EDIT / "stage2" / "checkpoint-20000"
    candidate_names = [
        "pytorch_lora_weights.safetensors",
        "adapter_model.safetensors",
        "lora.safetensors",
    ]
    weight_name = None
    for name in candidate_names:
        if (lora_dir2 / name).is_file():
            weight_name = name
            break

    if not lora_dir2.exists():
        raise RuntimeError(f"Stage-2 LoRA dir not found: {lora_dir2}")
    if weight_name is None:
        raise RuntimeError(
            f"Stage-2 LoRA weight not found under {lora_dir2}. Tried: {candidate_names}"
        )

    try:
        import peft  # noqa: F401
    except Exception as e:
        raise RuntimeError("peft is not installed (requires peft>=0.11).") from e

    try:
        print(f"[stage2] loading LoRA: {lora_dir2}/{weight_name}")
        pipe2.load_lora_weights(
            str(lora_dir2),
            weight_name=weight_name,
            adapter_name="stage2",
        )
        try:
            pipe2.set_adapters("stage2", scale=1.0)
            print("[stage2] set_adapters('stage2', 1.0)")
        except Exception as e_set:
            print(f"[stage2] set_adapters not available ({e_set}); trying fuse_lora()")
            try:
                pipe2.fuse_lora(lora_scale=1.0)
                print("[stage2] fuse_lora(lora_scale=1.0) done")
            except Exception as e_fuse:
                raise RuntimeError(f"Stage-2 fuse_lora failed: {e_fuse}") from e_fuse
    except Exception as e:
        raise RuntimeError(f"Stage-2 LoRA load failed: {e}") from e

    _PIPE_STAGE2 = pipe2
    return pipe2
# ==========================================================================

# ---------------- Mask helpers ----------------
def to_grayscale_mask(im: Image.Image) -> Image.Image:
    """
    Convert any RGBA/RGB/L image to L mode.
    Output: white = region to remove/fill, black = keep.
    """
    if im.mode == "RGBA":
        mask = im.split()[-1]  # alpha as mask
    else:
        mask = im.convert("L")
    # Simple binarization & denoise
    mask = mask.point(lambda p: 255 if p > 16 else 0)
    return mask  # Do not invert; white = mask region

def dilate_mask(mask_l: Image.Image, px: int) -> Image.Image:
    """Dilate the white region by ~px pixels."""
    if px <= 0:
        return mask_l
    arr = np.array(mask_l, dtype=np.uint8)
    kernel = np.ones((3, 3), np.uint8)
    iters = max(1, int(px // 2))  # heuristic
    dilated = cv2.dilate(arr, kernel, iterations=iters)
    return Image.fromarray(dilated, mode="L")

def _mask_from_red(img: Image.Image, out_size: Tuple[int, int]) -> Image.Image:
    """
    Extract "pure red strokes" as a binary mask (white=brush, black=others) from RGBA/RGB.
    Thresholds are lenient to tolerate compression/resampling.
    """
    arr = np.array(img.convert("RGBA"))
    r, g, b, a = arr[..., 0], arr[..., 1], arr[..., 2], arr[..., 3]
    red_hit = (r >= 200) & (g <= 40) & (b <= 40) & (a > 0)
    mask = (red_hit.astype(np.uint8) * 255)
    m = Image.fromarray(mask, mode="L").resize(out_size, Image.NEAREST)
    return m

def pick_mask(
    upload_mask: Optional[Image.Image],
    sketch_data: Optional[dict],
    base_image: Image.Image,
    dilate_px: int = 0,
) -> Optional[Image.Image]:
    """
    Selection rules:
      1) If a mask is uploaded: use it directly (white=mask)
      2) Else from ImageEditor output, only red strokes are recognized as mask:
         - Try sketch_data['mask'] first (some versions provide it)
         - Else merge red strokes from sketch_data['layers'][*]['image']
         - If still none, try sketch_data['composite'] for red strokes
    """
    # 1) Uploaded mask has highest priority
    if isinstance(upload_mask, Image.Image):
        m = to_grayscale_mask(upload_mask).resize(base_image.size, Image.NEAREST)
        return dilate_mask(m, dilate_px) if dilate_px > 0 else m

    # 2) Hand-drawn (ImageEditor)
    if isinstance(sketch_data, dict):
        # 2a) explicit mask (still supported)
        m = sketch_data.get("mask")
        if isinstance(m, Image.Image):
            m = to_grayscale_mask(m).resize(base_image.size, Image.NEAREST)
            return dilate_mask(m, dilate_px) if dilate_px > 0 else m

        # 2b) merge red strokes from layers
        layers = sketch_data.get("layers")
        acc = None
        if isinstance(layers, list) and layers:
            acc = Image.new("L", base_image.size, 0)
            for lyr in layers:
                if not isinstance(lyr, dict):
                    continue
                li = lyr.get("image") or lyr.get("mask")
                if isinstance(li, Image.Image):
                    m_layer = _mask_from_red(li, base_image.size)
                    acc = ImageOps.lighter(acc, m_layer)  # union
            if acc.getbbox() is not None:
                return dilate_mask(acc, dilate_px) if dilate_px > 0 else acc

        # 2c) finally, search composite for red strokes
        comp = sketch_data.get("composite")
        if isinstance(comp, Image.Image):
            m_comp = _mask_from_red(comp, base_image.size)
            if m_comp.getbbox() is not None:
                return dilate_mask(m_comp, dilate_px) if dilate_px > 0 else m_comp

    # 3) No valid mask
    return None

def _round_mult64(x: float, mode: str = "nearest") -> int:
    """
    Align x to a multiple of 64:
      - mode="ceil"    round up
      - mode="floor"   round down
      - mode="nearest" nearest multiple
    """
    if mode == "ceil":
        return int((x + 63) // 64) * 64
    elif mode == "floor":
        return int(x // 64) * 64
    else:  # nearest
        return int((x + 32) // 64) * 64

def prepare_size_for_flux(img: Image.Image, target_max: int = 1024) -> tuple[int, int]:
    """
    Steps:
    1) Round w,h up to multiples of 64 (avoid too-small sizes)
    2) Fix the long side to target_max (default 1024)
    3) Scale the short side proportionally and align to a multiple of 64 (>= 64)
    """
    w, h = img.size
    w1 = max(64, _round_mult64(w, mode="ceil"))
    h1 = max(64, _round_mult64(h, mode="ceil"))

    if w1 >= h1:
        out_w = target_max
        scaled_h = h1 * (target_max / w1)
        out_h = max(64, _round_mult64(scaled_h, mode="nearest"))
    else:
        out_h = target_max
        scaled_w = w1 * (target_max / h1)
        out_w = max(64, _round_mult64(scaled_w, mode="nearest"))

    return int(out_w), int(out_h)

@spaces.GPU
# ---------------- Preview depth for canvas (colored) ----------------
def preview_depth(image: Optional[Image.Image], encoder: str, max_res: int, input_size: int, fp32: bool):
    if image is None:
        return None
    dm = get_model(encoder)
    d_rgb = dm.infer(image=image, max_res=max_res, input_size=input_size, fp32=fp32, grayscale=False)
    return d_rgb

def prepare_canvas(image, depth_img, source):
    base = depth_img if source == "depth" else image
    if base is None:
        raise gr.Error('Please upload an image (and wait for the depth preview), then click "Prepare canvas".')
    return gr.update(value=base)

# ---------------- Stage-1: depth(color) -> fill ----------------
@spaces.GPU
def run_depth_and_fill(
    image: Image.Image,
    mask_upload: Optional[Image.Image],
    sketch: Optional[dict],
    prompt: str,
    encoder: str,
    max_res: int,
    input_size: int,
    fp32: bool,
    max_side: int,
    mask_dilate_px: int,
    guidance_scale: float,
    steps: int,
    seed: Optional[int],
) -> Tuple[Image.Image, Image.Image]:
    if image is None:
        raise gr.Error("Please upload an image first.")

    # 1) produce a colored depth map (RGB)
    depth_model = get_model(encoder)
    depth_rgb: Image.Image = depth_model.infer(
        image=image, max_res=max_res, input_size=input_size, fp32=fp32, grayscale=False
    ).convert("RGB")

    print(f"[DEBUG] Depth RGB: mode={depth_rgb.mode}, size={depth_rgb.size}")

    # 2) extract mask (uploaded > drawn)
    mask_l = pick_mask(mask_upload, sketch, image, dilate_px=mask_dilate_px)
    if (mask_l is None) or (mask_l.getbbox() is None):
        raise gr.Error("No valid mask detected: please draw with the red brush or upload a binary mask.")

    print(f"[DEBUG] Mask: mode={mask_l.mode}, size={mask_l.size}, bbox={mask_l.getbbox()}")

    # 3) decide output size
    width, height = prepare_size_for_flux(depth_rgb, target_max=max_side)
    orig_w, orig_h = image.size
    print(f"[DEBUG] FLUX size: {width}x{height}, original: {orig_w}x{orig_h}")

    # 4) run FLUX pipeline (key: use depth_rgb as both image and depth input)
    pipe = get_pipe()
    generator = (
        torch.Generator("cpu").manual_seed(int(seed))
        if (seed is not None and seed >= 0)
        else torch.Generator("cpu").manual_seed(random.randint(0, 2**31 - 1))
    )

    result = pipe(
        prompt=prompt,
        image=depth_rgb,           # use the colored depth map instead of original image
        mask_image=mask_l,
        width=width,
        height=height,
        guidance_scale=float(guidance_scale),
        num_inference_steps=int(steps),
        max_sequence_length=512,
        generator=generator,
        depth=depth_rgb,           # feed depth (colored)
    ).images[0]

    final_result = result.resize((orig_w, orig_h), Image.BICUBIC)

    # return result and mask preview
    mask_preview = mask_l.resize((orig_w, orig_h), Image.NEAREST).convert("RGB")
    return final_result, mask_preview

def _to_pil_rgb(img_like) -> Image.Image:
    """Normalize input to PIL RGB. Supports PIL/L/RGBA/np.array."""
    if isinstance(img_like, Image.Image):
        return img_like.convert("RGB")
    try:
        arr = np.array(img_like)
        if arr.ndim == 2:
            arr = np.stack([arr, arr, arr], axis=-1)
        return Image.fromarray(arr.astype(np.uint8), mode="RGB")
    except Exception:
        raise gr.Error("Stage-2: `depth` / `depth_image` is not a valid image object.")

# ---------------- Stage-2: REQUIRED refine/render ----------------
@spaces.GPU
def run_stage2_refine(
    image: Image.Image,              # original image (RGB)
    stage1_out: Image.Image,         # output from Stage-1
    depth_img_from_stage1_input: Image.Image,  # Stage-1 depth preview (from UI)
    mask_upload: Optional[Image.Image],
    sketch: Optional[dict],
    prompt: str,
    encoder: str,
    max_res: int,
    input_size: int,
    fp32: bool,
    max_side: int,
    guidance_scale: float,
    steps: int,
    seed: Optional[int],
) -> Image.Image:
    if image is None or stage1_out is None:
        raise gr.Error("Please complete Stage-1 first (needs original image and Stage-1 output).")

    # Allow refine without mask (use all-black)
    mask_l = pick_mask(mask_upload, sketch, image, dilate_px=0)
    if (mask_l is None) or (mask_l.getbbox() is None):
        mask_l = Image.new("L", image.size, 0)

    # Unify sizes
    width, height = prepare_size_for_flux(image, target_max=max_side)
    orig_w, orig_h = image.size

    pipe2 = get_pipe_stage2()
    g2 = (
        torch.Generator("cpu").manual_seed(int(seed))
        if (seed is not None and seed >= 0)
        else torch.Generator("cpu").manual_seed(random.randint(0, 2**31 - 1))
    )
    depth_pil = _to_pil_rgb(stage1_out)                      # for `depth`
    depth_image_pil = _to_pil_rgb(depth_img_from_stage1_input)  # for `depth_image`
    image_rgb = _to_pil_rgb(image)

    # Resize to (width, height)
    depth_pil = depth_pil.resize((width, height), Image.BICUBIC)
    depth_image_pil = depth_image_pil.resize((width, height), Image.BICUBIC)

    # Mapping:
    #   image        = original RGB
    #   depth        = Stage-1 output (updated geometry)
    #   depth_image  = Stage-1 input depth (UI depth preview)
    out2 = pipe2(
        prompt=prompt,
        image=image,               # original image
        mask_image=mask_l,
        width=width,
        height=height,
        guidance_scale=float(guidance_scale),
        num_inference_steps=int(steps),
        max_sequence_length=512,
        generator=g2,
        depth=depth_pil,
        depth_image=depth_image_pil,
    ).images[0]

    out2 = out2.resize((orig_w * 3, orig_h), Image.BICUBIC)  # keep your 3× showcase layout
    return out2

# ---------------- UI ----------------
with gr.Blocks() as demo:
    gr.Markdown(
        """
# GeoRemover · Depth-Guided Object Removal (Two-Stage, Stage-2 REQUIRED)

**Pipeline overview**  
1) Compute a **colored depth map** from your input image.  
2) You create a **removal mask** (red brush or upload).  
3) **Stage-1** runs FLUX Fill with depth guidance to get a first pass.  
4) **Stage-2 (REQUIRED)** renders the final result from depth → image using Stage-1 output and the original depth.

> ⚠️ **Stage-2 is required.** Always click **Run Stage-2 (Render)** *after* Stage-1 finishes. Stage-1 alone is not the final output.

---

### Quick start
1. **Upload image** (left). Wait for **Depth preview (colored)** (right).  
2. In **Draw mask**, pick **Draw on: _image_** or **_depth_**, then click **Prepare canvas**.  
3. Paint the region to remove using the **red brush** (**red = remove**).  
4. Optionally adjust **Mask dilation** for thin edges.  
5. Enter a concise **Prompt** describing the fill content.  
6. Click **Run** → produces **Stage-1** (first pass).  
7. Click **Run Stage-2 (Render)** → produces the **final** result.

---

### Mask rules & tips
- Only **red strokes** are treated as mask (**white = remove, black = keep** internally).  
- Paint **slightly larger** than the object boundary to avoid seams/halos.  
- If you have a binary mask already, use **Upload mask**.  
- **Mask dilation (px)** expands the mask to cover thin borders.
"""
    )

    with gr.Row():
        with gr.Column(scale=1):
            # Input image
            img = gr.Image(
                label="Upload image",
                type="pil",
            )

            # Mask: upload or draw
            with gr.Tab("Upload mask"):
                mask_upload = gr.Image(
                    label="Mask (optional)",
                    type="pil",
                )

            with gr.Tab("Draw mask"):
                draw_source = gr.Radio(
                    ["image", "depth"],
                    value="image",
                    label="Draw on",
                )
                prepare_btn = gr.Button("Prepare canvas", variant="secondary")
                gr.Markdown(
                    """
**Canvas usage**  
- Click **Prepare canvas** after selecting *image* or *depth*.  
- Use the **red brush** only—red strokes are extracted as the removal mask.  
- Switch tabs anytime if you prefer uploading a ready-made mask.
"""
                )
                sketch = gr.ImageEditor(
                    label="Sketch mask (red = remove)",
                    type="pil",
                    brush=gr.Brush(colors=["#FF0000"], default_size=24),
                )

            # Prompt
            prompt = gr.Textbox(
                label="Prompt",
                value="A beautiful scene",
                placeholder="don't change it",
            )

            # Tunables
            with gr.Accordion("Advanced (Depth & FLUX)", open=False):
                encoder = gr.Dropdown(
                    ["vits", "vitl"],
                    value="vitl",
                    label="Depth encoder",
                )
                max_res = gr.Slider(
                    512, 2048, value=1280, step=64,
                    label="Depth: max_res",
                )
                input_size = gr.Slider(
                    256, 1024, value=518, step=2,
                    label="Depth: input_size",
                )
                fp32 = gr.Checkbox(
                    False,
                    label="Depth: use FP32 (default FP16)",
                )
                max_side = gr.Slider(
                    512, 1536, value=1024, step=64,
                    label="FLUX: max side (px)",
                )
                mask_dilate_px = gr.Slider(
                    0, 128, value=0, step=1,
                    label="Mask dilation (px)",
                )
                guidance_scale = gr.Slider(
                    0, 50, value=30, step=0.5,
                    label="FLUX: guidance_scale",
                )
                steps = gr.Slider(
                    10, 75, value=50, step=1,
                    label="FLUX: steps",
                )
                seed = gr.Number(
                    value=0, precision=0,
                    label="Seed (>=0 = fixed; empty = random)",
                )

            run_btn = gr.Button("Run", variant="primary")
            # Stage-2 is REQUIRED: keep disabled until Stage-1 finishes
            run_btn_stage2 = gr.Button("Run Stage-2 (Render)", variant="secondary", interactive=False)

        with gr.Column(scale=1):
            depth_preview = gr.Image(
                label="Depth preview (colored)",
                interactive=False,
            )
            mask_preview = gr.Image(
                label="Mask preview (areas to remove)",
                interactive=False,
            )
            out = gr.Image(
                label="Output (Stage-1 first pass)",
            )
            out_stage2 = gr.Image(
                label="Final Output (Stage-2)",
            )

    gr.Markdown(
        """
### Why Stage-2 is required
Stage-1 provides a depth-guided fill that is *not final*. **Stage-2 renders** the definitive image by leveraging:
- **Stage-1 output** as updated geometry hints, and  
- **Original colored depth** as `depth_image` guidance.  
Skipping Stage-2 will leave the process incomplete.

### Troubleshooting
- **“No valid mask detected”**: Either upload a binary mask (white=remove) **or** draw with **red brush** after clicking **Prepare canvas**.  
- **Seams/halos**: Increase **Mask dilation (px)** (e.g., 8–16) and re-run both stages.  
- **Prompt not followed**: Lower **guidance_scale** (e.g., 18–24) and make the prompt more concrete.  
- **Depth looks noisy**: Use **vitl**, increase **Depth: max_res**, or enable **FP32**.
"""
    )

    # ===== Helpers to toggle Stage-2 button =====
    def _enable_button():
        return gr.update(interactive=True)

    # Auto depth preview on image change
    img.change(
        fn=preview_depth,
        inputs=[img, encoder, max_res, input_size, fp32],
        outputs=[depth_preview],
    )

    # Prepare canvas for drawing on image or depth
    prepare_btn.click(
        fn=prepare_canvas,
        inputs=[img, depth_preview, draw_source],
        outputs=[sketch],
    )

    # Stage-1
    run_btn.click(
        fn=run_depth_and_fill,
        inputs=[img, mask_upload, sketch, prompt, encoder, max_res, input_size, fp32,
                max_side, mask_dilate_px, guidance_scale, steps, seed],
        outputs=[out, mask_preview],
        api_name="run",
    ).then(  # Enable Stage-2 only after Stage-1 completes
        fn=_enable_button,
        inputs=[],
        outputs=[run_btn_stage2],
    )

    # Stage-2 (REQUIRED; unlocked after Stage-1)
    run_btn_stage2.click(
        fn=run_stage2_refine,
        inputs=[img, out, depth_preview,
                mask_upload, sketch, prompt, encoder, max_res, input_size, fp32,
                max_side, guidance_scale, steps, seed],
        outputs=[out_stage2],
        api_name="run_stage2",
    )

if __name__ == "__main__":
    os.environ.setdefault("HF_HUB_DISABLE_TELEMETRY", "1")
    demo.launch(server_name="0.0.0.0", server_port=7860)