File size: 5,447 Bytes
f71ac1d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 | """Post process after transformation."""
from __future__ import annotations
import torch
from vis4d.common.typing import NDArrayF32, NDArrayI64
from vis4d.data.const import CommonKeys as K
from vis4d.op.box.box2d import bbox_area, bbox_clip
from .base import Transform
@Transform(
in_keys=[
K.boxes2d,
K.boxes2d_classes,
K.boxes2d_track_ids,
K.input_hw,
K.boxes3d,
K.boxes3d_classes,
K.boxes3d_track_ids,
],
out_keys=[
K.boxes2d,
K.boxes2d_classes,
K.boxes2d_track_ids,
K.boxes3d,
K.boxes3d_classes,
K.boxes3d_track_ids,
],
)
class PostProcessBoxes2D:
"""Post process after transformation."""
def __init__(
self, min_area: float = 7.0 * 7.0, clip_bboxes_to_image: bool = True
) -> None:
"""Creates an instance of the class.
Args:
min_area (float): Minimum area of the bounding box. Defaults to
7.0 * 7.0.
clip_bboxes_to_image (bool): Whether to clip the bounding boxes to
the image size. Defaults to True.
"""
self.min_area = min_area
self.clip_bboxes_to_image = clip_bboxes_to_image
def __call__(
self,
boxes_list: list[NDArrayF32],
classes_list: list[NDArrayI64],
track_ids_list: list[NDArrayI64] | None,
input_hw_list: list[tuple[int, int]],
boxes3d_list: list[NDArrayF32] | None,
boxes3d_classes_list: list[NDArrayI64] | None,
boxes3d_track_ids_list: list[NDArrayI64] | None,
) -> tuple[
list[NDArrayF32],
list[NDArrayI64],
list[NDArrayI64] | None,
list[NDArrayF32] | None,
list[NDArrayI64] | None,
list[NDArrayI64] | None,
]:
"""Post process according to boxes2D after transformation.
Args:
boxes_list (list[NDArrayF32]): The bounding boxes to be post
processed.
classes_list (list[NDArrayF32]): The classes of the bounding boxes.
track_ids_list (list[NDArrayI64] | None): The track ids of the
bounding boxes.
input_hw_list (list[tuple[int, int]]): The height and width of the
input image.
boxes3d_list (list[NDArrayF32] | None): The 3D bounding boxes to be
post processed.
boxes3d_classes_list (list[NDArrayI64] | None): The classes of the
3D bounding boxes.
boxes3d_track_ids_list (list[NDArrayI64] | None): The track ids of
the 3D bounding boxes.
Returns:
tuple[list[NDArrayF32], list[NDArrayI64], list[NDArrayI64] | None,
list[NDArrayF32] | None, list[NDArrayI64] | None,
list[NDArrayI64] | None]: The post processed results.
"""
new_track_ids: list[NDArrayI64] | None = (
[] if track_ids_list is not None else None
)
new_boxes3d: list[NDArrayF32] | None = (
[] if boxes3d_list is not None else None
)
new_boxes3d_classes: list[NDArrayI64] | None = (
[] if boxes3d_classes_list is not None else None
)
new_boxes3d_track_ids: list[NDArrayI64] | None = (
[] if boxes3d_track_ids_list is not None else None
)
for i, (boxes, classes) in enumerate(zip(boxes_list, classes_list)):
boxes_ = torch.from_numpy(boxes)
if self.clip_bboxes_to_image:
boxes_ = bbox_clip(boxes_, input_hw_list[i])
keep = (bbox_area(boxes_) >= self.min_area).numpy()
boxes_list[i] = boxes[keep]
classes_list[i] = classes[keep]
if track_ids_list is not None:
assert new_track_ids is not None
new_track_ids.append(track_ids_list[i][keep])
if boxes3d_list is not None:
assert new_boxes3d is not None
new_boxes3d.append(boxes3d_list[i][keep])
if boxes3d_classes_list is not None:
assert new_boxes3d_classes is not None
new_boxes3d_classes.append(boxes3d_classes_list[i][keep])
if boxes3d_track_ids_list is not None:
assert new_boxes3d_track_ids is not None
new_boxes3d_track_ids.append(boxes3d_track_ids_list[i][keep])
return (
boxes_list,
classes_list,
new_track_ids,
new_boxes3d,
new_boxes3d_classes,
new_boxes3d_track_ids,
)
@Transform(in_keys=[K.boxes2d_track_ids], out_keys=[K.boxes2d_track_ids])
class RescaleTrackIDs:
"""Rescale track ids."""
def __call__(self, track_ids_list: list[NDArrayI64]) -> list[NDArrayI64]:
"""Rescale the track ids.
Args:
track_ids_list (list[NDArrayI64]): The track ids to be
rescaled.
Returns:
list[NDArrayI64]: The rescaled track ids.
"""
track_ids_all: dict[int, int] = {}
for track_ids in track_ids_list:
for track_id in track_ids:
if track_id not in track_ids_all:
track_ids_all[track_id] = len(track_ids_all)
for track_ids in track_ids_list:
for i, track_id in enumerate(track_ids):
track_ids[i] = track_ids_all[track_id]
return track_ids_list
|