| | import numpy as np |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | import roma |
| | from kiui.op import safe_normalize |
| |
|
| | def get_rays(pose, h, w, fovy, opengl=True): |
| |
|
| | x, y = torch.meshgrid( |
| | torch.arange(w, device=pose.device), |
| | torch.arange(h, device=pose.device), |
| | indexing="xy", |
| | ) |
| | x = x.flatten() |
| | y = y.flatten() |
| |
|
| | cx = w * 0.5 |
| | cy = h * 0.5 |
| |
|
| | focal = h * 0.5 / np.tan(0.5 * np.deg2rad(fovy)) |
| |
|
| | camera_dirs = F.pad( |
| | torch.stack( |
| | [ |
| | (x - cx + 0.5) / focal, |
| | (y - cy + 0.5) / focal * (-1.0 if opengl else 1.0), |
| | ], |
| | dim=-1, |
| | ), |
| | (0, 1), |
| | value=(-1.0 if opengl else 1.0), |
| | ) |
| |
|
| | rays_d = camera_dirs @ pose[:3, :3].transpose(0, 1) |
| | rays_o = pose[:3, 3].unsqueeze(0).expand_as(rays_d) |
| |
|
| | rays_o = rays_o.view(h, w, 3) |
| | rays_d = safe_normalize(rays_d).view(h, w, 3) |
| |
|
| | return rays_o, rays_d |
| |
|
| | def orbit_camera_jitter(poses, strength=0.1): |
| | |
| | |
| |
|
| | B = poses.shape[0] |
| | rotvec_x = poses[:, :3, 1] * strength * np.pi * (torch.rand(B, 1, device=poses.device) * 2 - 1) |
| | rotvec_y = poses[:, :3, 0] * strength * np.pi / 2 * (torch.rand(B, 1, device=poses.device) * 2 - 1) |
| |
|
| | rot = roma.rotvec_to_rotmat(rotvec_x) @ roma.rotvec_to_rotmat(rotvec_y) |
| | R = rot @ poses[:, :3, :3] |
| | T = rot @ poses[:, :3, 3:] |
| |
|
| | new_poses = poses.clone() |
| | new_poses[:, :3, :3] = R |
| | new_poses[:, :3, 3:] = T |
| | |
| | return new_poses |
| |
|
| | def grid_distortion(images, strength=0.5): |
| | |
| | |
| | |
| |
|
| | B, C, H, W = images.shape |
| |
|
| | num_steps = np.random.randint(8, 17) |
| | grid_steps = torch.linspace(-1, 1, num_steps) |
| |
|
| | |
| | grids = [] |
| | for b in range(B): |
| | |
| | x_steps = torch.linspace(0, 1, num_steps) |
| | x_steps = (x_steps + strength * (torch.rand_like(x_steps) - 0.5) / (num_steps - 1)).clamp(0, 1) |
| | x_steps = (x_steps * W).long() |
| | x_steps[0] = 0 |
| | x_steps[-1] = W |
| | xs = [] |
| | for i in range(num_steps - 1): |
| | xs.append(torch.linspace(grid_steps[i], grid_steps[i + 1], x_steps[i + 1] - x_steps[i])) |
| | xs = torch.cat(xs, dim=0) |
| |
|
| | y_steps = torch.linspace(0, 1, num_steps) |
| | y_steps = (y_steps + strength * (torch.rand_like(y_steps) - 0.5) / (num_steps - 1)).clamp(0, 1) |
| | y_steps = (y_steps * H).long() |
| | y_steps[0] = 0 |
| | y_steps[-1] = H |
| | ys = [] |
| | for i in range(num_steps - 1): |
| | ys.append(torch.linspace(grid_steps[i], grid_steps[i + 1], y_steps[i + 1] - y_steps[i])) |
| | ys = torch.cat(ys, dim=0) |
| |
|
| | |
| | grid_x, grid_y = torch.meshgrid(xs, ys, indexing='xy') |
| | grid = torch.stack([grid_x, grid_y], dim=-1) |
| |
|
| | grids.append(grid) |
| | |
| | grids = torch.stack(grids, dim=0).to(images.device) |
| |
|
| | |
| | images = F.grid_sample(images, grids, align_corners=False) |
| |
|
| | return images |
| |
|
| |
|