import torch import math import numpy as np from typing import List, Dict, Tuple, Callable, Union def min_ade( traj: torch.Tensor, traj_gt: torch.Tensor, masks: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """ Computes average displacement error for the best trajectory is a set, with respect to ground truth :param traj: predictions, shape [batch_size, num_modes, sequence_length, 2] :param traj_gt: ground truth trajectory, shape [batch_size, sequence_length, 2] :param masks: masks for varying length ground truth, shape [batch_size, sequence_length] :return errs, inds: errors and indices for modes with min error, shape [batch_size] """ num_modes = traj.shape[1] traj_gt_rpt = traj_gt.unsqueeze(1).repeat(1, num_modes, 1, 1) masks_rpt = masks.unsqueeze(1).repeat(1, num_modes, 1) err = traj_gt_rpt - traj[:, :, :, 0:2] err = torch.pow(err, exponent=2) err = torch.sum(err, dim=3) err = torch.pow(err, exponent=0.5) err = torch.sum(err * (1 - masks_rpt), dim=2) / torch.clip( torch.sum((1 - masks_rpt), dim=2), min=1 ) err, inds = torch.min(err, dim=1) return err, inds def min_fde( traj: torch.Tensor, traj_gt: torch.Tensor, masks: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """ Computes final displacement error for the best trajectory is a set, with respect to ground truth :param traj: predictions, shape [batch_size, num_modes, sequence_length, 2] :param traj_gt: ground truth trajectory, shape [batch_size, sequence_length, 2] :param masks: masks for varying length ground truth, shape [batch_size, sequence_length] :return errs, inds: errors and indices for modes with min error, shape [batch_size] """ num_modes = traj.shape[1] traj_gt_rpt = traj_gt.unsqueeze(1).repeat(1, num_modes, 1, 1) lengths = torch.sum(1 - masks, dim=1).long() inds = lengths.unsqueeze(1).unsqueeze(2).unsqueeze(3).repeat(1, num_modes, 1, 2) - 1 traj_last = torch.gather(traj[..., :2], dim=2, index=inds).squeeze(2) traj_gt_last = torch.gather(traj_gt_rpt, dim=2, index=inds).squeeze(2) err = traj_gt_last - traj_last[..., 0:2] err = torch.pow(err, exponent=2) err = torch.sum(err, dim=2) err = torch.pow(err, exponent=0.5) err, inds = torch.min(err, dim=1) return err, inds def miss_rate( traj: torch.Tensor, traj_gt: torch.Tensor, masks: torch.Tensor, dist_thresh: float = 2, ) -> torch.Tensor: """ Computes miss rate for mini batch of trajectories, with respect to ground truth and given distance threshold :param traj: predictions, shape [batch_size, num_modes, sequence_length, 2] :param traj_gt: ground truth trajectory, shape [batch_size, sequence_length, 2] :param masks: masks for varying length ground truth, shape [batch_size, sequence_length] :param dist_thresh: distance threshold for computing miss rate. :return errs, inds: errors and indices for modes with min error, shape [batch_size] """ num_modes = traj.shape[1] traj_gt_rpt = traj_gt.unsqueeze(1).repeat(1, num_modes, 1, 1) masks_rpt = masks.unsqueeze(1).repeat(1, num_modes, 1) dist = traj_gt_rpt - traj[:, :, :, 0:2] dist = torch.pow(dist, exponent=2) dist = torch.sum(dist, dim=3) dist = torch.pow(dist, exponent=0.5) dist[masks_rpt.bool()] = -math.inf dist, _ = torch.max(dist, dim=2) dist, _ = torch.min(dist, dim=1) m_r = torch.sum(torch.as_tensor(dist > dist_thresh)) / len(dist) return m_r def traj_fde(gt_box, pred_box, final_step): if gt_box.traj.shape[0] <= 0: return np.inf final_step = min(gt_box.traj.shape[0], final_step) gt_final = gt_box.traj[None, final_step - 1] pred_final = np.array(pred_box.traj)[:, final_step - 1, :] err = gt_final - pred_final err = np.sqrt(np.sum(np.square(gt_final - pred_final), axis=-1)) return np.min(err)