JasonYinnnn commited on
Commit
9b3eb99
·
1 Parent(s): 054d245

add debug

Browse files
Files changed (1) hide show
  1. app.py +129 -118
app.py CHANGED
@@ -84,7 +84,7 @@ MAX_SEED = np.iinfo(np.int32).max
84
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp")
85
  EXAMPLE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets/example_data")
86
  DTYPE = torch.float16
87
- DEVICE = "cuda"
88
  VALID_RATIO_THRESHOLD = 0.005
89
  CROP_SIZE = 518
90
  work_space = None
@@ -220,45 +220,51 @@ def run_segmentation(
220
  image_prompts: Any,
221
  polygon_refinement: bool = True,
222
  ) -> Image.Image:
223
- rgb_image = image_prompts["image"].convert("RGB")
 
224
 
225
- global work_space
226
- global sam2_predictor
227
-
228
- if sam2_predictor is None:
229
- sam2_model = build_sam2(
230
- config_file=SAM2_CONFIG,
231
- ckpt_path=SAM2_CHECKPOINT,
232
- )
233
- sam2_predictor = SAM2ImagePredictor(sam2_model)
234
 
235
- # pre-process the layers and get the xyxy boxes of each layer
236
- if len(image_prompts["points"]) == 0:
237
- gr.Error("No points provided for segmentation. Please add points to the image.")
238
- return None
239
-
240
- boxes = [
241
- [
242
- [int(box[0]), int(box[1]), int(box[3]), int(box[4])]
243
- for box in image_prompts["points"]
 
 
 
 
 
 
 
 
244
  ]
245
- ]
246
 
247
- detections = segment(
248
- sam2_predictor,
249
- rgb_image,
250
- boxes=[boxes],
251
- polygon_refinement=polygon_refinement,
252
- )
253
- seg_map_pil = plot_segmentation(rgb_image, detections)
254
 
255
- torch.cuda.empty_cache()
256
-
257
- cleanup_tmp(TMP_DIR, expire_seconds=3600)
 
 
 
 
258
 
259
- work_space = os.path.join(TMP_DIR, f"work_space_{uuid.uuid4()}")
260
- os.makedirs(work_space, exist_ok=True)
261
- seg_map_pil.save(os.path.join(work_space, 'mask.png'))
 
262
 
263
  return seg_map_pil
264
 
@@ -268,92 +274,97 @@ def run_depth_estimation(
268
  image_prompts: Any,
269
  seg_image: Union[str, Image.Image],
270
  ) -> Image.Image:
