| | from functools import cache |
| |
|
| | import torch |
| | from einops import reduce |
| | from jaxtyping import Float |
| | from lpips import LPIPS |
| | from skimage.metrics import structural_similarity |
| | from torch import Tensor |
| |
|
| |
|
| | @torch.no_grad() |
| | def compute_psnr( |
| | ground_truth: Float[Tensor, "batch channel height width"], |
| | predicted: Float[Tensor, "batch channel height width"], |
| | ) -> Float[Tensor, " batch"]: |
| | ground_truth = ground_truth.clip(min=0, max=1) |
| | predicted = predicted.clip(min=0, max=1) |
| | mse = reduce((ground_truth - predicted) ** 2, "b c h w -> b", "mean") |
| | return -10 * mse.log10() |
| |
|
| |
|
| | @cache |
| | def get_lpips(device: torch.device) -> LPIPS: |
| | return LPIPS(net="vgg").to(device) |
| |
|
| |
|
| | @torch.no_grad() |
| | def compute_lpips( |
| | ground_truth: Float[Tensor, "batch channel height width"], |
| | predicted: Float[Tensor, "batch channel height width"], |
| | ) -> Float[Tensor, " batch"]: |
| | value = get_lpips(predicted.device).forward(ground_truth, predicted, normalize=True) |
| | return value[:, 0, 0, 0] |
| |
|
| |
|
| | @torch.no_grad() |
| | def compute_ssim( |
| | ground_truth: Float[Tensor, "batch channel height width"], |
| | predicted: Float[Tensor, "batch channel height width"], |
| | ) -> Float[Tensor, " batch"]: |
| | ssim = [ |
| | structural_similarity( |
| | gt.detach().cpu().numpy(), |
| | hat.detach().cpu().numpy(), |
| | win_size=11, |
| | gaussian_weights=True, |
| | channel_axis=0, |
| | data_range=1.0, |
| | ) |
| | for gt, hat in zip(ground_truth, predicted) |
| | ] |
| | return torch.tensor(ssim, dtype=predicted.dtype, device=predicted.device) |
| |
|
| |
|
| | def compute_geodesic_distance_from_two_matrices(m1, m2): |
| | batch = m1.shape[0] |
| | m = torch.bmm(m1, m2.transpose(1, 2)) |
| |
|
| | cos = (m[:, 0, 0] + m[:, 1, 1] + m[:, 2, 2] - 1) / 2 |
| | cos = torch.min(cos, torch.autograd.Variable(torch.ones(batch).to(m1.device))) |
| | cos = torch.max(cos, torch.autograd.Variable(torch.ones(batch).to(m1.device)) * -1) |
| |
|
| | theta = torch.acos(cos) |
| |
|
| | |
| |
|
| | return theta |
| |
|
| |
|
| | def angle_error_mat(R1, R2): |
| | cos = (torch.trace(torch.mm(R1.T, R2)) - 1) / 2 |
| | cos = torch.clamp(cos, -1.0, 1.0) |
| | return torch.rad2deg(torch.abs(torch.acos(cos))) |
| |
|
| |
|
| | def angle_error_vec(v1, v2): |
| | n = torch.norm(v1) * torch.norm(v2) |
| | cos_theta = torch.dot(v1, v2) / n |
| | cos_theta = torch.clamp(cos_theta, -1.0, 1.0) |
| | return torch.rad2deg(torch.acos(cos_theta)) |
| |
|
| |
|
| | def compute_translation_error(t1, t2): |
| | return torch.norm(t1 - t2) |
| |
|
| |
|
| | @torch.no_grad() |
| | def compute_pose_error(pose_gt, pose_pred): |
| | R_gt = pose_gt[:3, :3] |
| | t_gt = pose_gt[:3, 3] |
| |
|
| | R = pose_pred[:3, :3] |
| | t = pose_pred[:3, 3] |
| |
|
| | error_t = angle_error_vec(t, t_gt) |
| | error_t = torch.minimum(error_t, 180 - error_t) |
| | error_t_scale = compute_translation_error(t, t_gt) |
| | error_R = angle_error_mat(R, R_gt) |
| | return error_t, error_t_scale, error_R |
| |
|
| | @torch.no_grad() |
| | def abs_relative_difference(output, target, valid_mask=None): |
| | actual_output = output |
| | actual_target = target |
| | abs_relative_diff = torch.abs(actual_output - actual_target) / actual_target |
| | if valid_mask is not None: |
| | abs_relative_diff[~valid_mask] = 0 |
| | n = valid_mask.sum((-1, -2)) |
| | else: |
| | n = output.shape[-1] * output.shape[-2] |
| | abs_relative_diff = torch.sum(abs_relative_diff, (-1, -2)) / n |
| | return abs_relative_diff.mean() |
| |
|
| | |
| | @torch.no_grad() |
| | def threshold_percentage(output, target, threshold_val, valid_mask=None): |
| | d1 = output / target |
| | d2 = target / output |
| | max_d1_d2 = torch.max(d1, d2) |
| | zero = torch.zeros_like(output) |
| | one = torch.ones_like(output) |
| | bit_mat = torch.where(max_d1_d2 < threshold_val, one, zero) |
| | if valid_mask is not None: |
| | bit_mat[~valid_mask] = 0 |
| | n = valid_mask.sum((-1, -2)) |
| | else: |
| | n = output.shape[-1] * output.shape[-2] |
| | count_mat = torch.sum(bit_mat, (-1, -2)) |
| | threshold_mat = count_mat / n |
| | return threshold_mat.mean() |
| |
|
| | @torch.no_grad() |
| | def delta1_acc(pred, gt, valid_mask): |
| | return threshold_percentage(pred, gt, 1.25, valid_mask) |