Zhen Ye commited on
Commit
91f3b56
·
1 Parent(s): 94c85d4

added fallback for Nan

Browse files
Files changed (2) hide show
  1. inference.py +24 -4
  2. models/depth_estimators/depth_pro.py +36 -17
inference.py CHANGED
@@ -433,8 +433,24 @@ def process_frames_depth(
433
 
434
  # Compute global min/max (using percentiles to handle outliers)
435
  all_depths = np.concatenate(all_values)
436
- global_min = np.percentile(all_depths, 1) # 1st percentile to clip outliers
437
- global_max = np.percentile(all_depths, 99) # 99th percentile
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
438
 
439
  logging.info(
440
  "Depth range: %.2f - %.2f meters (1st-99th percentile)",
@@ -472,11 +488,15 @@ def colorize_depth_map(
472
  """
473
  import cv2
474
 
 
 
 
 
475
  if global_max - global_min < 1e-6: # Handle uniform depth
476
- depth_norm = np.zeros_like(depth_map, dtype=np.uint8)
477
  else:
478
  # Clip to global range to handle outliers
479
- depth_clipped = np.clip(depth_map, global_min, global_max)
480
  depth_norm = ((depth_clipped - global_min) / (global_max - global_min) * 255).astype(np.uint8)
481
 
482
  # Apply TURBO colormap for vibrant, perceptually uniform visualization
 
433
 
434
  # Compute global min/max (using percentiles to handle outliers)
435
  all_depths = np.concatenate(all_values)
436
+
437
+ # Filter out NaN and inf values
438
+ valid_depths = all_depths[np.isfinite(all_depths)]
439
+
440
+ if len(valid_depths) == 0:
441
+ logging.warning("All depth values are NaN/inf - using fallback range")
442
+ global_min = 0.0
443
+ global_max = 1.0
444
+ else:
445
+ global_min = float(np.percentile(valid_depths, 1)) # 1st percentile to clip outliers
446
+ global_max = float(np.percentile(valid_depths, 99)) # 99th percentile
447
+
448
+ # Handle edge case where min == max
449
+ if abs(global_max - global_min) < 1e-6:
450
+ global_min = float(valid_depths.min())
451
+ global_max = float(valid_depths.max())
452
+ if abs(global_max - global_min) < 1e-6:
453
+ global_max = global_min + 1.0
454
 
455
  logging.info(
456
  "Depth range: %.2f - %.2f meters (1st-99th percentile)",
 
488
  """
489
  import cv2
490
 
491
+ # Replace NaN/inf with min value for visualization
492
+ depth_clean = np.copy(depth_map)
493
+ depth_clean[~np.isfinite(depth_clean)] = global_min
494
+
495
  if global_max - global_min < 1e-6: # Handle uniform depth
496
+ depth_norm = np.zeros_like(depth_clean, dtype=np.uint8)
497
  else:
498
  # Clip to global range to handle outliers
499
+ depth_clipped = np.clip(depth_clean, global_min, global_max)
500
  depth_norm = ((depth_clipped - global_min) / (global_max - global_min) * 255).astype(np.uint8)
501
 
502
  # Apply TURBO colormap for vibrant, perceptually uniform visualization
models/depth_estimators/depth_pro.py CHANGED
@@ -62,23 +62,42 @@ class DepthProEstimator(DepthEstimator):
62
  with torch.no_grad():
63
  outputs = self.model(**inputs)
64
 
65
- # Post-process to get depth and focal length
66
- post_processed = self.image_processor.post_process_depth_estimation(
67
- outputs,
68
- target_sizes=[(height, width)],
69
- )
70
-
71
- # Extract depth map and focal length
72
- depth_tensor = post_processed[0]["predicted_depth"] # Already at target size
73
- focal_length_value = post_processed[0].get("focal_length", 1.0)
74
-
75
- # Convert to numpy
76
- depth_map = depth_tensor.cpu().numpy()
77
-
78
- # focal_length might be a tensor, convert to float
79
- if isinstance(focal_length_value, torch.Tensor):
80
- focal_length = float(focal_length_value.item())
 
 
 
 
 
81
  else:
82
- focal_length = float(focal_length_value)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  return DepthResult(depth_map=depth_map, focal_length=focal_length)
 
62
  with torch.no_grad():
63
  outputs = self.model(**inputs)
64
 
65
+ # Get raw depth prediction
66
+ raw_depth = outputs.predicted_depth # Shape: [1, 1, H, W]
67
+
68
+ # Resize to target size if needed
69
+ if raw_depth.shape[-2:] != (height, width):
70
+ import torch.nn.functional as F
71
+ raw_depth = F.interpolate(
72
+ raw_depth,
73
+ size=(height, width),
74
+ mode='bilinear',
75
+ align_corners=False
76
+ )
77
+
78
+ # Convert to numpy and remove batch/channel dims
79
+ depth_map = raw_depth.squeeze().cpu().numpy() # Shape: [H, W]
80
+
81
+ # Get focal length from outputs if available
82
+ if hasattr(outputs, 'fov_deg') and outputs.fov_deg is not None:
83
+ # Convert field of view to focal length
84
+ fov_rad = outputs.fov_deg * np.pi / 180.0
85
+ focal_length = float(width / (2.0 * np.tan(fov_rad / 2.0)))
86
  else:
87
+ focal_length = 1.0
88
+
89
+ # Debug: Check for NaN values
90
+ if np.isnan(depth_map).any():
91
+ nan_count = np.isnan(depth_map).sum()
92
+ total = depth_map.size
93
+ logging.warning(
94
+ f"Depth map contains {nan_count}/{total} ({100*nan_count/total:.1f}%) NaN values"
95
+ )
96
+ logging.warning(f"Depth map shape: {depth_map.shape}, dtype: {depth_map.dtype}")
97
+ valid_depths = depth_map[np.isfinite(depth_map)]
98
+ if len(valid_depths) > 0:
99
+ logging.warning(
100
+ f"Valid depth range: {valid_depths.min():.4f} - {valid_depths.max():.4f}"
101
+ )
102
 
103
  return DepthResult(depth_map=depth_map, focal_length=focal_length)