File size: 2,236 Bytes
874cec4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import numpy as np
import torch
from typing import Literal, Tuple
# borrowed from https://github.com/nerfstudio-project/nerfstudio
from einops import rearrange

def intersect_aabb(
    origins: torch.Tensor,
    directions: torch.Tensor,
    aabb: torch.Tensor = torch.tensor([-1., -1., -1., 1., 1., 1.]).float(),
    max_bound: float = 1e10,
    invalid_value: float = 1e10,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Implementation of ray intersection with AABB box

    Args:
        origins: [N,3] tensor of 3d positions
        directions: [N,3] tensor of normalized directions
        aabb: [6] array of aabb box in the form of [x_min, y_min, z_min, x_max, y_max, z_max]
        max_bound: Maximum value of t_max
        invalid_value: Value to return in case of no intersection

    Returns:
        t_min, t_max - two tensors of shapes N representing distance of intersection from the origin.
    """
    # send aabb to origins's device
    if len(origins.shape) == 4:
        origins = rearrange(origins, 'b h w c -> (b h w) c')
        directions = rearrange(directions, 'b h w c -> (b h w) c')
    directions = directions.to(origins.device)
    aabb = aabb.to(origins.device)
    B = max(origins.size(0), directions.size(0))
    tx_min = (aabb[:3] - origins) / directions
    tx_max = (aabb[3:] - origins) / directions
    
    t_min = torch.stack((tx_min, tx_max)).amin(dim=0)
    t_max = torch.stack((tx_min, tx_max)).amax(dim=0)

    t_min = t_min.amax(dim=-1)
    t_max = t_max.amin(dim=-1)

    t_min = torch.clamp(t_min, min=0, max=max_bound)
    t_max = torch.clamp(t_max, min=0, max=max_bound)

    cond = t_max <= t_min
    ### fix a bug may caused by version change?
    invalid_value = torch.tensor(invalid_value).repeat(B).float().to(origins.device)
    t_min = torch.where(cond, invalid_value, t_min)
    t_max = torch.where(cond, invalid_value, t_max)
    # tmax is what I need
    return t_min, t_max

def intersect_aabb_end(origin,dir,min=0,max=4):
    t_max = intersect_aabb(origin,dir)[1]
    assert torch.isnan(t_max).any() == False , "nan in t_max of intersect_aabb_end"
    assert min < t_max.min() < max, "t_max out of range %s, min is %s, max is %s" % (t_max.min(), min, max)
    return t_max