File size: 16,689 Bytes
a9397fc
8df3f9a
 
 
901cd2d
75e28bb
 
 
901cd2d
8df3f9a
901cd2d
 
8df3f9a
 
 
0bb18a8
901cd2d
8df3f9a
 
 
 
a9397fc
 
 
75e28bb
784d43d
a9397fc
 
 
8df3f9a
 
 
 
 
 
a9397fc
 
 
31ef6d3
a9397fc
0bb18a8
8df3f9a
 
 
 
b64edd3
8df3f9a
 
 
 
 
 
 
 
 
 
784d43d
 
 
 
d9d5913
a9397fc
8df3f9a
 
 
 
 
 
 
 
 
 
 
31ef6d3
8df3f9a
 
 
 
ccebd17
 
 
 
 
 
7b8ab13
 
 
 
 
 
 
 
 
 
 
 
 
 
ccebd17
 
 
7b8ab13
ccebd17
8df3f9a
 
 
bef166b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56e808f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bef166b
 
 
8df3f9a
 
 
 
 
 
 
 
 
 
31ef6d3
d9d5913
901cd2d
8df3f9a
31ef6d3
901cd2d
 
 
 
 
 
d9d5913
0bb18a8
 
 
 
 
31ef6d3
 
 
 
 
 
 
a9397fc
31ef6d3
a9397fc
31ef6d3
8df3f9a
 
d66f250
8df3f9a
 
 
901cd2d
8df3f9a
 
31ef6d3
8df3f9a
31ef6d3
84293b6
 
 
8df3f9a
 
31ef6d3
8df3f9a
 
 
 
 
 
 
0bb18a8
8df3f9a
84293b6
 
 
 
 
31ef6d3
84293b6
31ef6d3
84293b6
8df3f9a
84293b6
31ef6d3
84293b6
8df3f9a
 
 
 
 
 
 
 
 
 
a9397fc
8df3f9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a9397fc
 
 
 
8df3f9a
a9397fc
 
8df3f9a
 
 
 
 
 
 
 
901cd2d
8df3f9a
6b31dc3
 
 
901cd2d
236a930
0a42ddd
236a930
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8df3f9a
 
31ef6d3
8df3f9a
 
 
 
 
 
 
 
 
 
 
 
 
236a930
 
 
 
 
 
 
 
 
 
 
8df3f9a
a9397fc
31ef6d3
8df3f9a
901cd2d
8df3f9a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
"""SAM 3D Objects – kaolin+pytorch3d stubbed for ZeroGPU (PyTorch 2.10+cu128)."""
import os, sys, subprocess
os.environ.setdefault("CUDA_HOME", "/usr/local/cuda")
os.environ.setdefault("CONDA_PREFIX", "/usr/local")
os.environ["LIDRA_SKIP_INIT"] = "true"
os.environ["ATTN_BACKEND"] = "sdpa"
os.environ["SPARSE_ATTN_BACKEND"] = "sdpa"
os.environ["SPARSE_BACKEND"] = "spconv"

# MUST import spaces before torch
import spaces
import gradio as gr
import numpy as np
from PIL import Image
from huggingface_hub import snapshot_download, login
import tempfile
from pathlib import Path

if os.environ.get("HF_TOKEN"):
    login(token=os.environ["HF_TOKEN"])

# --- Stubs (must be before sam3d imports) ---
STUB_KAOLIN = Path("/home/user/app/kaolin_stub")
STUB_PT3D = Path("/home/user/app/pytorch3d_stub")
STUB_FA = Path("/home/user/app/flash_attn_stub")
for stub in [STUB_KAOLIN, STUB_PT3D, STUB_FA]:
    if stub.exists():
        sys.path.insert(0, str(stub))
        print(f"Stub added: {stub.name}")

# --- Runtime pip installs ---
def _pip(*a):
    r = subprocess.run([sys.executable, "-m", "pip", "install", "--no-cache-dir"] + list(a),
                       capture_output=True, text=True, timeout=1200)
    ok = r.returncode == 0
    tag = a[-1][:50] if a else "?"
    if ok:
        print(f"  pip OK: {tag}")
    else:
        print(f"  pip FAIL: {tag}")
        print(f"    {r.stderr[-300:]}")
    return ok

