|
|
|
|
|
|
|
|
from typing import Callable, Optional, Union, Sequence
|
|
|
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
import torchvision.transforms.functional as tvf
|
|
|
import collections
|
|
|
from scipy.spatial.transform import Rotation
|
|
|
|
|
|
from ..utils.geometry import from_homogeneous, to_homogeneous
|
|
|
from ..utils.wrappers import Camera
|
|
|
|
|
|
|
|
|
def rectify_image(
|
|
|
image: torch.Tensor,
|
|
|
cam: Camera,
|
|
|
roll: float,
|
|
|
pitch: Optional[float] = None,
|
|
|
valid: Optional[torch.Tensor] = None,
|
|
|
):
|
|
|
*_, h, w = image.shape
|
|
|
grid = torch.meshgrid(
|
|
|
[torch.arange(w, device=image.device), torch.arange(h, device=image.device)],
|
|
|
indexing="xy",
|
|
|
)
|
|
|
grid = torch.stack(grid, -1).to(image.dtype)
|
|
|
|
|
|
if pitch is not None:
|
|
|
args = ("ZX", (roll, pitch))
|
|
|
else:
|
|
|
args = ("Z", roll)
|
|
|
R = Rotation.from_euler(*args, degrees=True).as_matrix()
|
|
|
R = torch.from_numpy(R).to(image)
|
|
|
|
|
|
grid_rect = to_homogeneous(cam.normalize(grid)) @ R.T
|
|
|
grid_rect = cam.denormalize(from_homogeneous(grid_rect))
|
|
|
grid_norm = (grid_rect + 0.5) / grid.new_tensor([w, h]) * 2 - 1
|
|
|
rectified = torch.nn.functional.grid_sample(
|
|
|
image[None],
|
|
|
grid_norm[None],
|
|
|
align_corners=False,
|
|
|
mode="bilinear",
|
|
|
).squeeze(0)
|
|
|
if valid is None:
|
|
|
valid = torch.all((grid_norm >= -1) & (grid_norm <= 1), -1)
|
|
|
else:
|
|
|
valid = (
|
|
|
torch.nn.functional.grid_sample(
|
|
|
valid[None, None].float(),
|
|
|
grid_norm[None],
|
|
|
align_corners=False,
|
|
|
mode="nearest",
|
|
|
)[0, 0]
|
|
|
> 0
|
|
|
)
|
|
|
return rectified, valid
|
|
|
|
|
|
|
|
|
def resize_image(
|
|
|
image: torch.Tensor,
|
|
|
size: Union[int, Sequence, np.ndarray],
|
|
|
fn: Optional[Callable] = None,
|
|
|
camera: Optional[Camera] = None,
|
|
|
valid: np.ndarray = None,
|
|
|
):
|
|
|
"""Resize an image to a fixed size, or according to max or min edge."""
|
|
|
*_, h, w = image.shape
|
|
|
if fn is not None:
|
|
|
assert isinstance(size, int)
|
|
|
scale = size / fn(h, w)
|
|
|
h_new, w_new = int(round(h * scale)), int(round(w * scale))
|
|
|
scale = (scale, scale)
|
|
|
else:
|
|
|
if isinstance(size, (collections.abc.Sequence, np.ndarray)):
|
|
|
w_new, h_new = size
|
|
|
elif isinstance(size, int):
|
|
|
w_new = h_new = size
|
|
|
else:
|
|
|
raise ValueError(f"Incorrect new size: {size}")
|
|
|
scale = (w_new / w, h_new / h)
|
|
|
if (w, h) != (w_new, h_new):
|
|
|
mode = tvf.InterpolationMode.BILINEAR
|
|
|
image = tvf.resize(image, (int(h_new), int(w_new)), interpolation=mode, antialias=True)
|
|
|
image.clip_(0, 1)
|
|
|
if camera is not None:
|
|
|
camera = camera.scale(scale)
|
|
|
if valid is not None:
|
|
|
valid = tvf.resize(
|
|
|
valid.unsqueeze(0),
|
|
|
(int(h_new), int(w_new)),
|
|
|
interpolation=tvf.InterpolationMode.NEAREST,
|
|
|
).squeeze(0)
|
|
|
ret = [image, scale]
|
|
|
if camera is not None:
|
|
|
ret.append(camera)
|
|
|
if valid is not None:
|
|
|
ret.append(valid)
|
|
|
return ret
|
|
|
|
|
|
|
|
|
def pad_image(
|
|
|
image: torch.Tensor,
|
|
|
size: Union[int, Sequence, np.ndarray],
|
|
|
camera: Optional[Camera] = None,
|
|
|
valid: torch.Tensor = None,
|
|
|
crop_and_center: bool = False,
|
|
|
):
|
|
|
if isinstance(size, int):
|
|
|
w_new = h_new = size
|
|
|
elif isinstance(size, (collections.abc.Sequence, np.ndarray)):
|
|
|
w_new, h_new = size
|
|
|
else:
|
|
|
raise ValueError(f"Incorrect new size: {size}")
|
|
|
*c, h, w = image.shape
|
|
|
if crop_and_center:
|
|
|
diff = np.array([w - w_new, h - h_new])
|
|
|
left, top = left_top = np.round(diff / 2).astype(int)
|
|
|
right, bottom = diff - left_top
|
|
|
else:
|
|
|
assert h <= h_new
|
|
|
assert w <= w_new
|
|
|
top = bottom = left = right = 0
|
|
|
slice_out = np.s_[..., : min(h, h_new), : min(w, w_new)]
|
|
|
slice_in = np.s_[
|
|
|
..., max(top, 0) : h - max(bottom, 0), max(left, 0) : w - max(right, 0)
|
|
|
]
|
|
|
if (w, h) == (w_new, h_new):
|
|
|
out = image
|
|
|
else:
|
|
|
out = torch.zeros((*c, h_new, w_new), dtype=image.dtype)
|
|
|
out[slice_out] = image[slice_in]
|
|
|
if camera is not None:
|
|
|
camera = camera.crop((max(left, 0), max(top, 0)), (w_new, h_new))
|
|
|
out_valid = torch.zeros((h_new, w_new), dtype=torch.bool)
|
|
|
out_valid[slice_out] = True if valid is None else valid[slice_in]
|
|
|
if camera is not None:
|
|
|
return out, out_valid, camera
|
|
|
else:
|
|
|
return out, out_valid
|
|
|
|