File size: 12,784 Bytes
663494c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
import os
import numpy as np
import cv2
import mmcv
from mmdet.datasets.builder import PIPELINES
from mmdet3d.datasets.pipelines import LoadAnnotations3D

@PIPELINES.register_module()
class LoadMultiViewImageFromFilesInCeph(object):
    """Load multi channel images from a list of separate channel files.

    Expects results['img_filename'] to be a list of filenames.

    Args:
        to_float32 (bool): Whether to convert the img to float32.
            Defaults to False.
        color_type (str): Color type of the file. Defaults to 'unchanged'.
    """

    def __init__(
        self,
        to_float32=False,
        color_type="unchanged",
        file_client_args=dict(backend="disk"),
        img_root="",
    ):
        self.to_float32 = to_float32
        self.color_type = color_type
        self.file_client_args = file_client_args.copy()
        self.file_client = mmcv.FileClient(**self.file_client_args)
        self.img_root = img_root

    def __call__(self, results):
        """Call function to load multi-view image from files.

        Args:
            results (dict): Result dict containing multi-view image filenames.

        Returns:
            dict: The result dict containing the multi-view image data. \
                Added keys and values are described below.

                - filename (list of str): Multi-view image filenames.
                - img (np.ndarray): Multi-view image arrays.
                - img_shape (tuple[int]): Shape of multi-view image arrays.
                - ori_shape (tuple[int]): Shape of original image arrays.
                - pad_shape (tuple[int]): Shape of padded image arrays.
                - scale_factor (float): Scale factor.
                - img_norm_cfg (dict): Normalization configuration of images.
        """
        images_multiView = []
        filename = results["img_filename"]
        img_path: str   # already an absolute path, no need to pad in front
        for img_path in filename:

            if not str(img_path).startswith('/') and (self.img_root not in str(img_path)):
                img_path = os.path.join(self.img_root, img_path)
            if self.file_client_args["backend"] == "petrel":
                img_bytes = self.file_client.get(img_path)
                img = mmcv.imfrombytes(img_bytes)
            elif self.file_client_args["backend"] == "disk":
                img = mmcv.imread(img_path, self.color_type)

            images_multiView.append(img)
        # img is of shape (h, w, c, num_views)
        img = np.stack(
            # [mmcv.imread(name, self.color_type) for name in filename], axis=-1)
            images_multiView,
            axis=-1,
        )
        if self.to_float32:
            img = img.astype(np.float32)
        results["filename"] = filename
        # unravel to list, see `DefaultFormatBundle` in formating.py
        # which will transpose each image separately and then stack into array
        results["img"] = [img[..., i] for i in range(img.shape[-1])]
        results["img_shape"] = img.shape
        # print(len(results["img"]))
        # zxc

        results["ori_shape"] = img.shape
        # Set initial values for default meta_keys
        results["pad_shape"] = img.shape
        results["scale_factor"] = 1.0
        num_channels = 1 if len(img.shape) < 3 else img.shape[2]
        results["img_norm_cfg"] = dict(
            mean=np.zeros(num_channels, dtype=np.float32),
            std=np.ones(num_channels, dtype=np.float32),
            to_rgb=False,
        )
        return results

    def __repr__(self):
        """str: Return a string that describes the module."""
        repr_str = self.__class__.__name__
        repr_str += f"(to_float32={self.to_float32}, "
        repr_str += f"color_type='{self.color_type}')"
        return repr_str

