| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from dataclasses import dataclass |
| | from typing import Tuple |
| |
|
| | import numpy as np |
| | import torch |
| |
|
| |
|
| | @dataclass |
| | class DifferentiableProjectiveCamera: |
| | """ |
| | Implements a batch, differentiable, standard pinhole camera |
| | """ |
| |
|
| | origin: torch.Tensor |
| | x: torch.Tensor |
| | y: torch.Tensor |
| | z: torch.Tensor |
| | width: int |
| | height: int |
| | x_fov: float |
| | y_fov: float |
| | shape: Tuple[int] |
| |
|
| | def __post_init__(self): |
| | assert self.x.shape[0] == self.y.shape[0] == self.z.shape[0] == self.origin.shape[0] |
| | assert self.x.shape[1] == self.y.shape[1] == self.z.shape[1] == self.origin.shape[1] == 3 |
| | assert len(self.x.shape) == len(self.y.shape) == len(self.z.shape) == len(self.origin.shape) == 2 |
| |
|
| | def resolution(self): |
| | return torch.from_numpy(np.array([self.width, self.height], dtype=np.float32)) |
| |
|
| | def fov(self): |
| | return torch.from_numpy(np.array([self.x_fov, self.y_fov], dtype=np.float32)) |
| |
|
| | def get_image_coords(self) -> torch.Tensor: |
| | """ |
| | :return: coords of shape (width * height, 2) |
| | """ |
| | pixel_indices = torch.arange(self.height * self.width) |
| | coords = torch.stack( |
| | [ |
| | pixel_indices % self.width, |
| | torch.div(pixel_indices, self.width, rounding_mode="trunc"), |
| | ], |
| | axis=1, |
| | ) |
| | return coords |
| |
|
| | @property |
| | def camera_rays(self): |
| | batch_size, *inner_shape = self.shape |
| | inner_batch_size = int(np.prod(inner_shape)) |
| |
|
| | coords = self.get_image_coords() |
| | coords = torch.broadcast_to(coords.unsqueeze(0), [batch_size * inner_batch_size, *coords.shape]) |
| | rays = self.get_camera_rays(coords) |
| |
|
| | rays = rays.view(batch_size, inner_batch_size * self.height * self.width, 2, 3) |
| |
|
| | return rays |
| |
|
| | def get_camera_rays(self, coords: torch.Tensor) -> torch.Tensor: |
| | batch_size, *shape, n_coords = coords.shape |
| | assert n_coords == 2 |
| | assert batch_size == self.origin.shape[0] |
| |
|
| | flat = coords.view(batch_size, -1, 2) |
| |
|
| | res = self.resolution() |
| | fov = self.fov() |
| |
|
| | fracs = (flat.float() / (res - 1)) * 2 - 1 |
| | fracs = fracs * torch.tan(fov / 2) |
| |
|
| | fracs = fracs.view(batch_size, -1, 2) |
| | directions = ( |
| | self.z.view(batch_size, 1, 3) |
| | + self.x.view(batch_size, 1, 3) * fracs[:, :, :1] |
| | + self.y.view(batch_size, 1, 3) * fracs[:, :, 1:] |
| | ) |
| | directions = directions / directions.norm(dim=-1, keepdim=True) |
| | rays = torch.stack( |
| | [ |
| | torch.broadcast_to(self.origin.view(batch_size, 1, 3), [batch_size, directions.shape[1], 3]), |
| | directions, |
| | ], |
| | dim=2, |
| | ) |
| | return rays.view(batch_size, *shape, 2, 3) |
| |
|
| | def resize_image(self, width: int, height: int) -> "DifferentiableProjectiveCamera": |
| | """ |
| | Creates a new camera for the resized view assuming the aspect ratio does not change. |
| | """ |
| | assert width * self.height == height * self.width, "The aspect ratio should not change." |
| | return DifferentiableProjectiveCamera( |
| | origin=self.origin, |
| | x=self.x, |
| | y=self.y, |
| | z=self.z, |
| | width=width, |
| | height=height, |
| | x_fov=self.x_fov, |
| | y_fov=self.y_fov, |
| | ) |
| |
|
| |
|
| | def create_pan_cameras(size: int) -> DifferentiableProjectiveCamera: |
| | origins = [] |
| | xs = [] |
| | ys = [] |
| | zs = [] |
| | for theta in np.linspace(0, 2 * np.pi, num=20): |
| | z = np.array([np.sin(theta), np.cos(theta), -0.5]) |
| | z /= np.sqrt(np.sum(z**2)) |
| | origin = -z * 4 |
| | x = np.array([np.cos(theta), -np.sin(theta), 0.0]) |
| | y = np.cross(z, x) |
| | origins.append(origin) |
| | xs.append(x) |
| | ys.append(y) |
| | zs.append(z) |
| | return DifferentiableProjectiveCamera( |
| | origin=torch.from_numpy(np.stack(origins, axis=0)).float(), |
| | x=torch.from_numpy(np.stack(xs, axis=0)).float(), |
| | y=torch.from_numpy(np.stack(ys, axis=0)).float(), |
| | z=torch.from_numpy(np.stack(zs, axis=0)).float(), |
| | width=size, |
| | height=size, |
| | x_fov=0.7, |
| | y_fov=0.7, |
| | shape=(1, len(xs)), |
| | ) |
| |
|