Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.autograd.profiler as profiler | |
| import numpy as np | |
| from einops import rearrange, repeat, einsum | |
| from .math_utils import get_ray_limits_box, linspace | |
| from ...modules.diffusionmodules.openaimodel import Timestep | |
| class ImageEncoder(nn.Module): | |
| def __init__(self, output_dim: int = 64) -> None: | |
| super().__init__() | |
| self.output_dim = output_dim | |
| def forward(self, image): | |
| return image | |
| class PositionalEncoding(torch.nn.Module): | |
| """ | |
| Implement NeRF's positional encoding | |
| """ | |
| def __init__(self, num_freqs=6, d_in=3, freq_factor=np.pi, include_input=True): | |
| super().__init__() | |
| self.num_freqs = num_freqs | |
| self.d_in = d_in | |
| self.freqs = freq_factor * 2.0 ** torch.arange(0, num_freqs) | |
| self.d_out = self.num_freqs * 2 * d_in | |
| self.include_input = include_input | |
| if include_input: | |
| self.d_out += d_in | |
| # f1 f1 f2 f2 ... to multiply x by | |
| self.register_buffer( | |
| "_freqs", torch.repeat_interleave(self.freqs, 2).view(1, -1, 1) | |
| ) | |
| # 0 pi/2 0 pi/2 ... so that | |
| # (sin(x + _phases[0]), sin(x + _phases[1]) ...) = (sin(x), cos(x)...) | |
| _phases = torch.zeros(2 * self.num_freqs) | |
| _phases[1::2] = np.pi * 0.5 | |
| self.register_buffer("_phases", _phases.view(1, -1, 1)) | |
| def forward(self, x): | |
| """ | |
| Apply positional encoding (new implementation) | |
| :param x (batch, self.d_in) | |
| :return (batch, self.d_out) | |
| """ | |
| with profiler.record_function("positional_enc"): | |
| # embed = x.unsqueeze(1).repeat(1, self.num_freqs * 2, 1) | |
| embed = repeat(x, "... C -> ... N C", N=self.num_freqs * 2) | |
| embed = torch.sin(torch.addcmul(self._phases, embed, self._freqs)) | |
| embed = rearrange(embed, "... N C -> ... (N C)") | |
| if self.include_input: | |
| embed = torch.cat((x, embed), dim=-1) | |
| return embed | |
| class RayGenerator(torch.nn.Module): | |
| """ | |
| from camera pose and intrinsics to ray origins and directions | |
| """ | |
| def __init__(self): | |
| super().__init__() | |
| ( | |
| self.ray_origins_h, | |
| self.ray_directions, | |
| self.depths, | |
| self.image_coords, | |
| self.rendering_options, | |
| ) = (None, None, None, None, None) | |
| def forward(self, cam2world_matrix, intrinsics, render_size): | |
| """ | |
| Create batches of rays and return origins and directions. | |
| cam2world_matrix: (N, 4, 4) | |
| intrinsics: (N, 3, 3) | |
| render_size: int | |
| ray_origins: (N, M, 3) | |
| ray_dirs: (N, M, 2) | |
| """ | |
| N, M = cam2world_matrix.shape[0], render_size**2 | |
| cam_locs_world = cam2world_matrix[:, :3, 3] | |
| fx = intrinsics[:, 0, 0] | |
| fy = intrinsics[:, 1, 1] | |
| cx = intrinsics[:, 0, 2] | |
| cy = intrinsics[:, 1, 2] | |
| sk = intrinsics[:, 0, 1] | |
| uv = torch.stack( | |
| torch.meshgrid( | |
| torch.arange( | |
| render_size, dtype=torch.float32, device=cam2world_matrix.device | |
| ), | |
| torch.arange( | |
| render_size, dtype=torch.float32, device=cam2world_matrix.device | |
| ), | |
| indexing="ij", | |
| ) | |
| ) | |
| uv = uv.flip(0).reshape(2, -1).transpose(1, 0) | |
| uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1) | |
| x_cam = uv[:, :, 0].view(N, -1) * (1.0 / render_size) + (0.5 / render_size) | |
| y_cam = uv[:, :, 1].view(N, -1) * (1.0 / render_size) + (0.5 / render_size) | |
| z_cam = torch.ones((N, M), device=cam2world_matrix.device) | |
| x_lift = ( | |
| ( | |
| x_cam | |
| - cx.unsqueeze(-1) | |
| + cy.unsqueeze(-1) * sk.unsqueeze(-1) / fy.unsqueeze(-1) | |
| - sk.unsqueeze(-1) * y_cam / fy.unsqueeze(-1) | |
| ) | |
| / fx.unsqueeze(-1) | |
| * z_cam | |
| ) | |
| y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam | |
| cam_rel_points = torch.stack( | |
| (x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1 | |
| ) | |
| # NOTE: this should be named _blender2opencv | |
| _opencv2blender = ( | |
| torch.tensor( | |
| [ | |
| [1, 0, 0, 0], | |
| [0, -1, 0, 0], | |
| [0, 0, -1, 0], | |
| [0, 0, 0, 1], | |
| ], | |
| dtype=torch.float32, | |
| device=cam2world_matrix.device, | |
| ) | |
| .unsqueeze(0) | |
| .repeat(N, 1, 1) | |
| ) | |
| cam2world_matrix = torch.bmm(cam2world_matrix, _opencv2blender) | |
| world_rel_points = torch.bmm( | |
| cam2world_matrix, cam_rel_points.permute(0, 2, 1) | |
| ).permute(0, 2, 1)[:, :, :3] | |
| ray_dirs = world_rel_points - cam_locs_world[:, None, :] | |
| ray_dirs = torch.nn.functional.normalize(ray_dirs, dim=2) | |
| ray_origins = cam_locs_world.unsqueeze(1).repeat(1, ray_dirs.shape[1], 1) | |
| return ray_origins, ray_dirs | |
| class RaySampler(torch.nn.Module): | |
| def __init__( | |
| self, | |
| num_samples_per_ray, | |
| bbox_length=1.0, | |
| near=0.5, | |
| far=10000.0, | |
| disparity=False, | |
| ): | |
| super().__init__() | |
| self.num_samples_per_ray = num_samples_per_ray | |
| self.bbox_length = bbox_length | |
| self.near = near | |
| self.far = far | |
| self.disparity = disparity | |
| def forward(self, ray_origins, ray_directions): | |
| if not self.disparity: | |
| t_start, t_end = get_ray_limits_box( | |
| ray_origins, ray_directions, 2 * self.bbox_length | |
| ) | |
| else: | |
| t_start = torch.full_like(ray_origins, self.near) | |
| t_end = torch.full_like(ray_origins, self.far) | |
| is_ray_valid = t_end > t_start | |
| if torch.any(is_ray_valid).item(): | |
| t_start[~is_ray_valid] = t_start[is_ray_valid].min() | |
| t_end[~is_ray_valid] = t_start[is_ray_valid].max() | |
| if not self.disparity: | |
| depths = linspace(t_start, t_end, self.num_samples_per_ray) | |
| depths += ( | |
| torch.rand_like(depths) | |
| * (t_end - t_start) | |
| / (self.num_samples_per_ray - 1) | |
| ) | |
| else: | |
| step = 1.0 / self.num_samples_per_ray | |
| z_steps = torch.linspace( | |
| 0, 1 - step, self.num_samples_per_ray, device=ray_origins.device | |
| ) | |
| z_steps += torch.rand_like(z_steps) * step | |
| depths = 1 / (1 / self.near * (1 - z_steps) + 1 / self.far * z_steps) | |
| depths = depths[..., None, None, None] | |
| return ray_origins[None] + ray_directions[None] * depths | |
| class PixelNeRF(torch.nn.Module): | |
| def __init__( | |
| self, | |
| num_samples_per_ray: int = 128, | |
| feature_dim: int = 64, | |
| interp: str = "bilinear", | |
| padding: str = "border", | |
| disparity: bool = False, | |
| near: float = 0.5, | |
| far: float = 10000.0, | |
| use_feats_std: bool = False, | |
| use_pos_emb: bool = False, | |
| ) -> None: | |
| super().__init__() | |
| # self.positional_encoder = Timestep(3) # TODO | |
| self.num_samples_per_ray = num_samples_per_ray | |
| self.ray_generator = RayGenerator() | |
| self.ray_sampler = RaySampler( | |
| num_samples_per_ray, near=near, far=far, disparity=disparity | |
| ) # TODO | |
| self.interp = interp | |
| self.padding = padding | |
| self.positional_encoder = PositionalEncoding() | |
| # self.feature_aggregator = nn.Linear(128, 129) # TODO | |
| self.use_feats_std = use_feats_std | |
| self.use_pos_emb = use_pos_emb | |
| d_in = feature_dim | |
| if use_feats_std: | |
| d_in += feature_dim | |
| if use_pos_emb: | |
| d_in += self.positional_encoder.d_out | |
| self.feature_aggregator = nn.Sequential( | |
| nn.Linear(d_in, 128), | |
| nn.ReLU(), | |
| nn.Linear(128, 128), | |
| nn.ReLU(), | |
| nn.Linear(128, 129), | |
| ) | |
| # self.decoder = nn.Linear(128, 131) # TODO | |
| self.decoder = nn.Sequential( | |
| nn.Linear(128, 128), | |
| nn.ReLU(), | |
| nn.Linear(128, 128), | |
| nn.ReLU(), | |
| nn.Linear(128, 131), | |
| ) | |
| def project(self, ray_samples, source_c2ws, source_instrincs): | |
| # TODO: implement | |
| # S for number of source cameras | |
| # ray_samples: [B, N, H * W, N_sample, 3] | |
| # source_c2ws: [B, S, 4, 4] | |
| # source_intrinsics: [B, S, 3, 3] | |
| # return [B, S, N, H * W, N_sample, 2] | |
| S = source_c2ws.shape[1] | |
| B = ray_samples.shape[0] | |
| N = ray_samples.shape[1] | |
| HW = ray_samples.shape[2] | |
| ray_samples = repeat( | |
| ray_samples, | |
| "B N HW N_sample C -> B S N HW N_sample C", | |
| S=source_c2ws.shape[1], | |
| ) | |
| padding = torch.ones((B, S, N, HW, self.num_samples_per_ray, 1)).to(ray_samples) | |
| ray_samples_homo = torch.cat([ray_samples, padding], dim=-1) | |
| source_c2ws = repeat(source_c2ws, "B S C1 C2 -> B S N 1 1 C1 C2", N=N) | |
| source_instrincs = repeat(source_instrincs, "B S C1 C2 -> B S N 1 1 C1 C2", N=N) | |
| source_w2c = source_c2ws.inverse() | |
| projected_samples = einsum( | |
| source_w2c, ray_samples_homo, "... i j, ... j -> ... i" | |
| )[..., :3] | |
| # NOTE: assumes opengl convention | |
| projected_samples = -1 * projected_samples[..., :2] / projected_samples[..., 2:] | |
| # NOTE: intrinsics here are normalized by resolution | |
| fx = source_instrincs[..., 0, 0] | |
| fy = source_instrincs[..., 1, 1] | |
| cx = source_instrincs[..., 0, 2] | |
| cy = source_instrincs[..., 1, 2] | |
| x = projected_samples[..., 0] * fx + cx | |
| # negative sign here is caused by opengl, F.grid_sample is consistent with openCV convention | |
| y = -projected_samples[..., 1] * fy + cy | |
| return torch.stack([x, y], dim=-1) | |
| def forward( | |
| self, image_feats, source_c2ws, source_intrinsics, c2ws, intrinsics, render_size | |
| ): | |
| # image_feats: [B S C H W] | |
| B = c2ws.shape[0] | |
| T = c2ws.shape[1] | |
| ray_origins, ray_directions = self.ray_generator( | |
| c2ws.reshape(-1, 4, 4), intrinsics.reshape(-1, 3, 3), render_size | |
| ) # [B * N, H * W, 3] | |
| # breakpoint() | |
| ray_samples = self.ray_sampler( | |
| ray_origins, ray_directions | |
| ) # [N_sample, B * N, H * W, 3] | |
| ray_samples = rearrange(ray_samples, "Ns (B N) HW C -> B N HW Ns C", B=B) | |
| projected_samples = self.project(ray_samples, source_c2ws, source_intrinsics) | |
| # # debug | |
| # p = projected_samples[:, :, 0, :, 0, :] | |
| # p = p.reshape(p.shape[0] * p.shape[1], *p.shape[2:]) | |
| # breakpoint() | |
| # image_feats = repeat(image_feats, "B S C H W -> (B S N) C H W", N=T) | |
| image_feats = rearrange(image_feats, "B S C H W -> (B S) C H W") | |
| projected_samples = rearrange( | |
| projected_samples, "B S N HW Ns xy -> (B S) (N Ns) HW xy" | |
| ) | |
| # make sure the projected samples are in the range of [-1, 1], as required by F.grid_sample | |
| joint = F.grid_sample( | |
| image_feats, | |
| projected_samples * 2.0 - 1.0, | |
| padding_mode=self.padding, | |
| mode=self.interp, | |
| align_corners=True, | |
| ) | |
| # print("image_feats", image_feats.max(), image_feats.min()) | |
| # print("samples", projected_samples.max(), projected_samples.min()) | |
| joint = rearrange( | |
| joint, | |
| "(B S) C (N Ns) HW -> B S N HW Ns C", | |
| B=B, | |
| Ns=self.num_samples_per_ray, | |
| ) | |
| reduced = torch.mean(joint, dim=1) # reduce on source dimension | |
| if self.use_feats_std: | |
| if not joint.shape[1] == 1: | |
| reduced = torch.cat((reduced, joint.std(dim=1)), dim=-1) | |
| else: | |
| reduced = torch.cat((reduced, torch.zeros_like(reduced)), dim=-1) | |
| if self.use_pos_emb: | |
| reduced = torch.cat((reduced, self.positional_encoder(ray_samples)), dim=-1) | |
| reduced = self.feature_aggregator(reduced) | |
| feats, weights = reduced.split([reduced.shape[-1] - 1, 1], dim=-1) | |
| # feats: [B, N, H * W, N_samples, N_c] | |
| # weights: [B, N, H * W, N_samples, 1] | |
| weights = F.softmax(weights, dim=-2) | |
| feats = torch.sum(feats * weights, dim=-2) | |
| rgb, feats = self.decoder(feats).split([3, 128], dim=-1) | |
| rgb = F.sigmoid(rgb) | |
| rgb = rearrange(rgb, "B N (H W) C -> B N C H W", H=render_size) | |
| feats = rearrange(feats, "B N (H W) C -> B N C H W", H=render_size) | |
| # print(rgb.max(), rgb.min()) | |
| # print(feats.max(), feats.min()) | |
| return rgb, feats | |