Spaces:
Build error
Build error
| # based on https://gist.github.com/ranftlr/45f4c7ddeb1bbb88d606bc600cab6c8d | |
| import torch | |
| class DepthMetric: | |
| def __init__(self, thresholds=[1.25, 1.25**2, 1.25**3], depth_cap=None, prediction_type='depth'): | |
| self.thresholds = thresholds | |
| self.depth_cap = depth_cap | |
| self.metric_keys = self.get_metric_keys() | |
| self.prediction_type = prediction_type | |
| def compute_scale_and_shift(self, prediction, target, mask): | |
| # system matrix: A = [[a_00, a_01], [a_10, a_11]] | |
| a_00 = torch.sum(mask * prediction * prediction, (1, 2)) | |
| a_01 = torch.sum(mask * prediction, (1, 2)) | |
| a_11 = torch.sum(mask, (1, 2)) | |
| # right hand side: b = [b_0, b_1] | |
| b_0 = torch.sum(mask * prediction * target, (1, 2)) | |
| b_1 = torch.sum(mask * target, (1, 2)) | |
| # solution: x = A^-1 . b = [[a_11, -a_01], [-a_10, a_00]] / (a_00 * a_11 - a_01 * a_10) . b | |
| x_0 = torch.zeros_like(b_0) | |
| x_1 = torch.zeros_like(b_1) | |
| det = a_00 * a_11 - a_01 * a_01 | |
| # A needs to be a positive definite matrix. | |
| valid = det > 0 | |
| x_0[valid] = (a_11[valid] * b_0[valid] - a_01[valid] * b_1[valid]) / det[valid] | |
| x_1[valid] = (-a_01[valid] * b_0[valid] + a_00[valid] * b_1[valid]) / det[valid] | |
| return x_0, x_1 | |
| def get_metric_keys(self): | |
| metric_keys = [] | |
| for threshold in self.thresholds: | |
| metric_keys.append('d>{}'.format(threshold)) | |
| metric_keys.append('rmse') | |
| metric_keys.append('l1_err') | |
| metric_keys.append('abs_rel') | |
| return metric_keys | |
| def compute_metrics(self, prediction, target, mask): | |
| # check inputs | |
| prediction = prediction.float() | |
| target = target.float() | |
| mask = mask.float() | |
| assert prediction.shape == target.shape == mask.shape | |
| assert len(prediction.shape) == 4 | |
| assert prediction.shape[1] == 1 | |
| assert prediction.dtype == target.dtype == mask.dtype == torch.float32 | |
| # process inputs | |
| prediction = prediction.squeeze(1) | |
| target = target.squeeze(1) | |
| mask = (mask.squeeze(1) > 0.5).long() | |
| # output dict | |
| metrics = {} | |
| # get the predicted disparity | |
| prediction_disparity = torch.zeros_like(prediction) | |
| if self.prediction_type == 'depth': | |
| prediction_disparity[mask == 1] = 1.0 / (prediction[mask == 1] + 1.e-6) | |
| elif self.prediction_type == 'disparity': | |
| prediction_disparity[mask == 1] = prediction[mask == 1] | |
| else: | |
| raise ValueError('Unknown prediction type: {}'.format(self.prediction_type)) | |
| # transform predicted disparity to align with depth | |
| target_disparity = torch.zeros_like(target) | |
| target_disparity[mask == 1] = 1.0 / target[mask == 1] | |
| scale, shift = self.compute_scale_and_shift(prediction_disparity, target_disparity, mask) | |
| prediction_aligned = scale.view(-1, 1, 1) * prediction_disparity + shift.view(-1, 1, 1) | |
| if self.depth_cap is not None: | |
| disparity_cap = 1.0 / self.depth_cap | |
| prediction_aligned[prediction_aligned < disparity_cap] = disparity_cap | |
| prediciton_depth = 1.0 / prediction_aligned | |
| # delta > threshold, [batch_size, ] | |
| for threshold in self.thresholds: | |
| err = torch.zeros_like(prediciton_depth, dtype=torch.float) | |
| err[mask == 1] = torch.max( | |
| prediciton_depth[mask == 1] / target[mask == 1], | |
| target[mask == 1] / prediciton_depth[mask == 1], | |
| ) | |
| err[mask == 1] = (err[mask == 1] > threshold).float() | |
| metrics['d>{}'.format(threshold)] = torch.sum(err, (1, 2)) / torch.sum(mask, (1, 2)) | |
| # rmse, [batch_size, ] | |
| rmse = torch.zeros_like(prediciton_depth, dtype=torch.float) | |
| rmse[mask == 1] = (prediciton_depth[mask == 1] - target[mask == 1]) ** 2 | |
| rmse = torch.sum(rmse, (1, 2)) / torch.sum(mask, (1, 2)) | |
| metrics['rmse'] = torch.sqrt(rmse) | |
| # l1 error, [batch_size, ] | |
| l1_err = torch.zeros_like(prediciton_depth, dtype=torch.float) | |
| l1_err[mask == 1] = torch.abs(prediciton_depth[mask == 1] - target[mask == 1]) | |
| metrics['l1_err'] = torch.sum(l1_err, (1, 2)) / torch.sum(mask, (1, 2)) | |
| # abs_rel, [batch_size, ] | |
| abs_rel = torch.zeros_like(prediciton_depth, dtype=torch.float) | |
| abs_rel[mask == 1] = torch.abs(prediciton_depth[mask == 1] - target[mask == 1]) / target[mask == 1] | |
| metrics['abs_rel'] = torch.sum(abs_rel, (1, 2)) / torch.sum(mask, (1, 2)) | |
| return metrics, prediciton_depth.unsqueeze(1) | |