Spaces:
Configuration error
Configuration error
| import torch | |
| import numpy as np | |
| import cv2 | |
| import glob | |
| import argparse | |
| from pathlib import Path | |
| from tqdm import tqdm | |
| from copy import deepcopy | |
| from scipy.optimize import minimize | |
| import os | |
| from collections import defaultdict | |
| def group_by_directory(pathes, idx=-1): | |
| """ | |
| Groups the file paths based on the second-to-last directory in their paths. | |
| Parameters: | |
| - pathes (list): List of file paths. | |
| Returns: | |
| - dict: A dictionary where keys are the second-to-last directory names and values are lists of file paths. | |
| """ | |
| grouped_pathes = defaultdict(list) | |
| for path in pathes: | |
| # Extract the second-to-last directory | |
| dir_name = os.path.dirname(path).split("/")[idx] | |
| grouped_pathes[dir_name].append(path) | |
| return grouped_pathes | |
| def depth2disparity(depth, return_mask=False): | |
| if isinstance(depth, torch.Tensor): | |
| disparity = torch.zeros_like(depth) | |
| elif isinstance(depth, np.ndarray): | |
| disparity = np.zeros_like(depth) | |
| non_negtive_mask = depth > 0 | |
| disparity[non_negtive_mask] = 1.0 / depth[non_negtive_mask] | |
| if return_mask: | |
| return disparity, non_negtive_mask | |
| else: | |
| return disparity | |
| def absolute_error_loss(params, predicted_depth, ground_truth_depth): | |
| s, t = params | |
| predicted_aligned = s * predicted_depth + t | |
| abs_error = np.abs(predicted_aligned - ground_truth_depth) | |
| return np.sum(abs_error) | |
| def absolute_value_scaling(predicted_depth, ground_truth_depth, s=1, t=0): | |
| predicted_depth_np = predicted_depth.cpu().numpy().reshape(-1) | |
| ground_truth_depth_np = ground_truth_depth.cpu().numpy().reshape(-1) | |
| initial_params = [s, t] # s = 1, t = 0 | |
| result = minimize( | |
| absolute_error_loss, | |
| initial_params, | |
| args=(predicted_depth_np, ground_truth_depth_np), | |
| ) | |
| s, t = result.x | |
| return s, t | |
| def absolute_value_scaling2( | |
| predicted_depth, | |
| ground_truth_depth, | |
| s_init=1.0, | |
| t_init=0.0, | |
| lr=1e-4, | |
| max_iters=1000, | |
| tol=1e-6, | |
| ): | |
| # Initialize s and t as torch tensors with requires_grad=True | |
| s = torch.tensor( | |
| [s_init], | |
| requires_grad=True, | |
| device=predicted_depth.device, | |
| dtype=predicted_depth.dtype, | |
| ) | |
| t = torch.tensor( | |
| [t_init], | |
| requires_grad=True, | |
| device=predicted_depth.device, | |
| dtype=predicted_depth.dtype, | |
| ) | |
| optimizer = torch.optim.Adam([s, t], lr=lr) | |
| prev_loss = None | |
| for i in range(max_iters): | |
| optimizer.zero_grad() | |
| # Compute predicted aligned depth | |
| predicted_aligned = s * predicted_depth + t | |
| # Compute absolute error | |
| abs_error = torch.abs(predicted_aligned - ground_truth_depth) | |
| # Compute loss | |
| loss = torch.sum(abs_error) | |
| # Backpropagate | |
| loss.backward() | |
| # Update parameters | |
| optimizer.step() | |
| # Check convergence | |
| if prev_loss is not None and torch.abs(prev_loss - loss) < tol: | |
| break | |
| prev_loss = loss.item() | |
| return s.detach().item(), t.detach().item() | |
| def depth_evaluation( | |
| predicted_depth_original, | |
| ground_truth_depth_original, | |
| max_depth=80, | |
| custom_mask=None, | |
| post_clip_min=None, | |
| post_clip_max=None, | |
| pre_clip_min=None, | |
| pre_clip_max=None, | |
| align_with_lstsq=False, | |
| align_with_lad=False, | |
| align_with_lad2=False, | |
| metric_scale=False, | |
| lr=1e-4, | |
| max_iters=1000, | |
| use_gpu=False, | |
| align_with_scale=False, | |
| disp_input=False, | |
| ): | |
| """ | |
| Evaluate the depth map using various metrics and return a depth error parity map, with an option for least squares alignment. | |
| Args: | |
| predicted_depth (numpy.ndarray or torch.Tensor): The predicted depth map. | |
| ground_truth_depth (numpy.ndarray or torch.Tensor): The ground truth depth map. | |
| max_depth (float): The maximum depth value to consider. Default is 80 meters. | |
| align_with_lstsq (bool): If True, perform least squares alignment of the predicted depth with ground truth. | |
| Returns: | |
| dict: A dictionary containing the evaluation metrics. | |
| torch.Tensor: The depth error parity map. | |
| """ | |
| if isinstance(predicted_depth_original, np.ndarray): | |
| predicted_depth_original = torch.from_numpy(predicted_depth_original) | |
| if isinstance(ground_truth_depth_original, np.ndarray): | |
| ground_truth_depth_original = torch.from_numpy(ground_truth_depth_original) | |
| if custom_mask is not None and isinstance(custom_mask, np.ndarray): | |
| custom_mask = torch.from_numpy(custom_mask) | |
| # if the dimension is 3, flatten to 2d along the batch dimension | |
| if predicted_depth_original.dim() == 3: | |
| _, h, w = predicted_depth_original.shape | |
| predicted_depth_original = predicted_depth_original.view(-1, w) | |
| ground_truth_depth_original = ground_truth_depth_original.view(-1, w) | |
| if custom_mask is not None: | |
| custom_mask = custom_mask.view(-1, w) | |
| # put to device | |
| if use_gpu: | |
| predicted_depth_original = predicted_depth_original.cuda() | |
| ground_truth_depth_original = ground_truth_depth_original.cuda() | |
| # Filter out depths greater than max_depth | |
| if max_depth is not None: | |
| mask = (ground_truth_depth_original > 0) & ( | |
| ground_truth_depth_original < max_depth | |
| ) | |
| else: | |
| mask = ground_truth_depth_original > 0 | |
| predicted_depth = predicted_depth_original[mask] | |
| ground_truth_depth = ground_truth_depth_original[mask] | |
| # Clip the depth values | |
| if pre_clip_min is not None: | |
| predicted_depth = torch.clamp(predicted_depth, min=pre_clip_min) | |
| if pre_clip_max is not None: | |
| predicted_depth = torch.clamp(predicted_depth, max=pre_clip_max) | |
| if disp_input: # align the pred to gt in the disparity space | |
| real_gt = ground_truth_depth.clone() | |
| ground_truth_depth = 1 / (ground_truth_depth + 1e-8) | |
| # various alignment methods | |
| if metric_scale: | |
| predicted_depth = predicted_depth | |
| elif align_with_lstsq: | |
| # Convert to numpy for lstsq | |
| predicted_depth_np = predicted_depth.cpu().numpy().reshape(-1, 1) | |
| ground_truth_depth_np = ground_truth_depth.cpu().numpy().reshape(-1, 1) | |
| # Add a column of ones for the shift term | |
| A = np.hstack([predicted_depth_np, np.ones_like(predicted_depth_np)]) | |
| # Solve for scale (s) and shift (t) using least squares | |
| result = np.linalg.lstsq(A, ground_truth_depth_np, rcond=None) | |
| s, t = result[0][0], result[0][1] | |
| # convert to torch tensor | |
| s = torch.tensor(s, device=predicted_depth_original.device) | |
| t = torch.tensor(t, device=predicted_depth_original.device) | |
| # Apply scale and shift | |
| predicted_depth = s * predicted_depth + t | |
| elif align_with_lad: | |
| s, t = absolute_value_scaling( | |
| predicted_depth, | |
| ground_truth_depth, | |
| s=torch.median(ground_truth_depth) / torch.median(predicted_depth), | |
| ) | |
| predicted_depth = s * predicted_depth + t | |
| elif align_with_lad2: | |
| s_init = ( | |
| torch.median(ground_truth_depth) / torch.median(predicted_depth) | |
| ).item() | |
| s, t = absolute_value_scaling2( | |
| predicted_depth, | |
| ground_truth_depth, | |
| s_init=s_init, | |
| lr=lr, | |
| max_iters=max_iters, | |
| ) | |
| predicted_depth = s * predicted_depth + t | |
| elif align_with_scale: | |
| # Compute initial scale factor 's' using the closed-form solution (L2 norm) | |
| dot_pred_gt = torch.nanmean(ground_truth_depth) | |
| dot_pred_pred = torch.nanmean(predicted_depth) | |
| s = dot_pred_gt / dot_pred_pred | |
| # Iterative reweighted least squares using the Weiszfeld method | |
| for _ in range(10): | |
| # Compute residuals between scaled predictions and ground truth | |
| residuals = s * predicted_depth - ground_truth_depth | |
| abs_residuals = ( | |
| residuals.abs() + 1e-8 | |
| ) # Add small constant to avoid division by zero | |
| # Compute weights inversely proportional to the residuals | |
| weights = 1.0 / abs_residuals | |
| # Update 's' using weighted sums | |
| weighted_dot_pred_gt = torch.sum( | |
| weights * predicted_depth * ground_truth_depth | |
| ) | |
| weighted_dot_pred_pred = torch.sum(weights * predicted_depth**2) | |
| s = weighted_dot_pred_gt / weighted_dot_pred_pred | |
| # Optionally clip 's' to prevent extreme scaling | |
| s = s.clamp(min=1e-3) | |
| # Detach 's' if you want to stop gradients from flowing through it | |
| s = s.detach() | |
| # Apply the scale factor to the predicted depth | |
| predicted_depth = s * predicted_depth | |
| else: | |
| # Align the predicted depth with the ground truth using median scaling | |
| scale_factor = torch.median(ground_truth_depth) / torch.median(predicted_depth) | |
| predicted_depth *= scale_factor | |
| if disp_input: | |
| # convert back to depth | |
| ground_truth_depth = real_gt | |
| predicted_depth = depth2disparity(predicted_depth) | |
| # Clip the predicted depth values | |
| if post_clip_min is not None: | |
| predicted_depth = torch.clamp(predicted_depth, min=post_clip_min) | |
| if post_clip_max is not None: | |
| predicted_depth = torch.clamp(predicted_depth, max=post_clip_max) | |
| if custom_mask is not None: | |
| assert custom_mask.shape == ground_truth_depth_original.shape | |
| mask_within_mask = custom_mask.cpu()[mask] | |
| predicted_depth = predicted_depth[mask_within_mask] | |
| ground_truth_depth = ground_truth_depth[mask_within_mask] | |
| # Calculate the metrics | |
| abs_rel = torch.mean( | |
| torch.abs(predicted_depth - ground_truth_depth) / ground_truth_depth | |
| ).item() | |
| sq_rel = torch.mean( | |
| ((predicted_depth - ground_truth_depth) ** 2) / ground_truth_depth | |
| ).item() | |
| # Correct RMSE calculation | |
| rmse = torch.sqrt(torch.mean((predicted_depth - ground_truth_depth) ** 2)).item() | |
| # Clip the depth values to avoid log(0) | |
| predicted_depth = torch.clamp(predicted_depth, min=1e-5) | |
| log_rmse = torch.sqrt( | |
| torch.mean((torch.log(predicted_depth) - torch.log(ground_truth_depth)) ** 2) | |
| ).item() | |
| # Calculate the accuracy thresholds | |
| max_ratio = torch.maximum( | |
| predicted_depth / ground_truth_depth, ground_truth_depth / predicted_depth | |
| ) | |
| threshold_0 = torch.mean((max_ratio < 1.0).float()).item() | |
| threshold_1 = torch.mean((max_ratio < 1.25).float()).item() | |
| threshold_2 = torch.mean((max_ratio < 1.25**2).float()).item() | |
| threshold_3 = torch.mean((max_ratio < 1.25**3).float()).item() | |
| # Compute the depth error parity map | |
| if metric_scale: | |
| predicted_depth_original = predicted_depth_original | |
| if disp_input: | |
| predicted_depth_original = depth2disparity(predicted_depth_original) | |
| depth_error_parity_map = ( | |
| torch.abs(predicted_depth_original - ground_truth_depth_original) | |
| / ground_truth_depth_original | |
| ) | |
| elif align_with_lstsq or align_with_lad or align_with_lad2: | |
| predicted_depth_original = predicted_depth_original * s + t | |
| if disp_input: | |
| predicted_depth_original = depth2disparity(predicted_depth_original) | |
| depth_error_parity_map = ( | |
| torch.abs(predicted_depth_original - ground_truth_depth_original) | |
| / ground_truth_depth_original | |
| ) | |
| elif align_with_scale: | |
| predicted_depth_original = predicted_depth_original * s | |
| if disp_input: | |
| predicted_depth_original = depth2disparity(predicted_depth_original) | |
| depth_error_parity_map = ( | |
| torch.abs(predicted_depth_original - ground_truth_depth_original) | |
| / ground_truth_depth_original | |
| ) | |
| else: | |
| predicted_depth_original = predicted_depth_original * scale_factor | |
| if disp_input: | |
| predicted_depth_original = depth2disparity(predicted_depth_original) | |
| depth_error_parity_map = ( | |
| torch.abs(predicted_depth_original - ground_truth_depth_original) | |
| / ground_truth_depth_original | |
| ) | |
| # Reshape the depth_error_parity_map back to the original image size | |
| depth_error_parity_map_full = torch.zeros_like(ground_truth_depth_original) | |
| depth_error_parity_map_full = torch.where( | |
| mask, depth_error_parity_map, depth_error_parity_map_full | |
| ) | |
| predict_depth_map_full = predicted_depth_original | |
| gt_depth_map_full = torch.zeros_like(ground_truth_depth_original) | |
| gt_depth_map_full = torch.where( | |
| mask, ground_truth_depth_original, gt_depth_map_full | |
| ) | |
| num_valid_pixels = ( | |
| torch.sum(mask).item() | |
| if custom_mask is None | |
| else torch.sum(mask_within_mask).item() | |
| ) | |
| if num_valid_pixels == 0: | |
| ( | |
| abs_rel, | |
| sq_rel, | |
| rmse, | |
| log_rmse, | |
| threshold_0, | |
| threshold_1, | |
| threshold_2, | |
| threshold_3, | |
| ) = (0, 0, 0, 0, 0, 0, 0, 0) | |
| results = { | |
| "Abs Rel": abs_rel, | |
| "Sq Rel": sq_rel, | |
| "RMSE": rmse, | |
| "Log RMSE": log_rmse, | |
| "δ < 1.": threshold_0, | |
| "δ < 1.25": threshold_1, | |
| "δ < 1.25^2": threshold_2, | |
| "δ < 1.25^3": threshold_3, | |
| "valid_pixels": num_valid_pixels, | |
| } | |
| return ( | |
| results, | |
| depth_error_parity_map_full, | |
| predict_depth_map_full, | |
| gt_depth_map_full, | |
| ) | |