File size: 4,424 Bytes
da2e2ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from __future__ import annotations

from typing import Dict

import numpy as np
import torch

from det_map.data.datasets.dataclasses import AgentInput, Camera
from det_map.data.datasets.lidar_utils import transform_points, render_image
from navsim.planning.training.abstract_feature_target_builder import AbstractFeatureBuilder
from mmcv.parallel import DataContainer as DC

class LiDARCameraFeatureBuilder(AbstractFeatureBuilder):
    def __init__(self, pipelines):
        super().__init__()
        self.pipelines = pipelines

    def compute_features(self, agent_input: AgentInput) -> Dict[str, torch.Tensor]:
        img_pipeline = self.pipelines['img']
        timestamps_ori = agent_input.timestamps
        timestamps = [(timestamps_ori[-1] - tmp) / 1e6 for tmp in timestamps_ori]

        lidars = [np.copy(tmp.lidar_pc) for tmp in agent_input.lidars]
        ego2globals = [tmp for tmp in agent_input.ego2globals]

        # last frame is the key frame
        global2ego_key = np.linalg.inv(ego2globals[-1])
        # ego2global, global2ego key frame
        lidars_warped = [transform_points(transform_points(pts, mat), global2ego_key)
                  for pts, mat in zip(lidars[:-1], ego2globals[:-1])]
        lidars_warped.append(lidars[-1])
        for i, l in enumerate(lidars_warped):
            # x,y,z,intensity,timestamp
            l[4] = timestamps[i]
            lidars_warped[i] = torch.from_numpy(l[:5]).t()


        # debug visualize lidar pc
        # for idx, lidar in enumerate(lidars_warped):
        #     render_image(lidar, str('warped'+ str(idx)))
        # for idx, lidar in enumerate([tmp.lidar_pc for tmp in agent_input.lidars]):
        #     render_image(lidar, str('ori'+ str(idx)))

        cams_all_frames = [[
            tmp.cam_f0,
            # tmp.cam_l0,
            # tmp.cam_l1,
            # tmp.cam_l2,
            # tmp.cam_r0,
            # tmp.cam_r1,
            # tmp.cam_r2,
            tmp.cam_b0
        ] for tmp in agent_input.cameras]

        image, canvas, sensor2lidar_rotation, sensor2lidar_translation, intrinsics, distortion, post_rot, post_tran = [], [], [], [], [], [], [], []
        for cams_frame_t in cams_all_frames:
            image_t, canvas_t, sensor2lidar_rotation_t, sensor2lidar_translation_t, intrinsics_t, distortion_t, post_rot_t, post_tran_t = [], [], [], [], [], [], [], []
            for cam in cams_frame_t:
                cam_processed: Camera = img_pipeline(cam)
                image_t.append(cam_processed.image)
                canvas_t.append(cam_processed.canvas)
                sensor2lidar_rotation_t.append(cam_processed.sensor2lidar_rotation)
                sensor2lidar_translation_t.append(cam_processed.sensor2lidar_translation)
                intrinsics_t.append(cam_processed.intrinsics)
                distortion_t.append(cam_processed.distortion)
                post_rot_t.append(cam_processed.post_rot)
                post_tran_t.append(cam_processed.post_tran)
            image.append(torch.stack(image_t))
            canvas.append(torch.stack(canvas_t))
            sensor2lidar_rotation.append(torch.stack(sensor2lidar_rotation_t))
            sensor2lidar_translation.append(torch.stack(sensor2lidar_translation_t))
            intrinsics.append(torch.stack(intrinsics_t))
            distortion.append(torch.stack(distortion_t))
            post_rot.append(torch.stack(post_rot_t))
            post_tran.append(torch.stack(post_tran_t))


        # img: T, N_CAM, C, H, W
        # imgs = DC(torch.stack(image), cpu_only=False, stack=True)
        #combine = torch.matmul(sensor2lidar_rotation, torch.inverse(intrinsics))
        #coords = torch.matmul(combine, coords)
        #coords += sensor2lidar_translation
        imgs = torch.stack(image)
        return {
            "image": imgs,
            'canvas': torch.stack(canvas).to(imgs),
            'sensor2lidar_rotation': torch.stack(sensor2lidar_rotation).to(imgs),
            'sensor2lidar_translation': torch.stack(sensor2lidar_translation).to(imgs),
            'intrinsics': torch.stack(intrinsics).to(imgs),
            'distortion': torch.stack(distortion).to(imgs),
            'post_rot': torch.stack(post_rot).to(imgs),
            'post_tran': torch.stack(post_tran).to(imgs),
            "lidars_warped": lidars_warped
        }