print("=== Runtime installs ===")
_pip("open3d>=0.18.0")
_pip("--no-deps", "git+https://github.com/EasternJournalist/utils3d.git")   # --no-deps: skip jupyter dependency
_pip("iopath")
_pip("--no-deps", "sam2>=1.1.0")
_pip("--no-deps", "git+https://github.com/microsoft/MoGe.git@a8c37341bc0325ca99b9d57981cc3bb2bd3e255b")

# gsplat
for idx in ["https://docs.gsplat.studio/whl/pt210cu128",
            "https://docs.gsplat.studio/whl/pt28cu128"]:
    if _pip("--no-deps", f"--extra-index-url={idx}", "gsplat"):
        break

# spconv (sparse convolution – needed for SAM3D's SLatFlowModel)
# cu124 wheel is forward-compatible with cu128
_pip("spconv-cu124==2.3.8")

# DO NOT import CUDA-dependent packages here!

# --- Clone sam-3d-objects ---
SAM3D_PATH = Path("/home/user/app/sam-3d-objects")
if not SAM3D_PATH.exists():
    print("Cloning sam-3d-objects...")
    subprocess.run(["git", "clone", "--depth", "1",
        "https://github.com/facebookresearch/sam-3d-objects.git",
        str(SAM3D_PATH)], check=True)

subprocess.run([sys.executable, "-m", "pip", "install", "-e", str(SAM3D_PATH), "--no-deps"],
               capture_output=True, text=True)

# Hydra patch
patch = SAM3D_PATH / "patching" / "hydra"
if patch.exists():
    subprocess.run(["bash", str(patch)], capture_output=True, cwd=str(SAM3D_PATH))

# CRITICAL PATCH: Prevent SAM3D from overriding ATTN_BACKEND to flash_attn
# inference_pipeline.py auto-detects H200/A100 and forces flash_attn,
# but we don't have the real flash_attn package.
ip_file = SAM3D_PATH / "sam3d_objects" / "pipeline" / "inference_pipeline.py"
if ip_file.exists():
    ip_src = ip_file.read_text()
    # Find and replace the set_attention_backend function
    old_marker = 'os.environ["ATTN_BACKEND"] = "flash_attn"'
    if old_marker in ip_src:
        # Replace the entire if-block that forces flash_attn
        ip_src = ip_src.replace(
            'if "A100" in gpu_name or "H100" in gpu_name or "H200" in gpu_name:\n'
            '        # logger.info("Use flash_attn")\n'
            '        os.environ["ATTN_BACKEND"] = "flash_attn"\n'
            '        os.environ["SPARSE_ATTN_BACKEND"] = "flash_attn"',
            '# PATCHED: Always use sdpa backend (flash_attn not available on ZeroGPU)\n'
            '    logger.info("Using sdpa backend (patched for ZeroGPU)")\n'
            '    os.environ.setdefault("ATTN_BACKEND", "sdpa")\n'
            '    os.environ.setdefault("SPARSE_ATTN_BACKEND", "sdpa")'
        )
        ip_file.write_text(ip_src)
        print("PATCHED: inference_pipeline.py - forced sdpa backend")
    else:
        print("INFO: inference_pipeline.py already patched or different version")

sys.path.insert(0, str(SAM3D_PATH))
sys.path.insert(0, str(SAM3D_PATH / "notebook"))