@PIPELINES.register_module()
class LoadMultiViewImageFromFilesWithDownsample(object):
    """Load multi channel images from a list of separate channel files, and downsample the image.

    Args:
        to_float32 (bool): Whether to convert the img to float32.
            Defaults to False.
        img_root (str): The root directory of the images.
        downsample_factor (int): The factor to downsample the image.
    """

    def __init__(
        self,
        to_float32=False,
        img_root="",
        downsample_factor=1,
    ):
        self.to_float32 = to_float32
        self.img_root = img_root
        self.downsample_factor = downsample_factor
        if downsample_factor == 1:
            self.flag = cv2.IMREAD_UNCHANGED
        elif downsample_factor == 2:
            self.flag = cv2.IMREAD_REDUCED_COLOR_2
        elif downsample_factor == 4:
            self.flag = cv2.IMREAD_REDUCED_COLOR_4
        else:
            raise ValueError(f"Invalid downsample factor: {downsample_factor}")

    def imread(self, img_path) -> np.ndarray:
        with open(img_path, 'rb') as f:
            value_buf = f.read()
        img_np = np.frombuffer(value_buf, np.uint8)
        img = cv2.imdecode(img_np, self.flag)
        return img

    def __call__(self, results):
        """Call function to load multi-view image from files.

        Args:
            results (dict): Result dict containing multi-view image filenames.

        Returns:
            dict: The result dict containing the multi-view image data. \
                Added keys and values are described below.

                - filename (list of str): Multi-view image filenames.
                - img (np.ndarray): Multi-view image arrays.
                - img_shape (tuple[int]): Shape of multi-view image arrays.
                - ori_shape (tuple[int]): Shape of original image arrays.
                - pad_shape (tuple[int]): Shape of padded image arrays.
                - scale_factor (float): Scale factor.
                - img_norm_cfg (dict): Normalization configuration of images.
        """
        images_multiView = []
        filenames = results["img_filename"]

        for img_path in filenames:
            if not str(img_path).startswith('/') and (self.img_root not in str(img_path)):
                img_path = os.path.join(self.img_root, img_path)

            img = self.imread(img_path)
            images_multiView.append(img)

        # img is of shape (h, w, c, num_views)
        img = np.stack(
            images_multiView,
            axis=-1,
        )
        if self.to_float32:
            img = img.astype(np.float32)
        results["filename"] = filenames
        # unravel to list, see `DefaultFormatBundle` in formating.py
        # which will transpose each image separately and then stack into array
        results["img"] = [img[..., i] for i in range(img.shape[-1])]
        results["img_shape"] = img.shape
        results["ori_shape"] = (int(img.shape[0] * self.downsample_factor), int(img.shape[1] * self.downsample_factor), img.shape[2], img.shape[3])
        # Set initial values for default meta_keys
        results["pad_shape"] = img.shape
        results["scale_factor"] = 1.0 / self.downsample_factor

        if self.downsample_factor != 1:
            scale_matrix = np.eye(4)
            scale_matrix[0, 0] *= 1.0 / self.downsample_factor
            scale_matrix[1, 1] *= 1.0 / self.downsample_factor
            results["lidar2img"] = [scale_matrix @ l2i for l2i in results["lidar2img"]]
            results["cam_intrinsic"] = [scale_matrix @ cam_intrinsic for cam_intrinsic in results["cam_intrinsic"]]

        num_channels = 1 if len(img.shape) < 3 else img.shape[2]
        results["img_norm_cfg"] = dict(
            mean=np.zeros(num_channels, dtype=np.float32),
            std=np.ones(num_channels, dtype=np.float32),
            to_rgb=False,
        )
        return results

    def __repr__(self):
        """str: Return a string that describes the module."""
        repr_str = self.__class__.__name__
        repr_str += f"(to_float32={self.to_float32}, "
        repr_str += f"color_type='{self.color_type}')"
        return repr_str


