File size: 3,986 Bytes
663494c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import torch
import math
import numpy as np
from typing import List, Dict, Tuple, Callable, Union
def min_ade(
traj: torch.Tensor, traj_gt: torch.Tensor, masks: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Computes average displacement error for the best trajectory is a set,
with respect to ground truth
:param traj: predictions, shape [batch_size, num_modes, sequence_length, 2]
:param traj_gt: ground truth trajectory, shape
[batch_size, sequence_length, 2]
:param masks: masks for varying length ground truth, shape
[batch_size, sequence_length]
:return errs, inds: errors and indices for modes with min error, shape
[batch_size]
"""
num_modes = traj.shape[1]
traj_gt_rpt = traj_gt.unsqueeze(1).repeat(1, num_modes, 1, 1)
masks_rpt = masks.unsqueeze(1).repeat(1, num_modes, 1)
err = traj_gt_rpt - traj[:, :, :, 0:2]
err = torch.pow(err, exponent=2)
err = torch.sum(err, dim=3)
err = torch.pow(err, exponent=0.5)
err = torch.sum(err * (1 - masks_rpt), dim=2) / torch.clip(
torch.sum((1 - masks_rpt), dim=2), min=1
)
err, inds = torch.min(err, dim=1)
return err, inds
def min_fde(
traj: torch.Tensor, traj_gt: torch.Tensor, masks: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Computes final displacement error for the best trajectory is a set,
with respect to ground truth
:param traj: predictions, shape [batch_size, num_modes, sequence_length, 2]
:param traj_gt: ground truth trajectory, shape
[batch_size, sequence_length, 2]
:param masks: masks for varying length ground truth, shape
[batch_size, sequence_length]
:return errs, inds: errors and indices for modes with min error,
shape [batch_size]
"""
num_modes = traj.shape[1]
traj_gt_rpt = traj_gt.unsqueeze(1).repeat(1, num_modes, 1, 1)
lengths = torch.sum(1 - masks, dim=1).long()
inds = lengths.unsqueeze(1).unsqueeze(2).unsqueeze(3).repeat(1, num_modes, 1, 2) - 1
traj_last = torch.gather(traj[..., :2], dim=2, index=inds).squeeze(2)
traj_gt_last = torch.gather(traj_gt_rpt, dim=2, index=inds).squeeze(2)
err = traj_gt_last - traj_last[..., 0:2]
err = torch.pow(err, exponent=2)
err = torch.sum(err, dim=2)
err = torch.pow(err, exponent=0.5)
err, inds = torch.min(err, dim=1)
return err, inds
def miss_rate(
traj: torch.Tensor,
traj_gt: torch.Tensor,
masks: torch.Tensor,
dist_thresh: float = 2,
) -> torch.Tensor:
"""
Computes miss rate for mini batch of trajectories,
with respect to ground truth and given distance threshold
:param traj: predictions, shape [batch_size, num_modes, sequence_length, 2]
:param traj_gt: ground truth trajectory,
shape [batch_size, sequence_length, 2]
:param masks: masks for varying length ground truth,
shape [batch_size, sequence_length]
:param dist_thresh: distance threshold for computing miss rate.
:return errs, inds: errors and indices for modes with min error,
shape [batch_size]
"""
num_modes = traj.shape[1]
traj_gt_rpt = traj_gt.unsqueeze(1).repeat(1, num_modes, 1, 1)
masks_rpt = masks.unsqueeze(1).repeat(1, num_modes, 1)
dist = traj_gt_rpt - traj[:, :, :, 0:2]
dist = torch.pow(dist, exponent=2)
dist = torch.sum(dist, dim=3)
dist = torch.pow(dist, exponent=0.5)
dist[masks_rpt.bool()] = -math.inf
dist, _ = torch.max(dist, dim=2)
dist, _ = torch.min(dist, dim=1)
m_r = torch.sum(torch.as_tensor(dist > dist_thresh)) / len(dist)
return m_r
def traj_fde(gt_box, pred_box, final_step):
if gt_box.traj.shape[0] <= 0:
return np.inf
final_step = min(gt_box.traj.shape[0], final_step)
gt_final = gt_box.traj[None, final_step - 1]
pred_final = np.array(pred_box.traj)[:, final_step - 1, :]
err = gt_final - pred_final
err = np.sqrt(np.sum(np.square(gt_final - pred_final), axis=-1))
return np.min(err)
|