# --- Monkey-patch: inject depth_edge into utils3d.numpy ---
# utils3d package lacks depth_edge in newer versions; SAM3D needs it for layout post-optimization
try:
    import utils3d.numpy as _u3d_np
    if not hasattr(_u3d_np, 'depth_edge'):
        def _depth_edge(depth, rtol=0.03, mask=None):
            from scipy.ndimage import sobel
            import numpy as _np
            d = _np.where(mask, depth, 0.0) if mask is not None else depth.copy()
            gx = sobel(d, axis=1)
            gy = sobel(d, axis=0)
            grad = _np.sqrt(gx**2 + gy**2)
            denom = _np.abs(d)
            denom[denom < 1e-6] = 1e-6
            edge = (grad / denom) > rtol
            if mask is not None:
                edge = edge & mask
            return edge
        _u3d_np.depth_edge = _depth_edge

        def _normals_edge(normals, tol=0.1, mask=None):
            """Detect normal discontinuities."""
            import numpy as _np
            from scipy.ndimage import sobel
            # Compute gradient of each normal component
            edges = _np.zeros(normals.shape[:2], dtype=bool)
            for c in range(normals.shape[-1]):
                ch = normals[..., c]
                if mask is not None:
                    ch = _np.where(mask, ch, 0.0)
                gx = sobel(ch, axis=1)
                gy = sobel(ch, axis=0)
                grad = _np.sqrt(gx**2 + gy**2)
                edges |= (grad > tol)
            if mask is not None:
                edges = edges & mask
            return edges
        _u3d_np.normals_edge = _normals_edge

        # Also inject a catch-all __getattr__ for any future missing functions
        _orig_getattr = getattr(_u3d_np, '__getattr__', None)
        def _u3d_catchall(name):
            if name.startswith('__') and name.endswith('__'):
                raise AttributeError(name)
            import warnings
            warnings.warn(f"utils3d.numpy stub: {name} not implemented, returning dummy")
            def _dummy(*a, **kw):
                import numpy as _np
                return _np.zeros(1)
            return _dummy
        import types
        _u3d_np.__getattr__ = _u3d_catchall

        print("Injected depth_edge + normals_edge + catch-all into utils3d.numpy")
except Exception as e:
    print(f"depth_edge patch skipped: {e}")

# --- Pre-download checkpoints ---
print("Downloading SAM3D checkpoints...")
CKPT_DIR = snapshot_download(repo_id="facebook/sam-3d-objects",
                              token=os.environ.get("HF_TOKEN"))
hf_ckpt = Path(CKPT_DIR) / "checkpoints"
local_ckpt = SAM3D_PATH / "checkpoints" / "hf"
if hf_ckpt.exists() and not local_ckpt.exists():
    local_ckpt.parent.mkdir(parents=True, exist_ok=True)
    local_ckpt.symlink_to(hf_ckpt)
CONFIG_PATH = str(local_ckpt / "pipeline.yaml")
print(f"Config exists: {Path(CONFIG_PATH).exists()}")
print("=== Startup complete ===")

# --- Endpoints ---

@spaces.GPU(duration=60)
def diagnose():
    import torch
    lines = [f"torch={torch.__version__}", f"cuda={torch.cuda.is_available()}"]
    if torch.cuda.is_available():
        lines.append(f"gpu={torch.cuda.get_device_name()}")
    for mod in ["kaolin", "utils3d", "iopath", "pytorch3d", "open3d", "gsplat", "moge"]:
        try:
            m = __import__(mod)
            lines.append(f"{mod}: OK ({getattr(m, '__version__', '-')})")
        except Exception as e:
            lines.append(f"{mod}: FAIL - {e}")
    try:
        from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
        lines.append("sam2: OK")
    except Exception as e:
        lines.append(f"sam2: FAIL - {e}")
    try:
        from inference import Inference
        lines.append("SAM3D Inference: importable")
    except Exception as e:
        lines.append(f"SAM3D Inference: FAIL - {e}")
    lines.append(f"config: {Path(CONFIG_PATH).exists()}")
    return "\n".join(lines)