271
- rgb_image = image_prompts["image"].convert("RGB")
272
-
273
- rgb_image = rgb_image.resize((1024, 1024), Image.Resampling.LANCZOS)
274
-
275
- global pipeline
276
- pipeline.cuda()
277
-
278
- global dpt_pack
279
- global work_space
280
- if work_space is None:
281
- work_space = os.path.join(TMP_DIR, f"work_space_{uuid.uuid4()}")
282
- os.makedirs(work_space, exist_ok=True)
283
- global generated_object_map
284
-
285
- generated_object_map = {}
286
-
287
- origin_W, origin_H = rgb_image.size
288
- if max(origin_H, origin_W) > 1024:
289
- factor = max(origin_H, origin_W) / 1024
290
- H = int(origin_H // factor)
291
- W = int(origin_W // factor)
292
- rgb_image = rgb_image.resize((W, H), Image.Resampling.LANCZOS)
293
- W, H = rgb_image.size
294
-
295
- input_image = np.array(rgb_image).astype(np.float32)
296
- input_image = torch.tensor(input_image / 255, dtype=torch.float32, device=DEVICE).permute(2, 0, 1)
297
-
298
- output = pipeline.models['scene_cond_model'].infer(input_image)
299
- depth = output['depth']
300
- intrinsics = output['intrinsics']
301
-
302
- invalid_mask = torch.logical_or(torch.isnan(depth), torch.isinf(depth))
303
- depth_mask = ~invalid_mask
304
-
305
- depth = torch.where(invalid_mask, 0.0, depth)
306
- K = torch.from_numpy(
307
- np.array([
308
- [intrinsics[0, 0].item() * W, 0, 0.5*W],
309
- [0, intrinsics[1, 1].item() * H, 0.5*H],
310
- [0, 0, 1]
311
- ])
312
- ).to(dtype=torch.float32, device=DEVICE)
313
-
314
- dpt_pack = {
315
- 'c2w': c2w.to(DEVICE),
316
- 'K': K,
317
- 'depth_mask': depth_mask,
318
- 'depth': depth
319
- }
320
-
321
- instance_labels = np.unique(np.array(seg_image).reshape(-1, 3), axis=0)
322
- seg_image = seg_image.resize((W, H), Image.Resampling.LANCZOS)
323
- seg_image = np.array(seg_image)
324
-
325
- mask_pack = []
326
- for instance_label in instance_labels:
327
- if (instance_label == np.array([0, 0, 0])).all():
328
- continue
329
- else:
330
- instance_mask = (seg_image.reshape(-1, 3) == instance_label).all(axis=-1).reshape(H, W)
331
- mask_pack.append(instance_mask)
332
- fg_mask = torch.from_numpy(np.stack(mask_pack).any(axis=0)).to(DEVICE)
333
-
334
- scene_est_depth_pts, scene_est_depth_pts_colors = \
335
- project2ply(depth_mask, depth, input_image, K, c2w)
336
- save_ply_path = os.path.join(work_space, "scene_pcd.glb")
337
-
338
- fg_depth_pts, _ = \
339
- project2ply(fg_mask, depth, input_image, K, c2w)
340
- _, trans, scale = normalize_vertices(fg_depth_pts)
341
 
342
- if trans.shape[0] == 1:
343
- trans = trans[0]
 
344
 
345
- dpt_pack.update(
346
- {
347
- "trans": trans,
348
- "scale": scale,
349
- }
350
- )
351
-
352
- trimesh.PointCloud(scene_est_depth_pts.reshape(-1, 3), scene_est_depth_pts_colors.reshape(-1, 3)).\
353
- apply_translation(-trans).apply_scale(1. / (scale + 1e-6)).\
354
- apply_transform(rot).export(save_ply_path)
355
-
356
- torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
 
358
  return save_ply_path
359
 
 
84
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp")
85
  EXAMPLE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets/example_data")
86
  DTYPE = torch.float16
87
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
88
  VALID_RATIO_THRESHOLD = 0.005
89
  CROP_SIZE = 518
90
  work_space = None
 
220
  image_prompts: Any,
221
  polygon_refinement: bool = True,
222
  ) -> Image.Image:
223
+ try:
224
+ rgb_image = image_prompts["image"].convert("RGB")
225
 
226
+ global work_space
227
+ global sam2_predictor
 
 
 
 
 
 
 
228
 
229
+ if sam2_predictor is None:
230
+ sam2_model = build_sam2(
231
+ config_file=SAM2_CONFIG,
232
+ ckpt_path=SAM2_CHECKPOINT,
233
+ )
234
+ sam2_predictor = SAM2ImagePredictor(sam2_model)
235
+
236
+ # pre-process the layers and get the xyxy boxes of each layer
237
+ if len(image_prompts["points"]) == 0:
238
+ gr.Error("No points provided for segmentation. Please add points to the image.")
239
+ return None
240
+
241
+ boxes = [
242
+ [
243
+ [int(box[0]), int(box[1]), int(box[3]), int(box[4])]
244
+ for box in image_prompts["points"]
245
+ ]
246
  ]
 
247
 
248
+ detections = segment(
249
+ sam2_predictor,
250
+ rgb_image,
251
+ boxes=[boxes],
252
+ polygon_refinement=polygon_refinement,
253
+ )
254
+ seg_map_pil = plot_segmentation(rgb_image, detections)
255
 
256
+ torch.cuda.empty_cache()
257
+
258
+ cleanup_tmp(TMP_DIR, expire_seconds=3600)
259
+
260
+ work_space = os.path.join(TMP_DIR, f"work_space_{uuid.uuid4()}")
261
+ os.makedirs(work_space, exist_ok=True)
262
+ seg_map_pil.save(os.path.join(work_space, 'mask.png'))
263
 
264
+ except Exception as e:
265
+ import traceback
266
+ traceback.print_exc()
267
+ raise gr.Error(f"run_segmentation failed: {e}")
268
 
269
  return seg_map_pil
270
 
 
274
  image_prompts: Any,
275
  seg_image: Union[str, Image.Image],
276
  ) -> Image.Image:
