"""SAM 3D Objects – kaolin+pytorch3d stubbed for ZeroGPU (PyTorch 2.10+cu128).""" import os, sys, subprocess os.environ.setdefault("CUDA_HOME", "/usr/local/cuda") os.environ.setdefault("CONDA_PREFIX", "/usr/local") os.environ["LIDRA_SKIP_INIT"] = "true" # MUST import spaces before torch import spaces import gradio as gr import numpy as np from PIL import Image from huggingface_hub import snapshot_download, login import tempfile from pathlib import Path if os.environ.get("HF_TOKEN"): login(token=os.environ["HF_TOKEN"]) # --- Stubs (must be before sam3d imports) --- STUB_KAOLIN = Path("/home/user/app/kaolin_stub") STUB_PT3D = Path("/home/user/app/pytorch3d_stub") STUB_SPCONV = Path("/home/user/app/spconv_stub") for stub in [STUB_KAOLIN, STUB_PT3D, STUB_SPCONV]: if stub.exists(): sys.path.insert(0, str(stub)) print(f"Stub added: {stub.name}") # --- Runtime pip installs --- def _pip(*a): r = subprocess.run([sys.executable, "-m", "pip", "install", "--no-cache-dir"] + list(a), capture_output=True, text=True, timeout=1200) ok = r.returncode == 0 tag = a[-1][:50] if a else "?" if ok: print(f" pip OK: {tag}") else: print(f" pip FAIL: {tag}") print(f" {r.stderr[-300:]}") return ok print("=== Runtime installs ===") _pip("open3d>=0.18.0") _pip("--no-deps", "utils3d") # --no-deps: skip jupyter dependency _pip("iopath") _pip("--no-deps", "sam2>=1.1.0") _pip("--no-deps", "git+https://github.com/microsoft/MoGe.git@a8c37341bc0325ca99b9d57981cc3bb2bd3e255b") # gsplat for idx in ["https://docs.gsplat.studio/whl/pt210cu128", "https://docs.gsplat.studio/whl/pt28cu128"]: if _pip("--no-deps", f"--extra-index-url={idx}", "gsplat"): break # DO NOT import CUDA-dependent packages here! # --- Clone sam-3d-objects --- SAM3D_PATH = Path("/home/user/app/sam-3d-objects") if not SAM3D_PATH.exists(): print("Cloning sam-3d-objects...") subprocess.run(["git", "clone", "--depth", "1", "https://github.com/facebookresearch/sam-3d-objects.git", str(SAM3D_PATH)], check=True) subprocess.run([sys.executable, "-m", "pip", "install", "-e", str(SAM3D_PATH), "--no-deps"], capture_output=True, text=True) # Hydra patch patch = SAM3D_PATH / "patching" / "hydra" if patch.exists(): subprocess.run(["bash", str(patch)], capture_output=True, cwd=str(SAM3D_PATH)) sys.path.insert(0, str(SAM3D_PATH)) sys.path.insert(0, str(SAM3D_PATH / "notebook")) # --- Pre-download checkpoints --- print("Downloading SAM3D checkpoints...") CKPT_DIR = snapshot_download(repo_id="facebook/sam-3d-objects", token=os.environ.get("HF_TOKEN")) hf_ckpt = Path(CKPT_DIR) / "checkpoints" local_ckpt = SAM3D_PATH / "checkpoints" / "hf" if hf_ckpt.exists() and not local_ckpt.exists(): local_ckpt.parent.mkdir(parents=True, exist_ok=True) local_ckpt.symlink_to(hf_ckpt) CONFIG_PATH = str(local_ckpt / "pipeline.yaml") print(f"Config exists: {Path(CONFIG_PATH).exists()}") print("=== Startup complete ===") # --- Endpoints --- @spaces.GPU(duration=60) def diagnose(): import torch lines = [f"torch={torch.__version__}", f"cuda={torch.cuda.is_available()}"] if torch.cuda.is_available(): lines.append(f"gpu={torch.cuda.get_device_name()}") for mod in ["kaolin", "utils3d", "iopath", "pytorch3d", "open3d", "gsplat", "moge"]: try: m = __import__(mod) lines.append(f"{mod}: OK ({getattr(m, '__version__', '-')})") except Exception as e: lines.append(f"{mod}: FAIL - {e}") try: from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator lines.append("sam2: OK") except Exception as e: lines.append(f"sam2: FAIL - {e}") try: from inference import Inference lines.append("SAM3D Inference: importable") except Exception as e: lines.append(f"SAM3D Inference: FAIL - {e}") lines.append(f"config: {Path(CONFIG_PATH).exists()}") return "\n".join(lines) @spaces.GPU(duration=300) def reconstruct_objects(image: np.ndarray): if image is None: return None, None, "No image" try: import torch, trimesh, time t0 = time.time() print(f"GPU: {torch.cuda.get_device_name()}") from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator sam2_gen = SAM2AutomaticMaskGenerator.from_pretrained("facebook/sam2-hiera-large") print(f" SAM2 loaded ({time.time()-t0:.0f}s)") image_np = np.array(image) if not isinstance(image, np.ndarray) else image masks = sam2_gen.generate(image_np) if not masks: return None, image_np, "No objects detected" masks = sorted(masks, key=lambda x: x["area"], reverse=True) best_mask = masks[0]["segmentation"] preview = image_np.copy() preview[best_mask] = (preview[best_mask] * 0.5 + np.array([0, 255, 0]) * 0.5).astype(np.uint8) print(f" {len(masks)} masks ({time.time()-t0:.0f}s)") from inference import Inference sam3d = Inference(CONFIG_PATH, compile=False) print(f" SAM3D loaded ({time.time()-t0:.0f}s)") result = sam3d(image=image_np, mask=best_mask, seed=42) print(f" Reconstructed ({time.time()-t0:.0f}s)") if result is None: return None, preview, "Reconstruction returned None" od = tempfile.mkdtemp() glb = f"{od}/object.glb" gs = None if hasattr(result, "save_ply"): gs = result elif isinstance(result, dict): for k in ("gs", "gaussian", "gaussians", "scene"): v = result.get(k) if v is not None: gs = v[0] if isinstance(v, (list, tuple)) else v break if gs is not None and hasattr(gs, "save_ply"): ply = f"{od}/temp.ply" gs.save_ply(ply) import open3d as o3d pcd = o3d.io.read_point_cloud(ply) pcd.estimate_normals() mesh, _ = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=8) o3d.io.write_triangle_mesh(glb, mesh) elif gs is not None and hasattr(gs, "_xyz"): import open3d as o3d pcd = o3d.geometry.PointCloud() pcd.points = o3d.utility.Vector3dVector(gs._xyz.detach().cpu().numpy()) pcd.estimate_normals() mesh, _ = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=8) o3d.io.write_triangle_mesh(glb, mesh) elif isinstance(result, dict) and "mesh" in result: m = result["mesh"] if hasattr(m, "export"): m.export(glb) else: keys = list(result.keys()) if isinstance(result, dict) else dir(result) return None, preview, f"Cannot extract 3D. Keys: {keys}" n = 0 try: n = len(trimesh.load(glb, force="mesh").faces) except Exception: pass elapsed = int(time.time() - t0) return glb, preview, f"OK: {len(masks)} objects, {n:,} faces ({elapsed}s)" except Exception as e: import traceback traceback.print_exc() return None, None, f"Error: {e}" # --- UI --- with gr.Blocks(title="SAM 3D Objects") as demo: gr.Markdown("# SAM 3D Objects\nImage → 3D (GLB). SAM2 detection + SAM3D reconstruction.") with gr.Tab("Reconstruct"): with gr.Row(): with gr.Column(): inp = gr.Image(label="Input", type="numpy") btn = gr.Button("Reconstruct", variant="primary", size="lg") with gr.Column(): prev = gr.Image(label="Detection", type="numpy", interactive=False) stat = gr.Textbox(label="Status") with gr.Row(): m3d = gr.Model3D(label="3D Preview") dl = gr.File(label="Download GLB") btn.click(reconstruct_objects, inputs=[inp], outputs=[m3d, prev, stat]) m3d.change(lambda x: x, inputs=[m3d], outputs=[dl]) with gr.Tab("Diagnose"): dbtn = gr.Button("Diagnose GPU & Modules") dout = gr.Textbox(lines=15) dbtn.click(diagnose, outputs=[dout]) demo.launch(mcp_server=True)