unknownuser6666's picture
Upload folder using huggingface_hub
663494c verified
import math
import torch
from einops import rearrange, repeat
def bivariate_gaussian_activation(ip):
"""
Activation function to output parameters of bivariate Gaussian distribution.
Args:
ip (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output tensor containing the parameters of the bivariate Gaussian distribution.
"""
# print('ip', ip.size())
mu_x = ip[..., 0:1]
# print('mu_x', mu_x.size())
mu_y = ip[..., 1:2]
# below not used in the current models, deterministic outputs
sig_x = ip[..., 2:3]
# print('sig_x', sig_x.size())
sig_y = ip[..., 3:4]
rho = ip[..., 4:5]
# print('rho', rho.size())
sig_x = torch.exp(sig_x)
# print('sig_x', sig_x.size())
sig_y = torch.exp(sig_y)
rho = torch.tanh(rho)
# print('rho', rho.size())
out = torch.cat([mu_x, mu_y, sig_x, sig_y, rho], dim=-1)
# print('out', out.size())
return out
def norm_points(pos, pc_range):
"""
Normalize the end points of a given position tensor.
Args:
pos (torch.Tensor): Input position tensor.
pc_range (List[float]): Point cloud range.
Returns:
torch.Tensor: Normalized end points tensor.
"""
x_norm = (pos[..., 0] - pc_range[0]) / (pc_range[3] - pc_range[0])
y_norm = (pos[..., 1] - pc_range[1]) / (pc_range[4] - pc_range[1])
return torch.stack([x_norm, y_norm], dim=-1)
def pos2posemb2d(pos, num_pos_feats=128, temperature=10000):
"""
Convert 2D position into positional embeddings.
Args:
pos (torch.Tensor): Input 2D position tensor.
num_pos_feats (int, optional): Number of positional features. Default is 128.
temperature (int, optional): Temperature factor for positional embeddings. Default is 10000.
Returns:
torch.Tensor: Positional embeddings tensor.
"""
scale = 2 * math.pi
pos = pos * scale
dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos.device)
dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats)
pos_x = pos[..., 0, None] / dim_t
pos_y = pos[..., 1, None] / dim_t
pos_x = torch.stack(
(pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1
).flatten(-2)
pos_y = torch.stack(
(pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=-1
).flatten(-2)
posemb = torch.cat((pos_y, pos_x), dim=-1)
return posemb
def rot_2d(yaw):
"""
Compute 2D rotation matrix for a given yaw angle tensor.
Args:
yaw (torch.Tensor): Input yaw angle tensor.
Returns:
torch.Tensor: 2D rotation matrix tensor.
"""
sy, cy = torch.sin(yaw), torch.cos(yaw)
out = torch.stack([torch.stack([cy, -sy]), torch.stack([sy, cy])]).permute(
[2, 0, 1]
)
return out
def anchor_coordinate_transform(
anchors, bbox_results, with_translation_transform=True, with_rotation_transform=True
):
"""
Transform anchor coordinates with respect to detected bounding boxes in the batch.
Args:
anchors (torch.Tensor): A tensor containing the k-means anchor values.
bbox_results (List[Tuple[torch.Tensor]]): A list of tuples containing the bounding box results for each image in the batch.
with_translate (bool, optional): Whether to perform translation transformation. Defaults to True.
with_rot (bool, optional): Whether to perform rotation transformation. Defaults to True.
Returns:
torch.Tensor: A tensor containing the transformed anchor coordinates.
"""
batch_size = len(bbox_results)
batched_anchors = []
transformed_anchors = anchors[
None, ...
] # expand num agents: num_groups, num_modes, 12, 2 -> 1, ...
for i in range(batch_size):
bboxes, scores, labels, bbox_index, mask = bbox_results[i]
yaw = bboxes.yaw.to(transformed_anchors.device)
bbox_centers = bboxes.gravity_center.to(transformed_anchors.device)
if with_rotation_transform:
angle = yaw - 3.1415953 # num_agents, 1
rot_yaw = rot_2d(angle) # num_agents, 2, 2
rot_yaw = rot_yaw[:, None, None, :, :] # num_agents, 1, 1, 2, 2
transformed_anchors = rearrange(
transformed_anchors, "b g m t c -> b g m c t"
) # 1, num_groups, num_modes, 12, 2 -> 1, num_groups, num_modes, 2, 12
transformed_anchors = torch.matmul(
rot_yaw, transformed_anchors
) # -> num_agents, num_groups, num_modes, 12, 2
transformed_anchors = rearrange(
transformed_anchors, "b g m c t -> b g m t c"
)
if with_translation_transform:
transformed_anchors = (
bbox_centers[:, None, None, None, :2] + transformed_anchors
)
batched_anchors.append(transformed_anchors)
return torch.stack(batched_anchors)
def trajectory_coordinate_transform(
trajectory,
bbox_results,
with_translation_transform=True,
with_rotation_transform=True,
):
"""
Transform trajectory coordinates with respect to detected bounding boxes in the batch.
Args:
trajectory (torch.Tensor): predicted trajectory.
bbox_results (List[Tuple[torch.Tensor]]): A list of tuples containing the bounding box results for each image in the batch.
with_translate (bool, optional): Whether to perform translation transformation. Defaults to True.
with_rot (bool, optional): Whether to perform rotation transformation. Defaults to True.
Returns:
torch.Tensor: A tensor containing the transformed trajectory coordinates.
"""
batch_size = len(bbox_results)
batched_trajectories = []
for i in range(batch_size):
bboxes, scores, labels, bbox_index, mask = bbox_results[i]
yaw = bboxes.yaw.to(trajectory.device)
bbox_centers = bboxes.gravity_center.to(trajectory.device)
transformed_trajectory = trajectory[i, ...]
if with_rotation_transform:
# we take negtive here, to reverse the trajectory back to ego centric coordinate
angle = -(yaw - 3.1415953)
rot_yaw = rot_2d(angle)
rot_yaw = rot_yaw[:, None, None, :, :] # A, 1, 1, 2, 2
transformed_trajectory = rearrange(
transformed_trajectory, "a g p t c -> a g p c t"
) # A, G, P, 12 ,2 -> # A, G, P, 2, 12
transformed_trajectory = torch.matmul(
rot_yaw, transformed_trajectory
) # -> A, G, P, 12, 2
transformed_trajectory = rearrange(
transformed_trajectory, "a g p c t -> a g p t c"
)
if with_translation_transform:
transformed_trajectory = (
bbox_centers[:, None, None, None, :2] + transformed_trajectory
)
batched_trajectories.append(transformed_trajectory)
return torch.stack(batched_trajectories)