| import torch | |
| def normalize_bbox(bboxes): | |
| cx = bboxes[..., 0:1] | |
| cy = bboxes[..., 1:2] | |
| cz = bboxes[..., 2:3] | |
| w = bboxes[..., 3:4].log() | |
| l = bboxes[..., 4:5].log() | |
| h = bboxes[..., 5:6].log() | |
| rot = bboxes[..., 6:7] | |
| if bboxes.size(-1) > 7: | |
| vx = bboxes[..., 7:8] | |
| vy = bboxes[..., 8:9] | |
| out = torch.cat([cx, cy, w, l, cz, h, rot.sin(), rot.cos(), vx, vy], dim=-1) | |
| else: | |
| out = torch.cat([cx, cy, w, l, cz, h, rot.sin(), rot.cos()], dim=-1) | |
| return out | |
| def denormalize_bbox(normalized_bboxes): | |
| rot_sin = normalized_bboxes[..., 6:7] | |
| rot_cos = normalized_bboxes[..., 7:8] | |
| rot = torch.atan2(rot_sin, rot_cos) | |
| cx = normalized_bboxes[..., 0:1] | |
| cy = normalized_bboxes[..., 1:2] | |
| cz = normalized_bboxes[..., 4:5] | |
| w = normalized_bboxes[..., 2:3].exp() | |
| l = normalized_bboxes[..., 3:4].exp() | |
| h = normalized_bboxes[..., 5:6].exp() | |
| if normalized_bboxes.size(-1) > 8: | |
| vx = normalized_bboxes[..., 8:9] | |
| vy = normalized_bboxes[..., 9:10] | |
| out = torch.cat([cx, cy, cz, w, l, h, rot, vx, vy], dim=-1) | |
| else: | |
| out = torch.cat([cx, cy, cz, w, l, h, rot], dim=-1) | |
| return out | |
| def encode_bbox(bboxes, pc_range=None): | |
| xyz = bboxes[..., 0:3].clone() | |
| wlh = bboxes[..., 3:6].log() | |
| rot = bboxes[..., 6:7] | |
| if pc_range is not None: | |
| xyz[..., 0] = (xyz[..., 0] - pc_range[0]) / (pc_range[3] - pc_range[0]) | |
| xyz[..., 1] = (xyz[..., 1] - pc_range[1]) / (pc_range[4] - pc_range[1]) | |
| xyz[..., 2] = (xyz[..., 2] - pc_range[2]) / (pc_range[5] - pc_range[2]) | |
| if bboxes.shape[-1] > 7: | |
| vel = bboxes[..., 7:9].clone() | |
| return torch.cat([xyz, wlh, rot.sin(), rot.cos(), vel], dim=-1) | |
| else: | |
| return torch.cat([xyz, wlh, rot.sin(), rot.cos()], dim=-1) | |
| def decode_bbox(bboxes, pc_range=None): | |
| xyz = bboxes[..., 0:3].clone() | |
| wlh = bboxes[..., 3:6].exp() | |
| rot = torch.atan2(bboxes[..., 6:7], bboxes[..., 7:8]) | |
| if pc_range is not None: | |
| xyz[..., 0] = xyz[..., 0] * (pc_range[3] - pc_range[0]) + pc_range[0] | |
| xyz[..., 1] = xyz[..., 1] * (pc_range[4] - pc_range[1]) + pc_range[1] | |
| xyz[..., 2] = xyz[..., 2] * (pc_range[5] - pc_range[2]) + pc_range[2] | |
| if bboxes.shape[-1] > 8: | |
| vel = bboxes[..., 8:10].clone() | |
| return torch.cat([xyz, wlh, rot, vel], dim=-1) | |
| else: | |
| return torch.cat([xyz, wlh, rot], dim=-1) | |