@PIPELINES.register_module()
class LoadAnnotations3D_E2E(LoadAnnotations3D):
    """Load Annotations3D.

    Load instance mask and semantic mask of points and
    encapsulate the items into related fields.

    Args:
        with_bbox_3d (bool, optional): Whether to load 3D boxes.
            Defaults to True.
        with_label_3d (bool, optional): Whether to load 3D labels.
            Defaults to True.
        with_attr_label (bool, optional): Whether to load attribute label.
            Defaults to False.
        with_mask_3d (bool, optional): Whether to load 3D instance masks.
            for points. Defaults to False.
        with_seg_3d (bool, optional): Whether to load 3D semantic masks.
            for points. Defaults to False.
        with_bbox (bool, optional): Whether to load 2D boxes.
            Defaults to False.
        with_label (bool, optional): Whether to load 2D labels.
            Defaults to False.
        with_mask (bool, optional): Whether to load 2D instance masks.
            Defaults to False.
        with_seg (bool, optional): Whether to load 2D semantic masks.
            Defaults to False.
        with_bbox_depth (bool, optional): Whether to load 2.5D boxes.
            Defaults to False.
        poly2mask (bool, optional): Whether to convert polygon annotations
            to bitmasks. Defaults to True.
        seg_3d_dtype (dtype, optional): Dtype of 3D semantic masks.
            Defaults to int64
        file_client_args (dict): Config dict of file clients, refer to
            https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py
            for more details.
    """

    def __init__(
        self,
        with_future_anns=False,
        with_ins_inds_3d=False,
        ins_inds_add_1=False,  # NOTE: make ins_inds start from 1, not 0
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.with_future_anns = with_future_anns
        self.with_ins_inds_3d = with_ins_inds_3d

        self.ins_inds_add_1 = ins_inds_add_1

    def _load_future_anns(self, results):
        """Private function to load 3D bounding box annotations.

        Args:
            results (dict): Result dict from :obj:`mmdet3d.CustomDataset`.

        Returns:
            dict: The dict containing loaded 3D bounding box annotations.
        """

        gt_bboxes_3d = []
        gt_labels_3d = []
        gt_inds_3d = []
        # gt_valid_flags = []
        gt_vis_tokens = []

        for ann_info in results["occ_future_ann_infos"]:
            if ann_info is not None:
                gt_bboxes_3d.append(ann_info["gt_bboxes_3d"])
                gt_labels_3d.append(ann_info["gt_labels_3d"])

                ann_gt_inds = ann_info["gt_inds"]
                if self.ins_inds_add_1:
                    ann_gt_inds += 1
                    # NOTE: sdc query is changed from -10 -> -9
                gt_inds_3d.append(ann_gt_inds)

                # gt_valid_flags.append(ann_info['gt_valid_flag'])
                gt_vis_tokens.append(ann_info["gt_vis_tokens"])
            else:
                # invalid frame
                gt_bboxes_3d.append(None)
                gt_labels_3d.append(None)
                gt_inds_3d.append(None)
                # gt_valid_flags.append(None)
                gt_vis_tokens.append(None)

        results["future_gt_bboxes_3d"] = gt_bboxes_3d
        # results['future_bbox3d_fields'].append('gt_bboxes_3d')  # Field is used for augmentations, not needed here
        results["future_gt_labels_3d"] = gt_labels_3d
        results["future_gt_inds"] = gt_inds_3d
        # results['future_gt_valid_flag'] = gt_valid_flags
        results["future_gt_vis_tokens"] = gt_vis_tokens

        return results

    def _load_ins_inds_3d(self, results):
        ann_gt_inds = results["ann_info"]["gt_inds"].copy()  # TODO: note here

        # NOTE: Avoid gt_inds generated twice
        results["ann_info"].pop("gt_inds")

        if self.ins_inds_add_1:
            ann_gt_inds += 1
        results["gt_inds"] = ann_gt_inds
        return results

    def __call__(self, results):
        results = super().__call__(results)

        if self.with_future_anns:
            results = self._load_future_anns(results)
        if self.with_ins_inds_3d:
            results = self._load_ins_inds_3d(results)

        # Generate ann for plan
        if "occ_future_ann_infos_for_plan" in results.keys():
            results = self._load_future_anns_plan(results)

        return results

    def __repr__(self):
        repr_str = super().__repr__()
        indent_str = "    "
        repr_str += f"{indent_str}with_future_anns={self.with_future_anns}, "
        repr_str += f"{indent_str}with_ins_inds_3d={self.with_ins_inds_3d}, "

        return repr_str