277
+ try:
278
+ rgb_image = image_prompts["image"].convert("RGB")
279
+
280
+ rgb_image = rgb_image.resize((1024, 1024), Image.Resampling.LANCZOS)
281
+
282
+ global pipeline
283
+ pipeline.cuda()
284
+
285
+ global dpt_pack
286
+ global work_space
287
+ if work_space is None:
288
+ work_space = os.path.join(TMP_DIR, f"work_space_{uuid.uuid4()}")
289
+ os.makedirs(work_space, exist_ok=True)
290
+ global generated_object_map
291
+
292
+ generated_object_map = {}
293
+
294
+ origin_W, origin_H = rgb_image.size
295
+ if max(origin_H, origin_W) > 1024:
296
+ factor = max(origin_H, origin_W) / 1024
297
+ H = int(origin_H // factor)
298
+ W = int(origin_W // factor)
299
+ rgb_image = rgb_image.resize((W, H), Image.Resampling.LANCZOS)
300
+ W, H = rgb_image.size
301
+
302
+ input_image = np.array(rgb_image).astype(np.float32)
303
+ input_image = torch.tensor(input_image / 255, dtype=torch.float32, device=DEVICE).permute(2, 0, 1)
304
+
305
+ output = pipeline.models['scene_cond_model'].infer(input_image)
306
+ depth = output['depth']
307
+ intrinsics = output['intrinsics']
308
+
309
+ invalid_mask = torch.logical_or(torch.isnan(depth), torch.isinf(depth))
310
+ depth_mask = ~invalid_mask
311
+
312
+ depth = torch.where(invalid_mask, 0.0, depth)
313
+ K = torch.from_numpy(
314
+ np.array([
315
+ [intrinsics[0, 0].item() * W, 0, 0.5*W],
316
+ [0, intrinsics[1, 1].item() * H, 0.5*H],
317
+ [0, 0, 1]
318
+ ])
319
+ ).to(dtype=torch.float32, device=DEVICE)
320
+
321
+ dpt_pack = {
322
+ 'c2w': c2w.to(DEVICE),
323
+ 'K': K,
324
+ 'depth_mask': depth_mask,
325
+ 'depth': depth
326
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
327
 
328
+ instance_labels = np.unique(np.array(seg_image).reshape(-1, 3), axis=0)
329
+ seg_image = seg_image.resize((W, H), Image.Resampling.LANCZOS)
330
+ seg_image = np.array(seg_image)
331
 
332
+ mask_pack = []
333
+ for instance_label in instance_labels:
334
+ if (instance_label == np.array([0, 0, 0])).all():
335
+ continue
336
+ else:
337
+ instance_mask = (seg_image.reshape(-1, 3) == instance_label).all(axis=-1).reshape(H, W)
338
+ mask_pack.append(instance_mask)
339
+ fg_mask = torch.from_numpy(np.stack(mask_pack).any(axis=0)).to(DEVICE)
340
+
341
+ scene_est_depth_pts, scene_est_depth_pts_colors = \
342
+ project2ply(depth_mask, depth, input_image, K, c2w)
343
+ save_ply_path = os.path.join(work_space, "scene_pcd.glb")
344
+
345
+ fg_depth_pts, _ = \
346
+ project2ply(fg_mask, depth, input_image, K, c2w)
347
+ _, trans, scale = normalize_vertices(fg_depth_pts)
348
+
349
+ if trans.shape[0] == 1:
350
+ trans = trans[0]
351
+
352
+ dpt_pack.update(
353
+ {
354
+ "trans": trans,
355
+ "scale": scale,
356
+ }
357
+ )
358
+
359
+ trimesh.PointCloud(scene_est_depth_pts.reshape(-1, 3), scene_est_depth_pts_colors.reshape(-1, 3)).\
360
+ apply_translation(-trans).apply_scale(1. / (scale + 1e-6)).\
361
+ apply_transform(rot).export(save_ply_path)
362
+
363
+ torch.cuda.empty_cache()
364
+ except Exception as e:
365
+ import traceback
366
+ traceback.print_exc()
367
+ raise gr.Error(f"run_depth_estimation failed: {e}")
368
 
369
  return save_ply_path
370