Spaces:
Paused
Paused
| """SAM 3D Objects – kaolin+pytorch3d stubbed for ZeroGPU (PyTorch 2.10+cu128).""" | |
| import os, sys, subprocess | |
| os.environ.setdefault("CUDA_HOME", "/usr/local/cuda") | |
| os.environ.setdefault("CONDA_PREFIX", "/usr/local") | |
| os.environ["LIDRA_SKIP_INIT"] = "true" | |
| # MUST import spaces before torch | |
| import spaces | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| from huggingface_hub import snapshot_download, login | |
| import tempfile | |
| from pathlib import Path | |
| if os.environ.get("HF_TOKEN"): | |
| login(token=os.environ["HF_TOKEN"]) | |
| # --- Stubs (must be before sam3d imports) --- | |
| STUB_KAOLIN = Path("/home/user/app/kaolin_stub") | |
| STUB_PT3D = Path("/home/user/app/pytorch3d_stub") | |
| STUB_SPCONV = Path("/home/user/app/spconv_stub") | |
| for stub in [STUB_KAOLIN, STUB_PT3D, STUB_SPCONV]: | |
| if stub.exists(): | |
| sys.path.insert(0, str(stub)) | |
| print(f"Stub added: {stub.name}") | |
| # --- Runtime pip installs --- | |
| def _pip(*a): | |
| r = subprocess.run([sys.executable, "-m", "pip", "install", "--no-cache-dir"] + list(a), | |
| capture_output=True, text=True, timeout=1200) | |
| ok = r.returncode == 0 | |
| tag = a[-1][:50] if a else "?" | |
| if ok: | |
| print(f" pip OK: {tag}") | |
| else: | |
| print(f" pip FAIL: {tag}") | |
| print(f" {r.stderr[-300:]}") | |
| return ok | |
| print("=== Runtime installs ===") | |
| _pip("open3d>=0.18.0") | |
| _pip("--no-deps", "utils3d") # --no-deps: skip jupyter dependency | |
| _pip("iopath") | |
| _pip("--no-deps", "sam2>=1.1.0") | |
| _pip("--no-deps", "git+https://github.com/microsoft/MoGe.git@a8c37341bc0325ca99b9d57981cc3bb2bd3e255b") | |
| # gsplat | |
| for idx in ["https://docs.gsplat.studio/whl/pt210cu128", | |
| "https://docs.gsplat.studio/whl/pt28cu128"]: | |
| if _pip("--no-deps", f"--extra-index-url={idx}", "gsplat"): | |
| break | |
| # DO NOT import CUDA-dependent packages here! | |
| # --- Clone sam-3d-objects --- | |
| SAM3D_PATH = Path("/home/user/app/sam-3d-objects") | |
| if not SAM3D_PATH.exists(): | |
| print("Cloning sam-3d-objects...") | |
| subprocess.run(["git", "clone", "--depth", "1", | |
| "https://github.com/facebookresearch/sam-3d-objects.git", | |
| str(SAM3D_PATH)], check=True) | |
| subprocess.run([sys.executable, "-m", "pip", "install", "-e", str(SAM3D_PATH), "--no-deps"], | |
| capture_output=True, text=True) | |
| # Hydra patch | |
| patch = SAM3D_PATH / "patching" / "hydra" | |
| if patch.exists(): | |
| subprocess.run(["bash", str(patch)], capture_output=True, cwd=str(SAM3D_PATH)) | |
| sys.path.insert(0, str(SAM3D_PATH)) | |
| sys.path.insert(0, str(SAM3D_PATH / "notebook")) | |
| # --- Pre-download checkpoints --- | |
| print("Downloading SAM3D checkpoints...") | |
| CKPT_DIR = snapshot_download(repo_id="facebook/sam-3d-objects", | |
| token=os.environ.get("HF_TOKEN")) | |
| hf_ckpt = Path(CKPT_DIR) / "checkpoints" | |
| local_ckpt = SAM3D_PATH / "checkpoints" / "hf" | |
| if hf_ckpt.exists() and not local_ckpt.exists(): | |
| local_ckpt.parent.mkdir(parents=True, exist_ok=True) | |
| local_ckpt.symlink_to(hf_ckpt) | |
| CONFIG_PATH = str(local_ckpt / "pipeline.yaml") | |
| print(f"Config exists: {Path(CONFIG_PATH).exists()}") | |
| print("=== Startup complete ===") | |
| # --- Endpoints --- | |
| def diagnose(): | |
| import torch | |
| lines = [f"torch={torch.__version__}", f"cuda={torch.cuda.is_available()}"] | |
| if torch.cuda.is_available(): | |
| lines.append(f"gpu={torch.cuda.get_device_name()}") | |
| for mod in ["kaolin", "utils3d", "iopath", "pytorch3d", "open3d", "gsplat", "moge"]: | |
| try: | |
| m = __import__(mod) | |
| lines.append(f"{mod}: OK ({getattr(m, '__version__', '-')})") | |
| except Exception as e: | |
| lines.append(f"{mod}: FAIL - {e}") | |
| try: | |
| from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator | |
| lines.append("sam2: OK") | |
| except Exception as e: | |
| lines.append(f"sam2: FAIL - {e}") | |
| try: | |
| from inference import Inference | |
| lines.append("SAM3D Inference: importable") | |
| except Exception as e: | |
| lines.append(f"SAM3D Inference: FAIL - {e}") | |
| lines.append(f"config: {Path(CONFIG_PATH).exists()}") | |
| return "\n".join(lines) | |
| def reconstruct_objects(image: np.ndarray): | |
| if image is None: | |
| return None, None, "No image" | |
| try: | |
| import torch, trimesh, time | |
| t0 = time.time() | |
| print(f"GPU: {torch.cuda.get_device_name()}") | |
| from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator | |
| sam2_gen = SAM2AutomaticMaskGenerator.from_pretrained("facebook/sam2-hiera-large") | |
| print(f" SAM2 loaded ({time.time()-t0:.0f}s)") | |
| image_np = np.array(image) if not isinstance(image, np.ndarray) else image | |
| masks = sam2_gen.generate(image_np) | |
| if not masks: | |
| return None, image_np, "No objects detected" | |
| masks = sorted(masks, key=lambda x: x["area"], reverse=True) | |
| best_mask = masks[0]["segmentation"] | |
| preview = image_np.copy() | |
| preview[best_mask] = (preview[best_mask] * 0.5 + np.array([0, 255, 0]) * 0.5).astype(np.uint8) | |
| print(f" {len(masks)} masks ({time.time()-t0:.0f}s)") | |
| from inference import Inference | |
| sam3d = Inference(CONFIG_PATH, compile=False) | |
| print(f" SAM3D loaded ({time.time()-t0:.0f}s)") | |
| result = sam3d(image=image_np, mask=best_mask, seed=42) | |
| print(f" Reconstructed ({time.time()-t0:.0f}s)") | |
| if result is None: | |
| return None, preview, "Reconstruction returned None" | |
| od = tempfile.mkdtemp() | |
| glb = f"{od}/object.glb" | |
| gs = None | |
| if hasattr(result, "save_ply"): | |
| gs = result | |
| elif isinstance(result, dict): | |
| for k in ("gs", "gaussian", "gaussians", "scene"): | |
| v = result.get(k) | |
| if v is not None: | |
| gs = v[0] if isinstance(v, (list, tuple)) else v | |
| break | |
| if gs is not None and hasattr(gs, "save_ply"): | |
| ply = f"{od}/temp.ply" | |
| gs.save_ply(ply) | |
| import open3d as o3d | |
| pcd = o3d.io.read_point_cloud(ply) | |
| pcd.estimate_normals() | |
| mesh, _ = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=8) | |
| o3d.io.write_triangle_mesh(glb, mesh) | |
| elif gs is not None and hasattr(gs, "_xyz"): | |
| import open3d as o3d | |
| pcd = o3d.geometry.PointCloud() | |
| pcd.points = o3d.utility.Vector3dVector(gs._xyz.detach().cpu().numpy()) | |
| pcd.estimate_normals() | |
| mesh, _ = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=8) | |
| o3d.io.write_triangle_mesh(glb, mesh) | |
| elif isinstance(result, dict) and "mesh" in result: | |
| m = result["mesh"] | |
| if hasattr(m, "export"): | |
| m.export(glb) | |
| else: | |
| keys = list(result.keys()) if isinstance(result, dict) else dir(result) | |
| return None, preview, f"Cannot extract 3D. Keys: {keys}" | |
| n = 0 | |
| try: | |
| n = len(trimesh.load(glb, force="mesh").faces) | |
| except Exception: | |
| pass | |
| elapsed = int(time.time() - t0) | |
| return glb, preview, f"OK: {len(masks)} objects, {n:,} faces ({elapsed}s)" | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return None, None, f"Error: {e}" | |
| # --- UI --- | |
| with gr.Blocks(title="SAM 3D Objects") as demo: | |
| gr.Markdown("# SAM 3D Objects\nImage → 3D (GLB). SAM2 detection + SAM3D reconstruction.") | |
| with gr.Tab("Reconstruct"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| inp = gr.Image(label="Input", type="numpy") | |
| btn = gr.Button("Reconstruct", variant="primary", size="lg") | |
| with gr.Column(): | |
| prev = gr.Image(label="Detection", type="numpy", interactive=False) | |
| stat = gr.Textbox(label="Status") | |
| with gr.Row(): | |
| m3d = gr.Model3D(label="3D Preview") | |
| dl = gr.File(label="Download GLB") | |
| btn.click(reconstruct_objects, inputs=[inp], outputs=[m3d, prev, stat]) | |
| m3d.change(lambda x: x, inputs=[m3d], outputs=[dl]) | |
| with gr.Tab("Diagnose"): | |
| dbtn = gr.Button("Diagnose GPU & Modules") | |
| dout = gr.Textbox(lines=15) | |
| dbtn.click(diagnose, outputs=[dout]) | |
| demo.launch(mcp_server=True) | |