jboth's picture
Upload app.py with huggingface_hub
a51525f verified
raw
history blame
8.33 kB
"""SAM 3D Objects – kaolin+pytorch3d stubbed for ZeroGPU (PyTorch 2.10+cu128)."""
import os, sys, subprocess
os.environ.setdefault("CUDA_HOME", "/usr/local/cuda")
os.environ.setdefault("CONDA_PREFIX", "/usr/local")
os.environ["LIDRA_SKIP_INIT"] = "true"
# MUST import spaces before torch
import spaces
import gradio as gr
import numpy as np
from PIL import Image
from huggingface_hub import snapshot_download, login
import tempfile
from pathlib import Path
if os.environ.get("HF_TOKEN"):
login(token=os.environ["HF_TOKEN"])
# --- Stubs (must be before sam3d imports) ---
STUB_KAOLIN = Path("/home/user/app/kaolin_stub")
STUB_PT3D = Path("/home/user/app/pytorch3d_stub")
STUB_SPCONV = Path("/home/user/app/spconv_stub")
for stub in [STUB_KAOLIN, STUB_PT3D, STUB_SPCONV]:
if stub.exists():
sys.path.insert(0, str(stub))
print(f"Stub added: {stub.name}")
# --- Runtime pip installs ---
def _pip(*a):
r = subprocess.run([sys.executable, "-m", "pip", "install", "--no-cache-dir"] + list(a),
capture_output=True, text=True, timeout=1200)
ok = r.returncode == 0
tag = a[-1][:50] if a else "?"
if ok:
print(f" pip OK: {tag}")
else:
print(f" pip FAIL: {tag}")
print(f" {r.stderr[-300:]}")
return ok
print("=== Runtime installs ===")
_pip("open3d>=0.18.0")
_pip("--no-deps", "utils3d") # --no-deps: skip jupyter dependency
_pip("iopath")
_pip("--no-deps", "sam2>=1.1.0")
_pip("--no-deps", "git+https://github.com/microsoft/MoGe.git@a8c37341bc0325ca99b9d57981cc3bb2bd3e255b")
# gsplat
for idx in ["https://docs.gsplat.studio/whl/pt210cu128",
"https://docs.gsplat.studio/whl/pt28cu128"]:
if _pip("--no-deps", f"--extra-index-url={idx}", "gsplat"):
break
# DO NOT import CUDA-dependent packages here!
# --- Clone sam-3d-objects ---
SAM3D_PATH = Path("/home/user/app/sam-3d-objects")
if not SAM3D_PATH.exists():
print("Cloning sam-3d-objects...")
subprocess.run(["git", "clone", "--depth", "1",
"https://github.com/facebookresearch/sam-3d-objects.git",
str(SAM3D_PATH)], check=True)
subprocess.run([sys.executable, "-m", "pip", "install", "-e", str(SAM3D_PATH), "--no-deps"],
capture_output=True, text=True)
# Hydra patch
patch = SAM3D_PATH / "patching" / "hydra"
if patch.exists():
subprocess.run(["bash", str(patch)], capture_output=True, cwd=str(SAM3D_PATH))
sys.path.insert(0, str(SAM3D_PATH))
sys.path.insert(0, str(SAM3D_PATH / "notebook"))
# --- Pre-download checkpoints ---
print("Downloading SAM3D checkpoints...")
CKPT_DIR = snapshot_download(repo_id="facebook/sam-3d-objects",
token=os.environ.get("HF_TOKEN"))
hf_ckpt = Path(CKPT_DIR) / "checkpoints"
local_ckpt = SAM3D_PATH / "checkpoints" / "hf"
if hf_ckpt.exists() and not local_ckpt.exists():
local_ckpt.parent.mkdir(parents=True, exist_ok=True)
local_ckpt.symlink_to(hf_ckpt)
CONFIG_PATH = str(local_ckpt / "pipeline.yaml")
print(f"Config exists: {Path(CONFIG_PATH).exists()}")
print("=== Startup complete ===")
# --- Endpoints ---
@spaces.GPU(duration=60)
def diagnose():
import torch
lines = [f"torch={torch.__version__}", f"cuda={torch.cuda.is_available()}"]
if torch.cuda.is_available():
lines.append(f"gpu={torch.cuda.get_device_name()}")
for mod in ["kaolin", "utils3d", "iopath", "pytorch3d", "open3d", "gsplat", "moge"]:
try:
m = __import__(mod)
lines.append(f"{mod}: OK ({getattr(m, '__version__', '-')})")
except Exception as e:
lines.append(f"{mod}: FAIL - {e}")
try:
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
lines.append("sam2: OK")
except Exception as e:
lines.append(f"sam2: FAIL - {e}")
try:
from inference import Inference
lines.append("SAM3D Inference: importable")
except Exception as e:
lines.append(f"SAM3D Inference: FAIL - {e}")
lines.append(f"config: {Path(CONFIG_PATH).exists()}")
return "\n".join(lines)
@spaces.GPU(duration=300)
def reconstruct_objects(image: np.ndarray):
if image is None:
return None, None, "No image"
try:
import torch, trimesh, time
t0 = time.time()
print(f"GPU: {torch.cuda.get_device_name()}")
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
sam2_gen = SAM2AutomaticMaskGenerator.from_pretrained("facebook/sam2-hiera-large")
print(f" SAM2 loaded ({time.time()-t0:.0f}s)")
image_np = np.array(image) if not isinstance(image, np.ndarray) else image
masks = sam2_gen.generate(image_np)
if not masks:
return None, image_np, "No objects detected"
masks = sorted(masks, key=lambda x: x["area"], reverse=True)
best_mask = masks[0]["segmentation"]
preview = image_np.copy()
preview[best_mask] = (preview[best_mask] * 0.5 + np.array([0, 255, 0]) * 0.5).astype(np.uint8)
print(f" {len(masks)} masks ({time.time()-t0:.0f}s)")
from inference import Inference
sam3d = Inference(CONFIG_PATH, compile=False)
print(f" SAM3D loaded ({time.time()-t0:.0f}s)")
result = sam3d(image=image_np, mask=best_mask, seed=42)
print(f" Reconstructed ({time.time()-t0:.0f}s)")
if result is None:
return None, preview, "Reconstruction returned None"
od = tempfile.mkdtemp()
glb = f"{od}/object.glb"
gs = None
if hasattr(result, "save_ply"):
gs = result
elif isinstance(result, dict):
for k in ("gs", "gaussian", "gaussians", "scene"):
v = result.get(k)
if v is not None:
gs = v[0] if isinstance(v, (list, tuple)) else v
break
if gs is not None and hasattr(gs, "save_ply"):
ply = f"{od}/temp.ply"
gs.save_ply(ply)
import open3d as o3d
pcd = o3d.io.read_point_cloud(ply)
pcd.estimate_normals()
mesh, _ = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=8)
o3d.io.write_triangle_mesh(glb, mesh)
elif gs is not None and hasattr(gs, "_xyz"):
import open3d as o3d
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(gs._xyz.detach().cpu().numpy())
pcd.estimate_normals()
mesh, _ = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=8)
o3d.io.write_triangle_mesh(glb, mesh)
elif isinstance(result, dict) and "mesh" in result:
m = result["mesh"]
if hasattr(m, "export"):
m.export(glb)
else:
keys = list(result.keys()) if isinstance(result, dict) else dir(result)
return None, preview, f"Cannot extract 3D. Keys: {keys}"
n = 0
try:
n = len(trimesh.load(glb, force="mesh").faces)
except Exception:
pass
elapsed = int(time.time() - t0)
return glb, preview, f"OK: {len(masks)} objects, {n:,} faces ({elapsed}s)"
except Exception as e:
import traceback
traceback.print_exc()
return None, None, f"Error: {e}"
# --- UI ---
with gr.Blocks(title="SAM 3D Objects") as demo:
gr.Markdown("# SAM 3D Objects\nImage → 3D (GLB). SAM2 detection + SAM3D reconstruction.")
with gr.Tab("Reconstruct"):
with gr.Row():
with gr.Column():
inp = gr.Image(label="Input", type="numpy")
btn = gr.Button("Reconstruct", variant="primary", size="lg")
with gr.Column():
prev = gr.Image(label="Detection", type="numpy", interactive=False)
stat = gr.Textbox(label="Status")
with gr.Row():
m3d = gr.Model3D(label="3D Preview")
dl = gr.File(label="Download GLB")
btn.click(reconstruct_objects, inputs=[inp], outputs=[m3d, prev, stat])
m3d.change(lambda x: x, inputs=[m3d], outputs=[dl])
with gr.Tab("Diagnose"):
dbtn = gr.Button("Diagnose GPU & Modules")
dout = gr.Textbox(lines=15)
dbtn.click(diagnose, outputs=[dout])
demo.launch(mcp_server=True)