@spaces.GPU(duration=300)
def reconstruct_objects(image: np.ndarray):
    if image is None:
        return None, None, "No image"
    try:
        import torch, trimesh, time
        t0 = time.time()
        print(f"GPU: {torch.cuda.get_device_name()}")

        from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
        print(f"  Loading SAM2... (VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB)")
        sam2_gen = SAM2AutomaticMaskGenerator.from_pretrained("facebook/sam2-hiera-small")
        print(f"  SAM2 loaded ({time.time()-t0:.0f}s, VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB)")

        image_np = np.array(image) if not isinstance(image, np.ndarray) else image
        masks = sam2_gen.generate(image_np)
        if not masks:
            return None, image_np, "No objects detected"
        masks = sorted(masks, key=lambda x: x["area"], reverse=True)
        best_mask = masks[0]["segmentation"]

        preview = image_np.copy()
        preview[best_mask] = (preview[best_mask] * 0.5 + np.array([0, 255, 0]) * 0.5).astype(np.uint8)
        print(f"  {len(masks)} masks ({time.time()-t0:.0f}s)")

        # Free SAM2 to save VRAM for SAM3D
        del sam2_gen
        torch.cuda.empty_cache()
        print(f"  SAM2 freed (VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB)")

        from inference import Inference
        print(f"  Loading SAM3D... (VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB)")
        sam3d = Inference(CONFIG_PATH, compile=False)
        print(f"  SAM3D loaded ({time.time()-t0:.0f}s, VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB)")

        print(f"  Running reconstruction... (VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB)")
        result = sam3d(image=image_np, mask=best_mask, seed=42)
        print(f"  Reconstructed ({time.time()-t0:.0f}s, VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB)")
        if result is None:
            return None, preview, "Reconstruction returned None"

        od = tempfile.mkdtemp()
        glb = f"{od}/object.glb"

        gs = None
        if hasattr(result, "save_ply"):
            gs = result
        elif isinstance(result, dict):
            for k in ("gs", "gaussian", "gaussians", "scene"):
                v = result.get(k)
                if v is not None:
                    gs = v[0] if isinstance(v, (list, tuple)) else v
                    break

        if gs is not None and hasattr(gs, "save_ply"):
            ply = f"{od}/temp.ply"
            gs.save_ply(ply)
            import open3d as o3d
            pcd = o3d.io.read_point_cloud(ply)
            pcd.estimate_normals()
            mesh, _ = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=8)
            o3d.io.write_triangle_mesh(glb, mesh)
        elif gs is not None and hasattr(gs, "_xyz"):
            import open3d as o3d
            pcd = o3d.geometry.PointCloud()
            pcd.points = o3d.utility.Vector3dVector(gs._xyz.detach().cpu().numpy())
            pcd.estimate_normals()
            mesh, _ = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=8)
            o3d.io.write_triangle_mesh(glb, mesh)
        elif isinstance(result, dict) and "mesh" in result:
            m = result["mesh"]
            if hasattr(m, "export"):
                m.export(glb)
        else:
            keys = list(result.keys()) if isinstance(result, dict) else dir(result)
            return None, preview, f"Cannot extract 3D. Keys: {keys}"

        n = 0
        try:
            n = len(trimesh.load(glb, force="mesh").faces)
        except Exception:
            pass
        elapsed = int(time.time() - t0)
        return glb, preview, f"OK: {len(masks)} objects, {n:,} faces ({elapsed}s)"
    except Exception as e:
        import traceback
        tb = traceback.format_exc()
        print(tb)
        return None, None, f"Error:\n{tb[-1500:]}"


