Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
98cf79e
1
Parent(s):
0e828b5
Fix SAM 3D Objects import: use notebook/inference.py
Browse files
app.py
CHANGED
|
@@ -30,9 +30,10 @@ if not SAM3D_PATH.exists():
|
|
| 30 |
"https://github.com/facebookresearch/sam-3d-objects.git",
|
| 31 |
str(SAM3D_PATH)
|
| 32 |
], check=True)
|
| 33 |
-
sys.path.insert(0, str(SAM3D_PATH))
|
| 34 |
|
|
|
|
| 35 |
sys.path.insert(0, str(SAM3D_PATH))
|
|
|
|
| 36 |
|
| 37 |
# Global models
|
| 38 |
SAM3D_MODEL = None
|
|
@@ -66,15 +67,21 @@ def load_sam3d():
|
|
| 66 |
import torch
|
| 67 |
print("Loading SAM 3D Objects model...")
|
| 68 |
|
|
|
|
| 69 |
checkpoint_dir = snapshot_download(
|
| 70 |
repo_id="facebook/sam-3d-objects",
|
| 71 |
token=os.environ.get("HF_TOKEN")
|
| 72 |
)
|
| 73 |
|
| 74 |
-
|
|
|
|
| 75 |
|
| 76 |
-
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
print("✓ SAM 3D Objects loaded")
|
| 80 |
return SAM3D_MODEL
|
|
@@ -101,7 +108,7 @@ def reconstruct_objects(image: np.ndarray):
|
|
| 101 |
|
| 102 |
# Load models
|
| 103 |
generator = load_sam2()
|
| 104 |
-
|
| 105 |
|
| 106 |
# Convert to PIL if needed
|
| 107 |
if isinstance(image, np.ndarray):
|
|
@@ -110,9 +117,9 @@ def reconstruct_objects(image: np.ndarray):
|
|
| 110 |
pil_image = image
|
| 111 |
image = np.array(pil_image)
|
| 112 |
|
| 113 |
-
# Auto-detect all objects
|
| 114 |
print("Detecting objects...")
|
| 115 |
-
masks = generator.generate(
|
| 116 |
|
| 117 |
if not masks or len(masks) == 0:
|
| 118 |
return None, image, "⚠️ No objects detected"
|
|
@@ -125,26 +132,37 @@ def reconstruct_objects(image: np.ndarray):
|
|
| 125 |
preview = image.copy()
|
| 126 |
preview[best_mask] = (preview[best_mask] * 0.5 + np.array([0, 255, 0]) * 0.5).astype(np.uint8)
|
| 127 |
|
| 128 |
-
#
|
|
|
|
|
|
|
|
|
|
| 129 |
print("Reconstructing 3D...")
|
| 130 |
-
|
| 131 |
-
outputs = sam3d.predict(image, mask_uint8)
|
| 132 |
|
| 133 |
-
if
|
| 134 |
return None, preview, "⚠️ 3D reconstruction failed"
|
| 135 |
|
| 136 |
-
# Export as
|
| 137 |
output_dir = tempfile.mkdtemp()
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
# Get vertices from gaussian splat
|
| 141 |
-
vertices = outputs.get_xyz().cpu().numpy()
|
| 142 |
|
| 143 |
-
#
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
|
| 149 |
except Exception as e:
|
| 150 |
import traceback
|
|
|
|
| 30 |
"https://github.com/facebookresearch/sam-3d-objects.git",
|
| 31 |
str(SAM3D_PATH)
|
| 32 |
], check=True)
|
|
|
|
| 33 |
|
| 34 |
+
# Add both repo root and notebook folder to path
|
| 35 |
sys.path.insert(0, str(SAM3D_PATH))
|
| 36 |
+
sys.path.insert(0, str(SAM3D_PATH / "notebook"))
|
| 37 |
|
| 38 |
# Global models
|
| 39 |
SAM3D_MODEL = None
|
|
|
|
| 67 |
import torch
|
| 68 |
print("Loading SAM 3D Objects model...")
|
| 69 |
|
| 70 |
+
# Download checkpoints
|
| 71 |
checkpoint_dir = snapshot_download(
|
| 72 |
repo_id="facebook/sam-3d-objects",
|
| 73 |
token=os.environ.get("HF_TOKEN")
|
| 74 |
)
|
| 75 |
|
| 76 |
+
# Import from notebook/inference.py
|
| 77 |
+
from inference import Inference
|
| 78 |
|
| 79 |
+
# Config path in the repo
|
| 80 |
+
config_path = str(SAM3D_PATH / "sam3d_objects" / "configs" / "default.yaml")
|
| 81 |
+
|
| 82 |
+
SAM3D_MODEL = Inference(config_path, compile=False)
|
| 83 |
+
# Point to downloaded checkpoints
|
| 84 |
+
SAM3D_MODEL.checkpoint_dir = checkpoint_dir
|
| 85 |
|
| 86 |
print("✓ SAM 3D Objects loaded")
|
| 87 |
return SAM3D_MODEL
|
|
|
|
| 108 |
|
| 109 |
# Load models
|
| 110 |
generator = load_sam2()
|
| 111 |
+
inference = load_sam3d()
|
| 112 |
|
| 113 |
# Convert to PIL if needed
|
| 114 |
if isinstance(image, np.ndarray):
|
|
|
|
| 117 |
pil_image = image
|
| 118 |
image = np.array(pil_image)
|
| 119 |
|
| 120 |
+
# Auto-detect all objects with SAM2
|
| 121 |
print("Detecting objects...")
|
| 122 |
+
masks = generator.generate(image)
|
| 123 |
|
| 124 |
if not masks or len(masks) == 0:
|
| 125 |
return None, image, "⚠️ No objects detected"
|
|
|
|
| 132 |
preview = image.copy()
|
| 133 |
preview[best_mask] = (preview[best_mask] * 0.5 + np.array([0, 255, 0]) * 0.5).astype(np.uint8)
|
| 134 |
|
| 135 |
+
# Convert mask to PIL
|
| 136 |
+
mask_pil = PILImage.fromarray((best_mask * 255).astype(np.uint8))
|
| 137 |
+
|
| 138 |
+
# Run 3D reconstruction
|
| 139 |
print("Reconstructing 3D...")
|
| 140 |
+
result = inference(image=pil_image, mask=mask_pil)
|
|
|
|
| 141 |
|
| 142 |
+
if result is None:
|
| 143 |
return None, preview, "⚠️ 3D reconstruction failed"
|
| 144 |
|
| 145 |
+
# Export as PLY (gaussian splat format)
|
| 146 |
output_dir = tempfile.mkdtemp()
|
| 147 |
+
ply_path = f"{output_dir}/object_{uuid.uuid4().hex[:8]}.ply"
|
|
|
|
|
|
|
|
|
|
| 148 |
|
| 149 |
+
# Save the gaussian splat
|
| 150 |
+
if hasattr(result, 'save_ply'):
|
| 151 |
+
result.save_ply(ply_path)
|
| 152 |
+
elif 'gaussians' in result:
|
| 153 |
+
result['gaussians'].save_ply(ply_path)
|
| 154 |
+
else:
|
| 155 |
+
# Try to extract vertices and save as point cloud
|
| 156 |
+
vertices = result.get('xyz', result.get('points', None))
|
| 157 |
+
if vertices is not None:
|
| 158 |
+
if torch.is_tensor(vertices):
|
| 159 |
+
vertices = vertices.cpu().numpy()
|
| 160 |
+
cloud = trimesh.PointCloud(vertices)
|
| 161 |
+
cloud.export(ply_path)
|
| 162 |
+
else:
|
| 163 |
+
return None, preview, "⚠️ Could not extract 3D data"
|
| 164 |
+
|
| 165 |
+
return ply_path, preview, f"✓ Detected {len(masks)} objects, reconstructed largest"
|
| 166 |
|
| 167 |
except Exception as e:
|
| 168 |
import traceback
|