dev-bjoern commited on
Commit
98cf79e
·
1 Parent(s): 0e828b5

Fix SAM 3D Objects import: use notebook/inference.py

Browse files
Files changed (1) hide show
  1. app.py +39 -21
app.py CHANGED
@@ -30,9 +30,10 @@ if not SAM3D_PATH.exists():
30
  "https://github.com/facebookresearch/sam-3d-objects.git",
31
  str(SAM3D_PATH)
32
  ], check=True)
33
- sys.path.insert(0, str(SAM3D_PATH))
34
 
 
35
  sys.path.insert(0, str(SAM3D_PATH))
 
36
 
37
  # Global models
38
  SAM3D_MODEL = None
@@ -66,15 +67,21 @@ def load_sam3d():
66
  import torch
67
  print("Loading SAM 3D Objects model...")
68
 
 
69
  checkpoint_dir = snapshot_download(
70
  repo_id="facebook/sam-3d-objects",
71
  token=os.environ.get("HF_TOKEN")
72
  )
73
 
74
- from sam_3d_objects import Sam3dObjects
 
75
 
76
- device = "cuda" if torch.cuda.is_available() else "cpu"
77
- SAM3D_MODEL = Sam3dObjects.from_pretrained(checkpoint_dir, device=device)
 
 
 
 
78
 
79
  print("✓ SAM 3D Objects loaded")
80
  return SAM3D_MODEL
@@ -101,7 +108,7 @@ def reconstruct_objects(image: np.ndarray):
101
 
102
  # Load models
103
  generator = load_sam2()
104
- sam3d = load_sam3d()
105
 
106
  # Convert to PIL if needed
107
  if isinstance(image, np.ndarray):
@@ -110,9 +117,9 @@ def reconstruct_objects(image: np.ndarray):
110
  pil_image = image
111
  image = np.array(pil_image)
112
 
113
- # Auto-detect all objects
114
  print("Detecting objects...")
115
- masks = generator.generate(pil_image)
116
 
117
  if not masks or len(masks) == 0:
118
  return None, image, "⚠️ No objects detected"
@@ -125,26 +132,37 @@ def reconstruct_objects(image: np.ndarray):
125
  preview = image.copy()
126
  preview[best_mask] = (preview[best_mask] * 0.5 + np.array([0, 255, 0]) * 0.5).astype(np.uint8)
127
 
128
- # Run 3D reconstruction on largest object
 
 
 
129
  print("Reconstructing 3D...")
130
- mask_uint8 = best_mask.astype(np.uint8)
131
- outputs = sam3d.predict(image, mask_uint8)
132
 
133
- if outputs is None:
134
  return None, preview, "⚠️ 3D reconstruction failed"
135
 
136
- # Export as GLB
137
  output_dir = tempfile.mkdtemp()
138
- glb_path = f"{output_dir}/object_{uuid.uuid4().hex[:8]}.glb"
139
-
140
- # Get vertices from gaussian splat
141
- vertices = outputs.get_xyz().cpu().numpy()
142
 
143
- # Export as point cloud GLB
144
- cloud = trimesh.PointCloud(vertices)
145
- cloud.export(glb_path, file_type='glb')
146
-
147
- return glb_path, preview, f"✓ Detected {len(masks)} objects, reconstructed largest ({len(vertices)} points)"
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
  except Exception as e:
150
  import traceback
 
30
  "https://github.com/facebookresearch/sam-3d-objects.git",
31
  str(SAM3D_PATH)
32
  ], check=True)
 
33
 
34
+ # Add both repo root and notebook folder to path
35
  sys.path.insert(0, str(SAM3D_PATH))
36
+ sys.path.insert(0, str(SAM3D_PATH / "notebook"))
37
 
38
  # Global models
39
  SAM3D_MODEL = None
 
67
  import torch
68
  print("Loading SAM 3D Objects model...")
69
 
70
+ # Download checkpoints
71
  checkpoint_dir = snapshot_download(
72
  repo_id="facebook/sam-3d-objects",
73
  token=os.environ.get("HF_TOKEN")
74
  )
75
 
76
+ # Import from notebook/inference.py
77
+ from inference import Inference
78
 
79
+ # Config path in the repo
80
+ config_path = str(SAM3D_PATH / "sam3d_objects" / "configs" / "default.yaml")
81
+
82
+ SAM3D_MODEL = Inference(config_path, compile=False)
83
+ # Point to downloaded checkpoints
84
+ SAM3D_MODEL.checkpoint_dir = checkpoint_dir
85
 
86
  print("✓ SAM 3D Objects loaded")
87
  return SAM3D_MODEL
 
108
 
109
  # Load models
110
  generator = load_sam2()
111
+ inference = load_sam3d()
112
 
113
  # Convert to PIL if needed
114
  if isinstance(image, np.ndarray):
 
117
  pil_image = image
118
  image = np.array(pil_image)
119
 
120
+ # Auto-detect all objects with SAM2
121
  print("Detecting objects...")
122
+ masks = generator.generate(image)
123
 
124
  if not masks or len(masks) == 0:
125
  return None, image, "⚠️ No objects detected"
 
132
  preview = image.copy()
133
  preview[best_mask] = (preview[best_mask] * 0.5 + np.array([0, 255, 0]) * 0.5).astype(np.uint8)
134
 
135
+ # Convert mask to PIL
136
+ mask_pil = PILImage.fromarray((best_mask * 255).astype(np.uint8))
137
+
138
+ # Run 3D reconstruction
139
  print("Reconstructing 3D...")
140
+ result = inference(image=pil_image, mask=mask_pil)
 
141
 
142
+ if result is None:
143
  return None, preview, "⚠️ 3D reconstruction failed"
144
 
145
+ # Export as PLY (gaussian splat format)
146
  output_dir = tempfile.mkdtemp()
147
+ ply_path = f"{output_dir}/object_{uuid.uuid4().hex[:8]}.ply"
 
 
 
148
 
149
+ # Save the gaussian splat
150
+ if hasattr(result, 'save_ply'):
151
+ result.save_ply(ply_path)
152
+ elif 'gaussians' in result:
153
+ result['gaussians'].save_ply(ply_path)
154
+ else:
155
+ # Try to extract vertices and save as point cloud
156
+ vertices = result.get('xyz', result.get('points', None))
157
+ if vertices is not None:
158
+ if torch.is_tensor(vertices):
159
+ vertices = vertices.cpu().numpy()
160
+ cloud = trimesh.PointCloud(vertices)
161
+ cloud.export(ply_path)
162
+ else:
163
+ return None, preview, "⚠️ Could not extract 3D data"
164
+
165
+ return ply_path, preview, f"✓ Detected {len(masks)} objects, reconstructed largest"
166
 
167
  except Exception as e:
168
  import traceback