@spaces.GPU(duration=60)
def test_sam3d_only(image: np.ndarray):
    """Test SAM3D reconstruction with center-crop mask (no SAM2)."""
    if image is None:
        return None, None, "No image"
    try:
        import torch, time, gc
        t0 = time.time()
        print(f"GPU: {torch.cuda.get_device_name()}, VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB")

        image_np = np.array(image) if not isinstance(image, np.ndarray) else image
        h, w = image_np.shape[:2]

        # Create a center mask (middle 60% of image)
        mask = np.zeros((h, w), dtype=bool)
        y1, y2 = int(h * 0.2), int(h * 0.8)
        x1, x2 = int(w * 0.2), int(w * 0.8)
        mask[y1:y2, x1:x2] = True

        preview = image_np.copy()
        preview[mask] = (preview[mask] * 0.5 + np.array([0, 255, 0]) * 0.5).astype(np.uint8)
        print(f"  Mask created: {mask.sum()} pixels ({time.time()-t0:.0f}s)")

        from inference import Inference
        print(f"  Loading SAM3D... VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB")
        sam3d = Inference(CONFIG_PATH, compile=False)
        print(f"  SAM3D loaded ({time.time()-t0:.0f}s, VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB)")

        print(f"  Running reconstruction...")
        result = sam3d(image=image_np, mask=mask, seed=42)
        print(f"  Done ({time.time()-t0:.0f}s, VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB)")

        if result is None:
            return None, preview, "Reconstruction returned None"

        import tempfile
        od = tempfile.mkdtemp()
        glb = f"{od}/object.glb"

        gs = None
        if isinstance(result, dict):
            for k in ("gs", "gaussian", "gaussians", "scene"):
                v = result.get(k)
                if v is not None:
                    gs = v[0] if isinstance(v, (list, tuple)) else v
                    break

        if gs is not None and hasattr(gs, "save_ply"):
            ply = f"{od}/temp.ply"
            gs.save_ply(ply)
            import open3d as o3d
            pcd = o3d.io.read_point_cloud(ply)
            pcd.estimate_normals()
            mesh, _ = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=8)
            o3d.io.write_triangle_mesh(glb, mesh)
        elif gs is not None and hasattr(gs, "_xyz"):
            import open3d as o3d
            pcd = o3d.geometry.PointCloud()
            pcd.points = o3d.utility.Vector3dVector(gs._xyz.detach().cpu().numpy())
            pcd.estimate_normals()
            mesh, _ = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=8)
            o3d.io.write_triangle_mesh(glb, mesh)
        else:
            keys = list(result.keys()) if isinstance(result, dict) else dir(result)
            return None, preview, f"Cannot extract 3D. Keys: {keys}"

        import trimesh
        n = 0
        try:
            n = len(trimesh.load(glb, force="mesh").faces)
        except: pass
        elapsed = int(time.time() - t0)
        return glb, preview, f"OK: {n:,} faces ({elapsed}s)"
    except Exception as e:
        import traceback
        tb = traceback.format_exc()
        print(tb)
        return None, None, f"Error:\n{tb[-1500:]}"


# --- UI ---
with gr.Blocks(title="SAM 3D Objects") as demo:
    gr.Markdown("# SAM 3D Objects\nImage → 3D (GLB). SAM2 detection + SAM3D reconstruction.")
    with gr.Tab("Reconstruct"):
        with gr.Row():
            with gr.Column():
                inp = gr.Image(label="Input", type="numpy")
                btn = gr.Button("Reconstruct", variant="primary", size="lg")
            with gr.Column():
                prev = gr.Image(label="Detection", type="numpy", interactive=False)
                stat = gr.Textbox(label="Status")
        with gr.Row():
            m3d = gr.Model3D(label="3D Preview")
            dl = gr.File(label="Download GLB")
        btn.click(reconstruct_objects, inputs=[inp], outputs=[m3d, prev, stat])
        m3d.change(lambda x: x, inputs=[m3d], outputs=[dl])
    with gr.Tab("Test SAM3D Only"):
        with gr.Row():
            with gr.Column():
                tinp = gr.Image(label="Input", type="numpy")
                tbtn = gr.Button("Test SAM3D (no SAM2)", variant="primary")
            with gr.Column():
                tprev = gr.Image(label="Mask Preview", type="numpy", interactive=False)
                tstat = gr.Textbox(label="Status")
        with gr.Row():
            tm3d = gr.Model3D(label="3D Preview")
        tbtn.click(test_sam3d_only, inputs=[tinp], outputs=[tm3d, tprev, tstat])
    with gr.Tab("Diagnose"):
        dbtn = gr.Button("Diagnose GPU & Modules")
        dout = gr.Textbox(lines=15)
        dbtn.click(diagnose, outputs=[dout])

demo.launch(mcp_server=True)