| import torch | |
| import math | |
| import numpy as np | |
| from typing import List, Dict, Tuple, Callable, Union | |
| def min_ade( | |
| traj: torch.Tensor, traj_gt: torch.Tensor, masks: torch.Tensor | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Computes average displacement error for the best trajectory is a set, | |
| with respect to ground truth | |
| :param traj: predictions, shape [batch_size, num_modes, sequence_length, 2] | |
| :param traj_gt: ground truth trajectory, shape | |
| [batch_size, sequence_length, 2] | |
| :param masks: masks for varying length ground truth, shape | |
| [batch_size, sequence_length] | |
| :return errs, inds: errors and indices for modes with min error, shape | |
| [batch_size] | |
| """ | |
| num_modes = traj.shape[1] | |
| traj_gt_rpt = traj_gt.unsqueeze(1).repeat(1, num_modes, 1, 1) | |
| masks_rpt = masks.unsqueeze(1).repeat(1, num_modes, 1) | |
| err = traj_gt_rpt - traj[:, :, :, 0:2] | |
| err = torch.pow(err, exponent=2) | |
| err = torch.sum(err, dim=3) | |
| err = torch.pow(err, exponent=0.5) | |
| err = torch.sum(err * (1 - masks_rpt), dim=2) / torch.clip( | |
| torch.sum((1 - masks_rpt), dim=2), min=1 | |
| ) | |
| err, inds = torch.min(err, dim=1) | |
| return err, inds | |
| def min_fde( | |
| traj: torch.Tensor, traj_gt: torch.Tensor, masks: torch.Tensor | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Computes final displacement error for the best trajectory is a set, | |
| with respect to ground truth | |
| :param traj: predictions, shape [batch_size, num_modes, sequence_length, 2] | |
| :param traj_gt: ground truth trajectory, shape | |
| [batch_size, sequence_length, 2] | |
| :param masks: masks for varying length ground truth, shape | |
| [batch_size, sequence_length] | |
| :return errs, inds: errors and indices for modes with min error, | |
| shape [batch_size] | |
| """ | |
| num_modes = traj.shape[1] | |
| traj_gt_rpt = traj_gt.unsqueeze(1).repeat(1, num_modes, 1, 1) | |
| lengths = torch.sum(1 - masks, dim=1).long() | |
| inds = lengths.unsqueeze(1).unsqueeze(2).unsqueeze(3).repeat(1, num_modes, 1, 2) - 1 | |
| traj_last = torch.gather(traj[..., :2], dim=2, index=inds).squeeze(2) | |
| traj_gt_last = torch.gather(traj_gt_rpt, dim=2, index=inds).squeeze(2) | |
| err = traj_gt_last - traj_last[..., 0:2] | |
| err = torch.pow(err, exponent=2) | |
| err = torch.sum(err, dim=2) | |
| err = torch.pow(err, exponent=0.5) | |
| err, inds = torch.min(err, dim=1) | |
| return err, inds | |
| def miss_rate( | |
| traj: torch.Tensor, | |
| traj_gt: torch.Tensor, | |
| masks: torch.Tensor, | |
| dist_thresh: float = 2, | |
| ) -> torch.Tensor: | |
| """ | |
| Computes miss rate for mini batch of trajectories, | |
| with respect to ground truth and given distance threshold | |
| :param traj: predictions, shape [batch_size, num_modes, sequence_length, 2] | |
| :param traj_gt: ground truth trajectory, | |
| shape [batch_size, sequence_length, 2] | |
| :param masks: masks for varying length ground truth, | |
| shape [batch_size, sequence_length] | |
| :param dist_thresh: distance threshold for computing miss rate. | |
| :return errs, inds: errors and indices for modes with min error, | |
| shape [batch_size] | |
| """ | |
| num_modes = traj.shape[1] | |
| traj_gt_rpt = traj_gt.unsqueeze(1).repeat(1, num_modes, 1, 1) | |
| masks_rpt = masks.unsqueeze(1).repeat(1, num_modes, 1) | |
| dist = traj_gt_rpt - traj[:, :, :, 0:2] | |
| dist = torch.pow(dist, exponent=2) | |
| dist = torch.sum(dist, dim=3) | |
| dist = torch.pow(dist, exponent=0.5) | |
| dist[masks_rpt.bool()] = -math.inf | |
| dist, _ = torch.max(dist, dim=2) | |
| dist, _ = torch.min(dist, dim=1) | |
| m_r = torch.sum(torch.as_tensor(dist > dist_thresh)) / len(dist) | |
| return m_r | |
| def traj_fde(gt_box, pred_box, final_step): | |
| if gt_box.traj.shape[0] <= 0: | |
| return np.inf | |
| final_step = min(gt_box.traj.shape[0], final_step) | |
| gt_final = gt_box.traj[None, final_step - 1] | |
| pred_final = np.array(pred_box.traj)[:, final_step - 1, :] | |
| err = gt_final - pred_final | |
| err = np.sqrt(np.sum(np.square(gt_final - pred_final), axis=-1)) | |
| return np.min(err) | |