| | import torch |
| | import torch.nn as nn |
| | import numpy as np |
| | from einops import rearrange |
| | import os |
| | from typing_extensions import Literal |
| |
|
| | class SimpleAdapter(nn.Module): |
| | def __init__(self, in_dim, out_dim, kernel_size, stride, num_residual_blocks=1): |
| | super(SimpleAdapter, self).__init__() |
| |
|
| | |
| | self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=8) |
| |
|
| | |
| | |
| | self.conv = nn.Conv2d(in_dim * 64, out_dim, kernel_size=kernel_size, stride=stride, padding=0) |
| |
|
| | |
| | self.residual_blocks = nn.Sequential( |
| | *[ResidualBlock(out_dim) for _ in range(num_residual_blocks)] |
| | ) |
| |
|
| | def forward(self, x): |
| | |
| | bs, c, f, h, w = x.size() |
| | x = x.permute(0, 2, 1, 3, 4).contiguous().view(bs * f, c, h, w) |
| |
|
| | |
| | x_unshuffled = self.pixel_unshuffle(x) |
| |
|
| | |
| | x_conv = self.conv(x_unshuffled) |
| |
|
| | |
| | out = self.residual_blocks(x_conv) |
| |
|
| | |
| | out = out.view(bs, f, out.size(1), out.size(2), out.size(3)) |
| |
|
| | |
| | out = out.permute(0, 2, 1, 3, 4) |
| |
|
| | return out |
| | |
| | def process_camera_coordinates( |
| | self, |
| | direction: Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"], |
| | length: int, |
| | height: int, |
| | width: int, |
| | speed: float = 1/54, |
| | origin=(0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0) |
| | ): |
| | if origin is None: |
| | origin = (0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0) |
| | coordinates = generate_camera_coordinates(direction, length, speed, origin) |
| | plucker_embedding = process_pose_file(coordinates, width, height) |
| | return plucker_embedding |
| | |
| | |
| |
|
| | class ResidualBlock(nn.Module): |
| | def __init__(self, dim): |
| | super(ResidualBlock, self).__init__() |
| | self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, padding=1) |
| | self.relu = nn.ReLU(inplace=True) |
| | self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, padding=1) |
| |
|
| | def forward(self, x): |
| | residual = x |
| | out = self.relu(self.conv1(x)) |
| | out = self.conv2(out) |
| | out += residual |
| | return out |
| | |
| | class Camera(object): |
| | """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py |
| | """ |
| | def __init__(self, entry): |
| | fx, fy, cx, cy = entry[1:5] |
| | self.fx = fx |
| | self.fy = fy |
| | self.cx = cx |
| | self.cy = cy |
| | w2c_mat = np.array(entry[7:]).reshape(3, 4) |
| | w2c_mat_4x4 = np.eye(4) |
| | w2c_mat_4x4[:3, :] = w2c_mat |
| | self.w2c_mat = w2c_mat_4x4 |
| | self.c2w_mat = np.linalg.inv(w2c_mat_4x4) |
| |
|
| | def get_relative_pose(cam_params): |
| | """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py |
| | """ |
| | abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params] |
| | abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params] |
| | cam_to_origin = 0 |
| | target_cam_c2w = np.array([ |
| | [1, 0, 0, 0], |
| | [0, 1, 0, -cam_to_origin], |
| | [0, 0, 1, 0], |
| | [0, 0, 0, 1] |
| | ]) |
| | abs2rel = target_cam_c2w @ abs_w2cs[0] |
| | ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]] |
| | ret_poses = np.array(ret_poses, dtype=np.float32) |
| | return ret_poses |
| |
|
| | def custom_meshgrid(*args): |
| | |
| | return torch.meshgrid(*args, indexing='ij') |
| |
|
| |
|
| | def ray_condition(K, c2w, H, W, device): |
| | """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py |
| | """ |
| | |
| | |
| |
|
| | B = K.shape[0] |
| |
|
| | j, i = custom_meshgrid( |
| | torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype), |
| | torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype), |
| | ) |
| | i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 |
| | j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 |
| |
|
| | fx, fy, cx, cy = K.chunk(4, dim=-1) |
| |
|
| | zs = torch.ones_like(i) |
| | xs = (i - cx) / fx * zs |
| | ys = (j - cy) / fy * zs |
| | zs = zs.expand_as(ys) |
| |
|
| | directions = torch.stack((xs, ys, zs), dim=-1) |
| | directions = directions / directions.norm(dim=-1, keepdim=True) |
| |
|
| | rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) |
| | rays_o = c2w[..., :3, 3] |
| | rays_o = rays_o[:, :, None].expand_as(rays_d) |
| | |
| | rays_dxo = torch.linalg.cross(rays_o, rays_d) |
| | plucker = torch.cat([rays_dxo, rays_d], dim=-1) |
| | plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) |
| | |
| | return plucker |
| |
|
| |
|
| | def process_pose_file(cam_params, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu', return_poses=False): |
| | if return_poses: |
| | return cam_params |
| | else: |
| | cam_params = [Camera(cam_param) for cam_param in cam_params] |
| |
|
| | sample_wh_ratio = width / height |
| | pose_wh_ratio = original_pose_width / original_pose_height |
| |
|
| | if pose_wh_ratio > sample_wh_ratio: |
| | resized_ori_w = height * pose_wh_ratio |
| | for cam_param in cam_params: |
| | cam_param.fx = resized_ori_w * cam_param.fx / width |
| | else: |
| | resized_ori_h = width / pose_wh_ratio |
| | for cam_param in cam_params: |
| | cam_param.fy = resized_ori_h * cam_param.fy / height |
| |
|
| | intrinsic = np.asarray([[cam_param.fx * width, |
| | cam_param.fy * height, |
| | cam_param.cx * width, |
| | cam_param.cy * height] |
| | for cam_param in cam_params], dtype=np.float32) |
| |
|
| | K = torch.as_tensor(intrinsic)[None] |
| | c2ws = get_relative_pose(cam_params) |
| | c2ws = torch.as_tensor(c2ws)[None] |
| | plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() |
| | plucker_embedding = plucker_embedding[None] |
| | plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0] |
| | return plucker_embedding |
| |
|
| |
|
| |
|
| | def generate_camera_coordinates( |
| | direction: Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"], |
| | length: int, |
| | speed: float = 1/54, |
| | origin=(0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0) |
| | ): |
| | coordinates = [list(origin)] |
| | while len(coordinates) < length: |
| | coor = coordinates[-1].copy() |
| | if "Left" in direction: |
| | coor[9] += speed |
| | if "Right" in direction: |
| | coor[9] -= speed |
| | if "Up" in direction: |
| | coor[13] += speed |
| | if "Down" in direction: |
| | coor[13] -= speed |
| | coordinates.append(coor) |
| | return coordinates |
| |
|