Zhen Ye commited on
Commit
18ba97a
·
1 Parent(s): 91f3b56

fixed shape mismatch

Browse files
Files changed (1) hide show
  1. models/depth_estimators/depth_pro.py +80 -51
models/depth_estimators/depth_pro.py CHANGED
@@ -48,56 +48,85 @@ class DepthProEstimator(DepthEstimator):
48
  Returns:
49
  DepthResult with depth_map (HxW float32 in meters) and focal_length
50
  """
51
- # Convert BGR to RGB
52
- rgb_frame = frame[:, :, ::-1] # BGR RGB
53
-
54
- # Convert to PIL Image
55
- pil_image = Image.fromarray(rgb_frame)
56
- height, width = pil_image.height, pil_image.width
57
-
58
- # Preprocess image
59
- inputs = self.image_processor(images=pil_image, return_tensors="pt").to(self.device)
60
-
61
- # Run inference (no gradient needed)
62
- with torch.no_grad():
63
- outputs = self.model(**inputs)
64
-
65
- # Get raw depth prediction
66
- raw_depth = outputs.predicted_depth # Shape: [1, 1, H, W]
67
-
68
- # Resize to target size if needed
69
- if raw_depth.shape[-2:] != (height, width):
70
- import torch.nn.functional as F
71
- raw_depth = F.interpolate(
72
- raw_depth,
73
- size=(height, width),
74
- mode='bilinear',
75
- align_corners=False
76
- )
77
-
78
- # Convert to numpy and remove batch/channel dims
79
- depth_map = raw_depth.squeeze().cpu().numpy() # Shape: [H, W]
80
-
81
- # Get focal length from outputs if available
82
- if hasattr(outputs, 'fov_deg') and outputs.fov_deg is not None:
83
- # Convert field of view to focal length
84
- fov_rad = outputs.fov_deg * np.pi / 180.0
85
- focal_length = float(width / (2.0 * np.tan(fov_rad / 2.0)))
86
- else:
87
- focal_length = 1.0
88
-
89
- # Debug: Check for NaN values
90
- if np.isnan(depth_map).any():
91
- nan_count = np.isnan(depth_map).sum()
92
- total = depth_map.size
93
- logging.warning(
94
- f"Depth map contains {nan_count}/{total} ({100*nan_count/total:.1f}%) NaN values"
95
- )
96
- logging.warning(f"Depth map shape: {depth_map.shape}, dtype: {depth_map.dtype}")
97
- valid_depths = depth_map[np.isfinite(depth_map)]
98
- if len(valid_depths) > 0:
99
- logging.warning(
100
- f"Valid depth range: {valid_depths.min():.4f} - {valid_depths.max():.4f}"
101
  )
102
 
103
- return DepthResult(depth_map=depth_map, focal_length=focal_length)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  Returns:
49
  DepthResult with depth_map (HxW float32 in meters) and focal_length
50
  """
51
+ try:
52
+ # Convert BGR to RGB
53
+ rgb_frame = frame[:, :, ::-1] # BGR → RGB
54
+
55
+ # Convert to PIL Image
56
+ pil_image = Image.fromarray(rgb_frame)
57
+ height, width = pil_image.height, pil_image.width
58
+
59
+ # Preprocess image
60
+ inputs = self.image_processor(images=pil_image, return_tensors="pt").to(self.device)
61
+
62
+ # Run inference (no gradient needed)
63
+ with torch.no_grad():
64
+ outputs = self.model(**inputs)
65
+
66
+ # Debug: Inspect output structure
67
+ logging.debug(f"Model outputs type: {type(outputs)}")
68
+ logging.debug(f"Model outputs keys: {outputs.keys() if hasattr(outputs, 'keys') else 'N/A'}")
69
+
70
+ # Get raw depth prediction - the shape varies by model
71
+ raw_depth = outputs.predicted_depth
72
+
73
+ # Log the actual shape for debugging
74
+ logging.info(f"Raw depth shape: {raw_depth.shape}, dtype: {raw_depth.dtype}")
75
+
76
+ # Ensure we have a 4D tensor [B, C, H, W]
77
+ if raw_depth.dim() == 2:
78
+ # [H, W] -> [1, 1, H, W]
79
+ raw_depth = raw_depth.unsqueeze(0).unsqueeze(0)
80
+ elif raw_depth.dim() == 3:
81
+ # [B, H, W] or [C, H, W] -> [1, 1, H, W]
82
+ raw_depth = raw_depth.unsqueeze(1) if raw_depth.shape[0] == 1 else raw_depth.unsqueeze(0)
83
+ elif raw_depth.dim() == 1:
84
+ # This is unexpected - possibly a flattened output
85
+ # Try to reshape based on expected output size
86
+ expected_size = 1536 # Model's default output size
87
+ raw_depth = raw_depth.reshape(1, 1, expected_size, expected_size)
88
+
89
+ # Now resize to target size
90
+ if raw_depth.shape[-2:] != (height, width):
91
+ import torch.nn.functional as F
92
+ raw_depth = F.interpolate(
93
+ raw_depth,
94
+ size=(height, width),
95
+ mode='bilinear',
96
+ align_corners=False
 
 
 
 
97
  )
98
 
99
+ # Convert to numpy and remove batch/channel dims
100
+ depth_map = raw_depth.squeeze().cpu().numpy() # Shape: [H, W]
101
+
102
+ # Get focal length from outputs if available
103
+ if hasattr(outputs, 'fov_deg') and outputs.fov_deg is not None:
104
+ # Convert field of view to focal length
105
+ fov_rad = float(outputs.fov_deg) * np.pi / 180.0
106
+ focal_length = float(width / (2.0 * np.tan(fov_rad / 2.0)))
107
+ else:
108
+ focal_length = 1.0
109
+
110
+ # Debug: Check for NaN values
111
+ if np.isnan(depth_map).any():
112
+ nan_count = np.isnan(depth_map).sum()
113
+ total = depth_map.size
114
+ logging.warning(
115
+ f"Depth map contains {nan_count}/{total} ({100*nan_count/total:.1f}%) NaN values"
116
+ )
117
+ logging.warning(f"Depth map shape: {depth_map.shape}, dtype: {depth_map.dtype}")
118
+ valid_depths = depth_map[np.isfinite(depth_map)]
119
+ if len(valid_depths) > 0:
120
+ logging.warning(
121
+ f"Valid depth range: {valid_depths.min():.4f} - {valid_depths.max():.4f}"
122
+ )
123
+
124
+ return DepthResult(depth_map=depth_map, focal_length=focal_length)
125
+
126
+ except Exception as e:
127
+ logging.error(f"Depth estimation failed: {e}")
128
+ logging.error(f"Frame shape: {frame.shape}")
129
+ # Return a blank depth map as fallback
130
+ h, w = frame.shape[:2]
131
+ depth_map = np.zeros((h, w), dtype=np.float32)
132
+ return DepthResult(depth_map=depth_map, focal_length=1.0)