File size: 6,067 Bytes
874cec4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ed17f3
 
874cec4
3ed17f3
874cec4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import numpy as np
from scipy.spatial.transform import Rotation
import torch
from einops import repeat, rearrange
from easydict import EasyDict as edict
import torch.nn.functional as F

from source.rendering.aabb import intersect_aabb_end
from source.rendering.point_sampler import perturb_points_per_ray

def decompose_rotmat(R_c2w):
    R_cv2xyz = Rotation.from_euler("X", -90, degrees=True)
    rot_w2c = R_cv2xyz * Rotation.from_matrix(R_c2w).inv()
    roll, pitch, yaw = rot_w2c.as_euler("YXZ", degrees=True)
    return roll, pitch, yaw 

def normalize_angles(angles):
    """Normalize angles to be within the range [-180, 180] degrees."""
    return (np.array(angles) + 180) % 360 - 180

def compose_rotmat(roll, pitch, yaw):
    R_cv2xyz = Rotation.from_euler("X", -90, degrees=True)
    rot_c2w = Rotation.from_euler("YXZ", [roll, pitch, yaw], degrees=True)
    
    rot_w2c = R_cv2xyz.inv() * rot_c2w

    return rot_w2c.inv().as_matrix()


def fov_size2intrinsics(fov, img_size):
    """Converts field of view size to camera intrinsics."""
    if isinstance(fov, (int, float)):
        fov = [fov, fov]
    fov_x = np.deg2rad(fov[0])
    fov_y = np.deg2rad(fov[1])
    fx = (img_size[0] / 2) / np.tan(fov_x / 2)
    fy = (img_size[1] / 2) / np.tan(fov_y / 2)
    return np.array([[fx, 0, img_size[0] / 2],
                     [0, fy, img_size[1] / 2],
                     [0, 0, 1]])

def from_Euler_and_position_to_c2w(roll_pitch_yaw, position):
    roll, pitch, yaw = roll_pitch_yaw
    rot_c2w = compose_rotmat(roll, pitch, yaw)
    T_c2w = np.eye(4)
    T_c2w[:3, :3] = rot_c2w
    if isinstance(position, torch.Tensor):
        position = position.cpu().numpy()
    T_c2w[:3, 3] = position
    return T_c2w

class PointSamplerPerspective(torch.nn.Module):
    def __init__(self, num_points,aabb_strict=True,perturbation_strategy = 'uniform',render_size=[128,128]):
        super().__init__()
        """
        render_size: [H,W]
        num_points: number of points to sample along each ray
        aabb_strict: whether to use strict AABB for sampling
        perturbation_strategy: strategy for perturbing points along the ray
        """
        self.aabb_strict = aabb_strict,
        self.sample_total_length = np.sqrt(1.5**2+1.5**2+1.9**2)
        self.num_points = num_points
        self.perturbation_strategy = perturbation_strategy
        self.render_size = render_size
        

    @torch.no_grad()
    def forward(self, intrinsics, c2w):
        # c2w:  B x 4 x 4
        # intrinsics: B x 3 x 3
        # return: 
        # output = edict()
        # output.rays_world: B x H x W x C # direction of the rays
        # output.radii_raw: B x H x W x K
        # output.radii: B x H x W x K
        # output.ray_origins: B x H x W x C # origin of the rays
        # output.points_world: B x H x W x K x C
        
        batch_size = c2w.shape[0]
        t = c2w[:, :3, 3].clone()
        output = edict()
        device = c2w.device
        output.ray_origins = repeat(t, 'b c -> b h w c', h=self.render_size[0] , w=self.render_size[1]).to(device)
        output.ray_origins = output.ray_origins.clone()  # w -h z
        output.rays_world = compute_ray_directions(c2w.to(device), intrinsics.to(device), self.render_size[0], self.render_size[1])

        if self.aabb_strict:
            #  from b c to (b h w) c
            # origin_for_aabb = repeat(output.ray_origins, 'b c -> b h w c', h = H, w = W)
            # from  b h w c to (b h w) c
            # pano_direction_for_aabb = repeat(output.rays_world, 'b h w c -> b h w c', h = H, w = W)
            sample_total_length = intersect_aabb_end(output.ray_origins,output.rays_world,min=0,max=self.sample_total_length)
            sample_total_length = rearrange(sample_total_length, '(b h w) -> b h w 1', b=batch_size, h = self.render_size[0], w = self.render_size[1] )
            output.radii_raw = (torch.arange(self.num_points)+1)[None,None,None,:].to(sample_total_length.device) * (sample_total_length/self.num_points)
        else:
            raise NotImplementedError
        output.radii =  perturb_points_per_ray(output.radii_raw,strategy=self.perturbation_strategy)
        sample_point = output.ray_origins.unsqueeze(-1) + output.rays_world.unsqueeze(-1) * output.radii.unsqueeze(-2)
        output.points_world =  rearrange(sample_point, 'b h w c k -> b h w k c')
        # process_from w -h z to w h z 
        output.ray_origins[...,1] = -output.ray_origins[...,1]
        output.rays_world[...,1] = -output.rays_world[...,1]
        output.points_world[...,1] = -output.points_world[...,1]
        return output


        # return  output
    


def generate_pixel_coordinates(H, W):
    """
    Generate pixel coordinates grid on the image plane.

    Parameters:
    - H: Image height
    - W: Image width

    Returns:
    - pixel_coords: Pixel coordinates grid with shape [H, W, 3]
    """
    y, x = torch.meshgrid(torch.arange(H), torch.arange(W), indexing='ij')
    # to current device
    pixel_coords = torch.stack([x, y, torch.ones_like(x)], dim=-1).float().to(torch.cuda.current_device() if torch.cuda.is_available() else 'cpu')
    return pixel_coords

def compute_ray_directions(camera2world, intrinsics, H, W):
    """
    Compute ray directions.

    Parameters:
    - camera2world: Camera-to-world transformation matrix with shape [B, 4, 4]
    - intrinsics: Intrinsic matrix with shape [B, n, n]

    Returns:
    - ray_directions: Ray directions with shape [B, H, W, 3]
    """
    B = camera2world.shape[0]
    pixel_coords = generate_pixel_coordinates(H, W)  # [H, W, 3]
    pixel_coords = pixel_coords.unsqueeze(0).expand(B, -1, -1, -1)  # [B, H, W, 3]

    inv_intrinsics = torch.inverse(intrinsics)  # [B, n, n]
    normalized_coords = torch.einsum('bij,bhwj->bhwi', inv_intrinsics, pixel_coords)  # [B, H, W, 3]

    ray_directions = torch.einsum('bij,bhwj->bhwi', camera2world[:, :3, :3], normalized_coords)  # [B, H, W, 3]

    ray_directions = ray_directions / torch.norm(ray_directions, dim=-1, keepdim=True)

    return ray_directions