File size: 5,447 Bytes
f71ac1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
"""Post process after transformation."""

from __future__ import annotations

import torch

from vis4d.common.typing import NDArrayF32, NDArrayI64
from vis4d.data.const import CommonKeys as K
from vis4d.op.box.box2d import bbox_area, bbox_clip

from .base import Transform


@Transform(
    in_keys=[
        K.boxes2d,
        K.boxes2d_classes,
        K.boxes2d_track_ids,
        K.input_hw,
        K.boxes3d,
        K.boxes3d_classes,
        K.boxes3d_track_ids,
    ],
    out_keys=[
        K.boxes2d,
        K.boxes2d_classes,
        K.boxes2d_track_ids,
        K.boxes3d,
        K.boxes3d_classes,
        K.boxes3d_track_ids,
    ],
)
class PostProcessBoxes2D:
    """Post process after transformation."""

    def __init__(
        self, min_area: float = 7.0 * 7.0, clip_bboxes_to_image: bool = True
    ) -> None:
        """Creates an instance of the class.

        Args:
            min_area (float): Minimum area of the bounding box. Defaults to
                7.0 * 7.0.
            clip_bboxes_to_image (bool): Whether to clip the bounding boxes to
                the image size. Defaults to True.
        """
        self.min_area = min_area
        self.clip_bboxes_to_image = clip_bboxes_to_image

    def __call__(
        self,
        boxes_list: list[NDArrayF32],
        classes_list: list[NDArrayI64],
        track_ids_list: list[NDArrayI64] | None,
        input_hw_list: list[tuple[int, int]],
        boxes3d_list: list[NDArrayF32] | None,
        boxes3d_classes_list: list[NDArrayI64] | None,
        boxes3d_track_ids_list: list[NDArrayI64] | None,
    ) -> tuple[
        list[NDArrayF32],
        list[NDArrayI64],
        list[NDArrayI64] | None,
        list[NDArrayF32] | None,
        list[NDArrayI64] | None,
        list[NDArrayI64] | None,
    ]:
        """Post process according to boxes2D after transformation.

        Args:
            boxes_list (list[NDArrayF32]): The bounding boxes to be post
                processed.
            classes_list (list[NDArrayF32]): The classes of the bounding boxes.
            track_ids_list (list[NDArrayI64] | None): The track ids of the
                bounding boxes.
            input_hw_list (list[tuple[int, int]]): The height and width of the
                input image.
            boxes3d_list (list[NDArrayF32] | None): The 3D bounding boxes to be
                post processed.
            boxes3d_classes_list (list[NDArrayI64] | None): The classes of the
                3D bounding boxes.
            boxes3d_track_ids_list (list[NDArrayI64] | None): The track ids of
                the 3D bounding boxes.

        Returns:
            tuple[list[NDArrayF32], list[NDArrayI64], list[NDArrayI64] | None,
                list[NDArrayF32] | None, list[NDArrayI64] | None,
                list[NDArrayI64] | None]: The post processed results.
        """
        new_track_ids: list[NDArrayI64] | None = (
            [] if track_ids_list is not None else None
        )
        new_boxes3d: list[NDArrayF32] | None = (
            [] if boxes3d_list is not None else None
        )
        new_boxes3d_classes: list[NDArrayI64] | None = (
            [] if boxes3d_classes_list is not None else None
        )
        new_boxes3d_track_ids: list[NDArrayI64] | None = (
            [] if boxes3d_track_ids_list is not None else None
        )
        for i, (boxes, classes) in enumerate(zip(boxes_list, classes_list)):
            boxes_ = torch.from_numpy(boxes)
            if self.clip_bboxes_to_image:
                boxes_ = bbox_clip(boxes_, input_hw_list[i])

            keep = (bbox_area(boxes_) >= self.min_area).numpy()

            boxes_list[i] = boxes[keep]
            classes_list[i] = classes[keep]

            if track_ids_list is not None:
                assert new_track_ids is not None
                new_track_ids.append(track_ids_list[i][keep])

            if boxes3d_list is not None:
                assert new_boxes3d is not None
                new_boxes3d.append(boxes3d_list[i][keep])

            if boxes3d_classes_list is not None:
                assert new_boxes3d_classes is not None
                new_boxes3d_classes.append(boxes3d_classes_list[i][keep])

            if boxes3d_track_ids_list is not None:
                assert new_boxes3d_track_ids is not None
                new_boxes3d_track_ids.append(boxes3d_track_ids_list[i][keep])

        return (
            boxes_list,
            classes_list,
            new_track_ids,
            new_boxes3d,
            new_boxes3d_classes,
            new_boxes3d_track_ids,
        )


@Transform(in_keys=[K.boxes2d_track_ids], out_keys=[K.boxes2d_track_ids])
class RescaleTrackIDs:
    """Rescale track ids."""

    def __call__(self, track_ids_list: list[NDArrayI64]) -> list[NDArrayI64]:
        """Rescale the track ids.

        Args:
            track_ids_list (list[NDArrayI64]): The track ids to be
                rescaled.

        Returns:
            list[NDArrayI64]: The rescaled track ids.
        """
        track_ids_all: dict[int, int] = {}
        for track_ids in track_ids_list:
            for track_id in track_ids:
                if track_id not in track_ids_all:
                    track_ids_all[track_id] = len(track_ids_all)

        for track_ids in track_ids_list:
            for i, track_id in enumerate(track_ids):
                track_ids[i] = track_ids_all[track_id]

        return track_ids_list