qian43's picture
Upload 115 files
874cec4 verified
import numpy as np
import torch
from typing import Literal, Tuple
# borrowed from https://github.com/nerfstudio-project/nerfstudio
from einops import rearrange
def intersect_aabb(
origins: torch.Tensor,
directions: torch.Tensor,
aabb: torch.Tensor = torch.tensor([-1., -1., -1., 1., 1., 1.]).float(),
max_bound: float = 1e10,
invalid_value: float = 1e10,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Implementation of ray intersection with AABB box
Args:
origins: [N,3] tensor of 3d positions
directions: [N,3] tensor of normalized directions
aabb: [6] array of aabb box in the form of [x_min, y_min, z_min, x_max, y_max, z_max]
max_bound: Maximum value of t_max
invalid_value: Value to return in case of no intersection
Returns:
t_min, t_max - two tensors of shapes N representing distance of intersection from the origin.
"""
# send aabb to origins's device
if len(origins.shape) == 4:
origins = rearrange(origins, 'b h w c -> (b h w) c')
directions = rearrange(directions, 'b h w c -> (b h w) c')
directions = directions.to(origins.device)
aabb = aabb.to(origins.device)
B = max(origins.size(0), directions.size(0))
tx_min = (aabb[:3] - origins) / directions
tx_max = (aabb[3:] - origins) / directions
t_min = torch.stack((tx_min, tx_max)).amin(dim=0)
t_max = torch.stack((tx_min, tx_max)).amax(dim=0)
t_min = t_min.amax(dim=-1)
t_max = t_max.amin(dim=-1)
t_min = torch.clamp(t_min, min=0, max=max_bound)
t_max = torch.clamp(t_max, min=0, max=max_bound)
cond = t_max <= t_min
### fix a bug may caused by version change?
invalid_value = torch.tensor(invalid_value).repeat(B).float().to(origins.device)
t_min = torch.where(cond, invalid_value, t_min)
t_max = torch.where(cond, invalid_value, t_max)
# tmax is what I need
return t_min, t_max
def intersect_aabb_end(origin,dir,min=0,max=4):
t_max = intersect_aabb(origin,dir)[1]
assert torch.isnan(t_max).any() == False , "nan in t_max of intersect_aabb_end"
assert min < t_max.min() < max, "t_max out of range %s, min is %s, max is %s" % (t_max.min(), min, max)
return t_max