Spaces:
Sleeping
Sleeping
Zhen Ye
commited on
Commit
·
18ba97a
1
Parent(s):
91f3b56
fixed shape mismatch
Browse files
models/depth_estimators/depth_pro.py
CHANGED
|
@@ -48,56 +48,85 @@ class DepthProEstimator(DepthEstimator):
|
|
| 48 |
Returns:
|
| 49 |
DepthResult with depth_map (HxW float32 in meters) and focal_length
|
| 50 |
"""
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
raw_depth =
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
valid_depths = depth_map[np.isfinite(depth_map)]
|
| 98 |
-
if len(valid_depths) > 0:
|
| 99 |
-
logging.warning(
|
| 100 |
-
f"Valid depth range: {valid_depths.min():.4f} - {valid_depths.max():.4f}"
|
| 101 |
)
|
| 102 |
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
Returns:
|
| 49 |
DepthResult with depth_map (HxW float32 in meters) and focal_length
|
| 50 |
"""
|
| 51 |
+
try:
|
| 52 |
+
# Convert BGR to RGB
|
| 53 |
+
rgb_frame = frame[:, :, ::-1] # BGR → RGB
|
| 54 |
+
|
| 55 |
+
# Convert to PIL Image
|
| 56 |
+
pil_image = Image.fromarray(rgb_frame)
|
| 57 |
+
height, width = pil_image.height, pil_image.width
|
| 58 |
+
|
| 59 |
+
# Preprocess image
|
| 60 |
+
inputs = self.image_processor(images=pil_image, return_tensors="pt").to(self.device)
|
| 61 |
+
|
| 62 |
+
# Run inference (no gradient needed)
|
| 63 |
+
with torch.no_grad():
|
| 64 |
+
outputs = self.model(**inputs)
|
| 65 |
+
|
| 66 |
+
# Debug: Inspect output structure
|
| 67 |
+
logging.debug(f"Model outputs type: {type(outputs)}")
|
| 68 |
+
logging.debug(f"Model outputs keys: {outputs.keys() if hasattr(outputs, 'keys') else 'N/A'}")
|
| 69 |
+
|
| 70 |
+
# Get raw depth prediction - the shape varies by model
|
| 71 |
+
raw_depth = outputs.predicted_depth
|
| 72 |
+
|
| 73 |
+
# Log the actual shape for debugging
|
| 74 |
+
logging.info(f"Raw depth shape: {raw_depth.shape}, dtype: {raw_depth.dtype}")
|
| 75 |
+
|
| 76 |
+
# Ensure we have a 4D tensor [B, C, H, W]
|
| 77 |
+
if raw_depth.dim() == 2:
|
| 78 |
+
# [H, W] -> [1, 1, H, W]
|
| 79 |
+
raw_depth = raw_depth.unsqueeze(0).unsqueeze(0)
|
| 80 |
+
elif raw_depth.dim() == 3:
|
| 81 |
+
# [B, H, W] or [C, H, W] -> [1, 1, H, W]
|
| 82 |
+
raw_depth = raw_depth.unsqueeze(1) if raw_depth.shape[0] == 1 else raw_depth.unsqueeze(0)
|
| 83 |
+
elif raw_depth.dim() == 1:
|
| 84 |
+
# This is unexpected - possibly a flattened output
|
| 85 |
+
# Try to reshape based on expected output size
|
| 86 |
+
expected_size = 1536 # Model's default output size
|
| 87 |
+
raw_depth = raw_depth.reshape(1, 1, expected_size, expected_size)
|
| 88 |
+
|
| 89 |
+
# Now resize to target size
|
| 90 |
+
if raw_depth.shape[-2:] != (height, width):
|
| 91 |
+
import torch.nn.functional as F
|
| 92 |
+
raw_depth = F.interpolate(
|
| 93 |
+
raw_depth,
|
| 94 |
+
size=(height, width),
|
| 95 |
+
mode='bilinear',
|
| 96 |
+
align_corners=False
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
)
|
| 98 |
|
| 99 |
+
# Convert to numpy and remove batch/channel dims
|
| 100 |
+
depth_map = raw_depth.squeeze().cpu().numpy() # Shape: [H, W]
|
| 101 |
+
|
| 102 |
+
# Get focal length from outputs if available
|
| 103 |
+
if hasattr(outputs, 'fov_deg') and outputs.fov_deg is not None:
|
| 104 |
+
# Convert field of view to focal length
|
| 105 |
+
fov_rad = float(outputs.fov_deg) * np.pi / 180.0
|
| 106 |
+
focal_length = float(width / (2.0 * np.tan(fov_rad / 2.0)))
|
| 107 |
+
else:
|
| 108 |
+
focal_length = 1.0
|
| 109 |
+
|
| 110 |
+
# Debug: Check for NaN values
|
| 111 |
+
if np.isnan(depth_map).any():
|
| 112 |
+
nan_count = np.isnan(depth_map).sum()
|
| 113 |
+
total = depth_map.size
|
| 114 |
+
logging.warning(
|
| 115 |
+
f"Depth map contains {nan_count}/{total} ({100*nan_count/total:.1f}%) NaN values"
|
| 116 |
+
)
|
| 117 |
+
logging.warning(f"Depth map shape: {depth_map.shape}, dtype: {depth_map.dtype}")
|
| 118 |
+
valid_depths = depth_map[np.isfinite(depth_map)]
|
| 119 |
+
if len(valid_depths) > 0:
|
| 120 |
+
logging.warning(
|
| 121 |
+
f"Valid depth range: {valid_depths.min():.4f} - {valid_depths.max():.4f}"
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
return DepthResult(depth_map=depth_map, focal_length=focal_length)
|
| 125 |
+
|
| 126 |
+
except Exception as e:
|
| 127 |
+
logging.error(f"Depth estimation failed: {e}")
|
| 128 |
+
logging.error(f"Frame shape: {frame.shape}")
|
| 129 |
+
# Return a blank depth map as fallback
|
| 130 |
+
h, w = frame.shape[:2]
|
| 131 |
+
depth_map = np.zeros((h, w), dtype=np.float32)
|
| 132 |
+
return DepthResult(depth_map=depth_map, focal_length=1.0)
|