|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import PIL |
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
from eval.mv_recon.dataset_utils.transforms import ImgNorm |
|
|
from dust3r.utils.geometry import depthmap_to_absolute_camera_coordinates |
|
|
import eval.mv_recon.dataset_utils.cropping as cropping |
|
|
|
|
|
|
|
|
class BaseStereoViewDataset: |
|
|
"""Define all basic options. |
|
|
|
|
|
Usage: |
|
|
class MyDataset (BaseStereoViewDataset): |
|
|
def _get_views(self, idx, rng): |
|
|
# overload here |
|
|
views = [] |
|
|
views.append(dict(img=, ...)) |
|
|
return views |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
*, |
|
|
split=None, |
|
|
resolution=None, |
|
|
transform=ImgNorm, |
|
|
aug_crop=False, |
|
|
seed=None, |
|
|
): |
|
|
self.num_views = 2 |
|
|
self.split = split |
|
|
self._set_resolutions(resolution) |
|
|
|
|
|
self.transform = transform |
|
|
if isinstance(transform, str): |
|
|
transform = eval(transform) |
|
|
|
|
|
self.aug_crop = aug_crop |
|
|
self.seed = seed |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.scenes) |
|
|
|
|
|
def get_stats(self): |
|
|
return f"{len(self)} pairs" |
|
|
|
|
|
def __repr__(self): |
|
|
resolutions_str = "[" + ";".join(f"{w}x{h}" for w, h in self._resolutions) + "]" |
|
|
return ( |
|
|
f"""{type(self).__name__}({self.get_stats()}, |
|
|
{self.split=}, |
|
|
{self.seed=}, |
|
|
resolutions={resolutions_str}, |
|
|
{self.transform=})""".replace( |
|
|
"self.", "" |
|
|
) |
|
|
.replace("\n", "") |
|
|
.replace(" ", "") |
|
|
) |
|
|
|
|
|
def _get_views(self, idx, resolution, rng): |
|
|
raise NotImplementedError() |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
if isinstance(idx, tuple): |
|
|
|
|
|
idx, ar_idx = idx |
|
|
else: |
|
|
assert len(self._resolutions) == 1 |
|
|
ar_idx = 0 |
|
|
|
|
|
|
|
|
if self.seed: |
|
|
self._rng = np.random.default_rng(seed=self.seed + idx) |
|
|
elif not hasattr(self, "_rng"): |
|
|
seed = torch.initial_seed() |
|
|
self._rng = np.random.default_rng(seed=seed) |
|
|
|
|
|
|
|
|
resolution = self._resolutions[ |
|
|
ar_idx |
|
|
] |
|
|
views = self._get_views(idx, resolution, self._rng) |
|
|
|
|
|
|
|
|
for v, view in enumerate(views): |
|
|
assert ( |
|
|
"pts3d" not in view |
|
|
), f"pts3d should not be there, they will be computed afterwards based on intrinsics+depthmap for view {view_name(view)}" |
|
|
view["idx"] = v |
|
|
|
|
|
|
|
|
width, height = view["img"].size |
|
|
view["true_shape"] = np.int32((height, width)) |
|
|
view["img"] = self.transform(view["img"]) |
|
|
|
|
|
assert "camera_intrinsics" in view |
|
|
if "camera_pose" not in view: |
|
|
view["camera_pose"] = np.full((4, 4), np.nan, dtype=np.float32) |
|
|
else: |
|
|
assert np.isfinite( |
|
|
view["camera_pose"] |
|
|
).all(), f"NaN in camera pose for view {view_name(view)}" |
|
|
assert "pts3d" not in view |
|
|
assert "valid_mask" not in view |
|
|
assert np.isfinite( |
|
|
view["depthmap"] |
|
|
).all(), f"NaN in depthmap for view {view_name(view)}" |
|
|
pts3d, valid_mask = depthmap_to_absolute_camera_coordinates(**view) |
|
|
|
|
|
view["pts3d"] = pts3d |
|
|
view["valid_mask"] = valid_mask & np.isfinite(pts3d).all(axis=-1) |
|
|
|
|
|
|
|
|
for key, val in view.items(): |
|
|
res, err_msg = is_good_type(key, val) |
|
|
assert res, f"{err_msg} with {key}={val} for view {view_name(view)}" |
|
|
K = view["camera_intrinsics"] |
|
|
view["img_mask"] = True |
|
|
view["ray_mask"] = False |
|
|
view["ray_map"] = torch.full( |
|
|
(6, view["img"].shape[-2], view["img"].shape[-1]), torch.nan |
|
|
) |
|
|
view["update"] = True |
|
|
view["reset"] = False |
|
|
|
|
|
|
|
|
for view in views: |
|
|
|
|
|
transpose_to_landscape(view) |
|
|
|
|
|
view["rng"] = int.from_bytes(self._rng.bytes(4), "big") |
|
|
return views |
|
|
|
|
|
def _set_resolutions(self, resolutions): |
|
|
"""Set the resolution(s) of the dataset. |
|
|
Params: |
|
|
- resolutions: int or tuple or list of tuples |
|
|
""" |
|
|
assert resolutions is not None, "undefined resolution" |
|
|
|
|
|
if not isinstance(resolutions, list): |
|
|
resolutions = [resolutions] |
|
|
|
|
|
self._resolutions = [] |
|
|
for resolution in resolutions: |
|
|
if isinstance(resolution, int): |
|
|
width = height = resolution |
|
|
else: |
|
|
width, height = resolution |
|
|
assert isinstance( |
|
|
width, int |
|
|
), f"Bad type for {width=} {type(width)=}, should be int" |
|
|
assert isinstance( |
|
|
height, int |
|
|
), f"Bad type for {height=} {type(height)=}, should be int" |
|
|
assert width >= height |
|
|
self._resolutions.append((width, height)) |
|
|
|
|
|
def _crop_resize_if_necessary( |
|
|
self, image, depthmap, intrinsics, resolution, rng=None, info=None |
|
|
): |
|
|
"""This function: |
|
|
- first downsizes the image with LANCZOS inteprolation, |
|
|
which is better than bilinear interpolation in |
|
|
""" |
|
|
if not isinstance(image, PIL.Image.Image): |
|
|
image = PIL.Image.fromarray(image) |
|
|
|
|
|
|
|
|
|
|
|
W, H = image.size |
|
|
cx, cy = intrinsics[:2, 2].round().astype(int) |
|
|
|
|
|
|
|
|
min_margin_x = min(cx, W - cx) |
|
|
min_margin_y = min(cy, H - cy) |
|
|
assert min_margin_x > W / 5, f"Bad principal point in view={info}" |
|
|
assert min_margin_y > H / 5, f"Bad principal point in view={info}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
l, t = cx - min_margin_x, cy - min_margin_y |
|
|
r, b = cx + min_margin_x, cy + min_margin_y |
|
|
crop_bbox = (l, t, r, b) |
|
|
|
|
|
image, depthmap, intrinsics = cropping.crop_image_depthmap( |
|
|
image, depthmap, intrinsics, crop_bbox |
|
|
) |
|
|
|
|
|
|
|
|
W, H = image.size |
|
|
assert resolution[0] >= resolution[1] |
|
|
if H > 1.1 * W: |
|
|
|
|
|
resolution = resolution[::-1] |
|
|
elif 0.9 < H / W < 1.1 and resolution[0] != resolution[1]: |
|
|
|
|
|
if rng.integers(2): |
|
|
resolution = resolution[::-1] |
|
|
|
|
|
|
|
|
target_resolution = np.array(resolution) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
image, depthmap, intrinsics = cropping.rescale_image_depthmap( |
|
|
image, depthmap, intrinsics, target_resolution |
|
|
) |
|
|
|
|
|
|
|
|
intrinsics2 = cropping.camera_matrix_of_crop( |
|
|
intrinsics, image.size, resolution, offset_factor=0.5 |
|
|
) |
|
|
crop_bbox = cropping.bbox_from_intrinsics_in_out( |
|
|
intrinsics, intrinsics2, resolution |
|
|
) |
|
|
image, depthmap, intrinsics = cropping.crop_image_depthmap( |
|
|
image, depthmap, intrinsics, crop_bbox |
|
|
) |
|
|
return image, depthmap, intrinsics |
|
|
|
|
|
|
|
|
def is_good_type(key, v): |
|
|
"""returns (is_good, err_msg)""" |
|
|
if isinstance(v, (str, int, tuple)): |
|
|
return True, None |
|
|
if v.dtype not in (np.float32, torch.float32, bool, np.int32, np.int64, np.uint8): |
|
|
return False, f"bad {v.dtype=}" |
|
|
return True, None |
|
|
|
|
|
|
|
|
def view_name(view, batch_index=None): |
|
|
def sel(x): |
|
|
return x[batch_index] if batch_index not in (None, slice(None)) else x |
|
|
|
|
|
db = sel(view["dataset"]) |
|
|
label = sel(view["label"]) |
|
|
instance = sel(view["instance"]) |
|
|
return f"{db}/{label}/{instance}" |
|
|
|
|
|
|
|
|
def transpose_to_landscape(view): |
|
|
height, width = view["true_shape"] |
|
|
|
|
|
if width < height: |
|
|
|
|
|
assert view["img"].shape == (3, height, width) |
|
|
view["img"] = view["img"].swapaxes(1, 2) |
|
|
|
|
|
assert view["valid_mask"].shape == (height, width) |
|
|
view["valid_mask"] = view["valid_mask"].swapaxes(0, 1) |
|
|
|
|
|
assert view["depthmap"].shape == (height, width) |
|
|
view["depthmap"] = view["depthmap"].swapaxes(0, 1) |
|
|
|
|
|
assert view["pts3d"].shape == (height, width, 3) |
|
|
view["pts3d"] = view["pts3d"].swapaxes(0, 1) |
|
|
|
|
|
|
|
|
view["camera_intrinsics"] = view["camera_intrinsics"][[1, 0, 2]] |
|
|
|