|
|
import numpy as np
|
|
|
import torch
|
|
|
from torchvision.transforms import functional as tfn
|
|
|
import torchvision.transforms.functional as tvf
|
|
|
|
|
|
from ..utils import decompose_rotmat
|
|
|
from ..image import pad_image, rectify_image, resize_image
|
|
|
from ...utils.wrappers import Camera
|
|
|
from ..schema import KITTIDataConfiguration
|
|
|
|
|
|
|
|
|
class BEVTransform:
|
|
|
def __init__(self,
|
|
|
cfg: KITTIDataConfiguration, augmentations):
|
|
|
self.cfg = cfg
|
|
|
self.augmentations = augmentations
|
|
|
|
|
|
@staticmethod
|
|
|
def _compact_labels(msk, cat, iscrowd):
|
|
|
ids = np.unique(msk)
|
|
|
if 0 not in ids:
|
|
|
ids = np.concatenate((np.array([0], dtype=np.int32), ids), axis=0)
|
|
|
|
|
|
ids_to_compact = np.zeros((ids.max() + 1,), dtype=np.int32)
|
|
|
ids_to_compact[ids] = np.arange(0, ids.size, dtype=np.int32)
|
|
|
|
|
|
msk = ids_to_compact[msk]
|
|
|
cat = cat[ids]
|
|
|
iscrowd = iscrowd[ids]
|
|
|
|
|
|
return msk, cat, iscrowd
|
|
|
|
|
|
def __call__(self, img, bev_msk=None, bev_plabel=None, fv_msk=None, bev_weights_msk=None,
|
|
|
bev_cat=None, bev_iscrowd=None, fv_cat=None, fv_iscrowd=None,
|
|
|
fv_intrinsics=None, ego_pose=None):
|
|
|
|
|
|
if bev_cat is not None:
|
|
|
bev_cat = np.array(bev_cat, dtype=np.int32)
|
|
|
if bev_iscrowd is not None:
|
|
|
bev_iscrowd = np.array(bev_iscrowd, dtype=np.uint8)
|
|
|
|
|
|
if ego_pose is not None:
|
|
|
ego_pose = np.array(ego_pose, dtype=np.float32)
|
|
|
|
|
|
roll, pitch, yaw = decompose_rotmat(ego_pose[:3, :3])
|
|
|
|
|
|
|
|
|
img = tfn.to_tensor(img)
|
|
|
|
|
|
fx = fv_intrinsics[0][0]
|
|
|
fy = fv_intrinsics[1][1]
|
|
|
cx = fv_intrinsics[0][2]
|
|
|
cy = fv_intrinsics[1][2]
|
|
|
width = img.shape[2]
|
|
|
height = img.shape[1]
|
|
|
|
|
|
cam = Camera(torch.tensor(
|
|
|
[width, height, fx, fy, cx - 0.5, cy - 0.5])).float()
|
|
|
|
|
|
if not self.cfg.gravity_align:
|
|
|
|
|
|
roll = 0.0
|
|
|
pitch = 0.0
|
|
|
img, valid = rectify_image(img, cam, roll, pitch)
|
|
|
else:
|
|
|
img, valid = rectify_image(
|
|
|
img, cam, roll, pitch if self.cfg.rectify_pitch else None
|
|
|
)
|
|
|
roll = 0.0
|
|
|
if self.cfg.rectify_pitch:
|
|
|
pitch = 0.0
|
|
|
|
|
|
if self.cfg.target_focal_length is not None:
|
|
|
|
|
|
factor = self.cfg.target_focal_length / cam.f.numpy()
|
|
|
size = (np.array(img.shape[-2:][::-1]) * factor).astype(int)
|
|
|
img, _, cam, valid = resize_image(img, size, camera=cam, valid=valid)
|
|
|
size_out = self.cfg.resize_image
|
|
|
if size_out is None:
|
|
|
|
|
|
stride = self.cfg.pad_to_multiple
|
|
|
size_out = (np.ceil((size / stride)) * stride).astype(int)
|
|
|
|
|
|
img, valid, cam = pad_image(
|
|
|
img, size_out, cam, valid, crop_and_center=False
|
|
|
)
|
|
|
elif self.cfg.resize_image is not None:
|
|
|
img, _, cam, valid = resize_image(
|
|
|
img, self.cfg.resize_image, fn=max, camera=cam, valid=valid
|
|
|
)
|
|
|
if self.cfg.pad_to_square:
|
|
|
|
|
|
img, valid, cam = pad_image(img, self.cfg.resize_image, cam, valid)
|
|
|
|
|
|
|
|
|
if bev_msk is not None:
|
|
|
bev_msk = np.expand_dims(
|
|
|
np.array(bev_msk, dtype=np.int32, copy=False),
|
|
|
axis=0
|
|
|
)
|
|
|
bev_msk, bev_cat, bev_iscrowd = self._compact_labels(
|
|
|
bev_msk, bev_cat, bev_iscrowd
|
|
|
)
|
|
|
|
|
|
bev_msk = torch.from_numpy(bev_msk)
|
|
|
bev_cat = torch.from_numpy(bev_cat)
|
|
|
|
|
|
rotated_mask = torch.rot90(bev_msk, dims=(1, 2))
|
|
|
cropped_mask = rotated_mask[:, :672, (rotated_mask.size(2) - 672) // 2:-(rotated_mask.size(2) - 672) // 2]
|
|
|
|
|
|
bev_msk = cropped_mask.squeeze(0)
|
|
|
seg_masks = bev_cat[bev_msk]
|
|
|
|
|
|
seg_masks_onehot = seg_masks.clone()
|
|
|
seg_masks_onehot[seg_masks_onehot == 255] = 0
|
|
|
seg_masks_onehot = torch.nn.functional.one_hot(
|
|
|
seg_masks_onehot.to(torch.int64),
|
|
|
num_classes=self.cfg.num_classes
|
|
|
)
|
|
|
seg_masks_onehot[seg_masks == 255] = 0
|
|
|
|
|
|
seg_masks_onehot = seg_masks_onehot.permute(2, 0, 1)
|
|
|
|
|
|
seg_masks_down = tvf.resize(seg_masks_onehot, (100, 100))
|
|
|
|
|
|
seg_masks_down = seg_masks_down.permute(1, 2, 0)
|
|
|
|
|
|
if self.cfg.class_mapping is not None:
|
|
|
seg_masks_down = seg_masks_down[:, :, self.cfg.class_mapping]
|
|
|
|
|
|
img = self.augmentations(img)
|
|
|
flood_masks = torch.all(seg_masks_down == 0, dim=2).float()
|
|
|
|
|
|
|
|
|
ret = {
|
|
|
"image": img,
|
|
|
"valid": valid,
|
|
|
"camera": cam,
|
|
|
"seg_masks": (seg_masks_down).float().contiguous(),
|
|
|
"flood_masks": flood_masks,
|
|
|
"roll_pitch_yaw": torch.tensor((roll, pitch, yaw)).float(),
|
|
|
"confidence_map": flood_masks,
|
|
|
}
|
|
|
|
|
|
for key, value in ret.items():
|
|
|
if isinstance(value, np.ndarray):
|
|
|
ret[key] = torch.from_numpy(value)
|
|
|
|
|
|
return ret
|
|
|
|