Spaces:
Paused
Paused
File size: 16,689 Bytes
a9397fc 8df3f9a 901cd2d 75e28bb 901cd2d 8df3f9a 901cd2d 8df3f9a 0bb18a8 901cd2d 8df3f9a a9397fc 75e28bb 784d43d a9397fc 8df3f9a a9397fc 31ef6d3 a9397fc 0bb18a8 8df3f9a b64edd3 8df3f9a 784d43d d9d5913 a9397fc 8df3f9a 31ef6d3 8df3f9a ccebd17 7b8ab13 ccebd17 7b8ab13 ccebd17 8df3f9a bef166b 56e808f bef166b 8df3f9a 31ef6d3 d9d5913 901cd2d 8df3f9a 31ef6d3 901cd2d d9d5913 0bb18a8 31ef6d3 a9397fc 31ef6d3 a9397fc 31ef6d3 8df3f9a d66f250 8df3f9a 901cd2d 8df3f9a 31ef6d3 8df3f9a 31ef6d3 84293b6 8df3f9a 31ef6d3 8df3f9a 0bb18a8 8df3f9a 84293b6 31ef6d3 84293b6 31ef6d3 84293b6 8df3f9a 84293b6 31ef6d3 84293b6 8df3f9a a9397fc 8df3f9a a9397fc 8df3f9a a9397fc 8df3f9a 901cd2d 8df3f9a 6b31dc3 901cd2d 236a930 0a42ddd 236a930 8df3f9a 31ef6d3 8df3f9a 236a930 8df3f9a a9397fc 31ef6d3 8df3f9a 901cd2d 8df3f9a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 | """SAM 3D Objects – kaolin+pytorch3d stubbed for ZeroGPU (PyTorch 2.10+cu128)."""
import os, sys, subprocess
os.environ.setdefault("CUDA_HOME", "/usr/local/cuda")
os.environ.setdefault("CONDA_PREFIX", "/usr/local")
os.environ["LIDRA_SKIP_INIT"] = "true"
os.environ["ATTN_BACKEND"] = "sdpa"
os.environ["SPARSE_ATTN_BACKEND"] = "sdpa"
os.environ["SPARSE_BACKEND"] = "spconv"
# MUST import spaces before torch
import spaces
import gradio as gr
import numpy as np
from PIL import Image
from huggingface_hub import snapshot_download, login
import tempfile
from pathlib import Path
if os.environ.get("HF_TOKEN"):
login(token=os.environ["HF_TOKEN"])
# --- Stubs (must be before sam3d imports) ---
STUB_KAOLIN = Path("/home/user/app/kaolin_stub")
STUB_PT3D = Path("/home/user/app/pytorch3d_stub")
STUB_FA = Path("/home/user/app/flash_attn_stub")
for stub in [STUB_KAOLIN, STUB_PT3D, STUB_FA]:
if stub.exists():
sys.path.insert(0, str(stub))
print(f"Stub added: {stub.name}")
# --- Runtime pip installs ---
def _pip(*a):
r = subprocess.run([sys.executable, "-m", "pip", "install", "--no-cache-dir"] + list(a),
capture_output=True, text=True, timeout=1200)
ok = r.returncode == 0
tag = a[-1][:50] if a else "?"
if ok:
print(f" pip OK: {tag}")
else:
print(f" pip FAIL: {tag}")
print(f" {r.stderr[-300:]}")
return ok
print("=== Runtime installs ===")
_pip("open3d>=0.18.0")
_pip("--no-deps", "git+https://github.com/EasternJournalist/utils3d.git") # --no-deps: skip jupyter dependency
_pip("iopath")
_pip("--no-deps", "sam2>=1.1.0")
_pip("--no-deps", "git+https://github.com/microsoft/MoGe.git@a8c37341bc0325ca99b9d57981cc3bb2bd3e255b")
# gsplat
for idx in ["https://docs.gsplat.studio/whl/pt210cu128",
"https://docs.gsplat.studio/whl/pt28cu128"]:
if _pip("--no-deps", f"--extra-index-url={idx}", "gsplat"):
break
# spconv (sparse convolution – needed for SAM3D's SLatFlowModel)
# cu124 wheel is forward-compatible with cu128
_pip("spconv-cu124==2.3.8")
# DO NOT import CUDA-dependent packages here!
# --- Clone sam-3d-objects ---
SAM3D_PATH = Path("/home/user/app/sam-3d-objects")
if not SAM3D_PATH.exists():
print("Cloning sam-3d-objects...")
subprocess.run(["git", "clone", "--depth", "1",
"https://github.com/facebookresearch/sam-3d-objects.git",
str(SAM3D_PATH)], check=True)
subprocess.run([sys.executable, "-m", "pip", "install", "-e", str(SAM3D_PATH), "--no-deps"],
capture_output=True, text=True)
# Hydra patch
patch = SAM3D_PATH / "patching" / "hydra"
if patch.exists():
subprocess.run(["bash", str(patch)], capture_output=True, cwd=str(SAM3D_PATH))
# CRITICAL PATCH: Prevent SAM3D from overriding ATTN_BACKEND to flash_attn
# inference_pipeline.py auto-detects H200/A100 and forces flash_attn,
# but we don't have the real flash_attn package.
ip_file = SAM3D_PATH / "sam3d_objects" / "pipeline" / "inference_pipeline.py"
if ip_file.exists():
ip_src = ip_file.read_text()
# Find and replace the set_attention_backend function
old_marker = 'os.environ["ATTN_BACKEND"] = "flash_attn"'
if old_marker in ip_src:
# Replace the entire if-block that forces flash_attn
ip_src = ip_src.replace(
'if "A100" in gpu_name or "H100" in gpu_name or "H200" in gpu_name:\n'
' # logger.info("Use flash_attn")\n'
' os.environ["ATTN_BACKEND"] = "flash_attn"\n'
' os.environ["SPARSE_ATTN_BACKEND"] = "flash_attn"',
'# PATCHED: Always use sdpa backend (flash_attn not available on ZeroGPU)\n'
' logger.info("Using sdpa backend (patched for ZeroGPU)")\n'
' os.environ.setdefault("ATTN_BACKEND", "sdpa")\n'
' os.environ.setdefault("SPARSE_ATTN_BACKEND", "sdpa")'
)
ip_file.write_text(ip_src)
print("PATCHED: inference_pipeline.py - forced sdpa backend")
else:
print("INFO: inference_pipeline.py already patched or different version")
sys.path.insert(0, str(SAM3D_PATH))
sys.path.insert(0, str(SAM3D_PATH / "notebook"))
# --- Monkey-patch: inject depth_edge into utils3d.numpy ---
# utils3d package lacks depth_edge in newer versions; SAM3D needs it for layout post-optimization
try:
import utils3d.numpy as _u3d_np
if not hasattr(_u3d_np, 'depth_edge'):
def _depth_edge(depth, rtol=0.03, mask=None):
from scipy.ndimage import sobel
import numpy as _np
d = _np.where(mask, depth, 0.0) if mask is not None else depth.copy()
gx = sobel(d, axis=1)
gy = sobel(d, axis=0)
grad = _np.sqrt(gx**2 + gy**2)
denom = _np.abs(d)
denom[denom < 1e-6] = 1e-6
edge = (grad / denom) > rtol
if mask is not None:
edge = edge & mask
return edge
_u3d_np.depth_edge = _depth_edge
def _normals_edge(normals, tol=0.1, mask=None):
"""Detect normal discontinuities."""
import numpy as _np
from scipy.ndimage import sobel
# Compute gradient of each normal component
edges = _np.zeros(normals.shape[:2], dtype=bool)
for c in range(normals.shape[-1]):
ch = normals[..., c]
if mask is not None:
ch = _np.where(mask, ch, 0.0)
gx = sobel(ch, axis=1)
gy = sobel(ch, axis=0)
grad = _np.sqrt(gx**2 + gy**2)
edges |= (grad > tol)
if mask is not None:
edges = edges & mask
return edges
_u3d_np.normals_edge = _normals_edge
# Also inject a catch-all __getattr__ for any future missing functions
_orig_getattr = getattr(_u3d_np, '__getattr__', None)
def _u3d_catchall(name):
if name.startswith('__') and name.endswith('__'):
raise AttributeError(name)
import warnings
warnings.warn(f"utils3d.numpy stub: {name} not implemented, returning dummy")
def _dummy(*a, **kw):
import numpy as _np
return _np.zeros(1)
return _dummy
import types
_u3d_np.__getattr__ = _u3d_catchall
print("Injected depth_edge + normals_edge + catch-all into utils3d.numpy")
except Exception as e:
print(f"depth_edge patch skipped: {e}")
# --- Pre-download checkpoints ---
print("Downloading SAM3D checkpoints...")
CKPT_DIR = snapshot_download(repo_id="facebook/sam-3d-objects",
token=os.environ.get("HF_TOKEN"))
hf_ckpt = Path(CKPT_DIR) / "checkpoints"
local_ckpt = SAM3D_PATH / "checkpoints" / "hf"
if hf_ckpt.exists() and not local_ckpt.exists():
local_ckpt.parent.mkdir(parents=True, exist_ok=True)
local_ckpt.symlink_to(hf_ckpt)
CONFIG_PATH = str(local_ckpt / "pipeline.yaml")
print(f"Config exists: {Path(CONFIG_PATH).exists()}")
print("=== Startup complete ===")
# --- Endpoints ---
@spaces.GPU(duration=60)
def diagnose():
import torch
lines = [f"torch={torch.__version__}", f"cuda={torch.cuda.is_available()}"]
if torch.cuda.is_available():
lines.append(f"gpu={torch.cuda.get_device_name()}")
for mod in ["kaolin", "utils3d", "iopath", "pytorch3d", "open3d", "gsplat", "moge"]:
try:
m = __import__(mod)
lines.append(f"{mod}: OK ({getattr(m, '__version__', '-')})")
except Exception as e:
lines.append(f"{mod}: FAIL - {e}")
try:
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
lines.append("sam2: OK")
except Exception as e:
lines.append(f"sam2: FAIL - {e}")
try:
from inference import Inference
lines.append("SAM3D Inference: importable")
except Exception as e:
lines.append(f"SAM3D Inference: FAIL - {e}")
lines.append(f"config: {Path(CONFIG_PATH).exists()}")
return "\n".join(lines)
@spaces.GPU(duration=300)
def reconstruct_objects(image: np.ndarray):
if image is None:
return None, None, "No image"
try:
import torch, trimesh, time
t0 = time.time()
print(f"GPU: {torch.cuda.get_device_name()}")
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
print(f" Loading SAM2... (VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB)")
sam2_gen = SAM2AutomaticMaskGenerator.from_pretrained("facebook/sam2-hiera-small")
print(f" SAM2 loaded ({time.time()-t0:.0f}s, VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB)")
image_np = np.array(image) if not isinstance(image, np.ndarray) else image
masks = sam2_gen.generate(image_np)
if not masks:
return None, image_np, "No objects detected"
masks = sorted(masks, key=lambda x: x["area"], reverse=True)
best_mask = masks[0]["segmentation"]
preview = image_np.copy()
preview[best_mask] = (preview[best_mask] * 0.5 + np.array([0, 255, 0]) * 0.5).astype(np.uint8)
print(f" {len(masks)} masks ({time.time()-t0:.0f}s)")
# Free SAM2 to save VRAM for SAM3D
del sam2_gen
torch.cuda.empty_cache()
print(f" SAM2 freed (VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB)")
from inference import Inference
print(f" Loading SAM3D... (VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB)")
sam3d = Inference(CONFIG_PATH, compile=False)
print(f" SAM3D loaded ({time.time()-t0:.0f}s, VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB)")
print(f" Running reconstruction... (VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB)")
result = sam3d(image=image_np, mask=best_mask, seed=42)
print(f" Reconstructed ({time.time()-t0:.0f}s, VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB)")
if result is None:
return None, preview, "Reconstruction returned None"
od = tempfile.mkdtemp()
glb = f"{od}/object.glb"
gs = None
if hasattr(result, "save_ply"):
gs = result
elif isinstance(result, dict):
for k in ("gs", "gaussian", "gaussians", "scene"):
v = result.get(k)
if v is not None:
gs = v[0] if isinstance(v, (list, tuple)) else v
break
if gs is not None and hasattr(gs, "save_ply"):
ply = f"{od}/temp.ply"
gs.save_ply(ply)
import open3d as o3d
pcd = o3d.io.read_point_cloud(ply)
pcd.estimate_normals()
mesh, _ = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=8)
o3d.io.write_triangle_mesh(glb, mesh)
elif gs is not None and hasattr(gs, "_xyz"):
import open3d as o3d
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(gs._xyz.detach().cpu().numpy())
pcd.estimate_normals()
mesh, _ = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=8)
o3d.io.write_triangle_mesh(glb, mesh)
elif isinstance(result, dict) and "mesh" in result:
m = result["mesh"]
if hasattr(m, "export"):
m.export(glb)
else:
keys = list(result.keys()) if isinstance(result, dict) else dir(result)
return None, preview, f"Cannot extract 3D. Keys: {keys}"
n = 0
try:
n = len(trimesh.load(glb, force="mesh").faces)
except Exception:
pass
elapsed = int(time.time() - t0)
return glb, preview, f"OK: {len(masks)} objects, {n:,} faces ({elapsed}s)"
except Exception as e:
import traceback
tb = traceback.format_exc()
print(tb)
return None, None, f"Error:\n{tb[-1500:]}"
@spaces.GPU(duration=60)
def test_sam3d_only(image: np.ndarray):
"""Test SAM3D reconstruction with center-crop mask (no SAM2)."""
if image is None:
return None, None, "No image"
try:
import torch, time, gc
t0 = time.time()
print(f"GPU: {torch.cuda.get_device_name()}, VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB")
image_np = np.array(image) if not isinstance(image, np.ndarray) else image
h, w = image_np.shape[:2]
# Create a center mask (middle 60% of image)
mask = np.zeros((h, w), dtype=bool)
y1, y2 = int(h * 0.2), int(h * 0.8)
x1, x2 = int(w * 0.2), int(w * 0.8)
mask[y1:y2, x1:x2] = True
preview = image_np.copy()
preview[mask] = (preview[mask] * 0.5 + np.array([0, 255, 0]) * 0.5).astype(np.uint8)
print(f" Mask created: {mask.sum()} pixels ({time.time()-t0:.0f}s)")
from inference import Inference
print(f" Loading SAM3D... VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB")
sam3d = Inference(CONFIG_PATH, compile=False)
print(f" SAM3D loaded ({time.time()-t0:.0f}s, VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB)")
print(f" Running reconstruction...")
result = sam3d(image=image_np, mask=mask, seed=42)
print(f" Done ({time.time()-t0:.0f}s, VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB)")
if result is None:
return None, preview, "Reconstruction returned None"
import tempfile
od = tempfile.mkdtemp()
glb = f"{od}/object.glb"
gs = None
if isinstance(result, dict):
for k in ("gs", "gaussian", "gaussians", "scene"):
v = result.get(k)
if v is not None:
gs = v[0] if isinstance(v, (list, tuple)) else v
break
if gs is not None and hasattr(gs, "save_ply"):
ply = f"{od}/temp.ply"
gs.save_ply(ply)
import open3d as o3d
pcd = o3d.io.read_point_cloud(ply)
pcd.estimate_normals()
mesh, _ = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=8)
o3d.io.write_triangle_mesh(glb, mesh)
elif gs is not None and hasattr(gs, "_xyz"):
import open3d as o3d
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(gs._xyz.detach().cpu().numpy())
pcd.estimate_normals()
mesh, _ = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=8)
o3d.io.write_triangle_mesh(glb, mesh)
else:
keys = list(result.keys()) if isinstance(result, dict) else dir(result)
return None, preview, f"Cannot extract 3D. Keys: {keys}"
import trimesh
n = 0
try:
n = len(trimesh.load(glb, force="mesh").faces)
except: pass
elapsed = int(time.time() - t0)
return glb, preview, f"OK: {n:,} faces ({elapsed}s)"
except Exception as e:
import traceback
tb = traceback.format_exc()
print(tb)
return None, None, f"Error:\n{tb[-1500:]}"
# --- UI ---
with gr.Blocks(title="SAM 3D Objects") as demo:
gr.Markdown("# SAM 3D Objects\nImage → 3D (GLB). SAM2 detection + SAM3D reconstruction.")
with gr.Tab("Reconstruct"):
with gr.Row():
with gr.Column():
inp = gr.Image(label="Input", type="numpy")
btn = gr.Button("Reconstruct", variant="primary", size="lg")
with gr.Column():
prev = gr.Image(label="Detection", type="numpy", interactive=False)
stat = gr.Textbox(label="Status")
with gr.Row():
m3d = gr.Model3D(label="3D Preview")
dl = gr.File(label="Download GLB")
btn.click(reconstruct_objects, inputs=[inp], outputs=[m3d, prev, stat])
m3d.change(lambda x: x, inputs=[m3d], outputs=[dl])
with gr.Tab("Test SAM3D Only"):
with gr.Row():
with gr.Column():
tinp = gr.Image(label="Input", type="numpy")
tbtn = gr.Button("Test SAM3D (no SAM2)", variant="primary")
with gr.Column():
tprev = gr.Image(label="Mask Preview", type="numpy", interactive=False)
tstat = gr.Textbox(label="Status")
with gr.Row():
tm3d = gr.Model3D(label="3D Preview")
tbtn.click(test_sam3d_only, inputs=[tinp], outputs=[tm3d, tprev, tstat])
with gr.Tab("Diagnose"):
dbtn = gr.Button("Diagnose GPU & Modules")
dout = gr.Textbox(lines=15)
dbtn.click(diagnose, outputs=[dout])
demo.launch(mcp_server=True)
|