| import numpy as np |
| import torch |
| from typing import Literal, Tuple |
| |
| from einops import rearrange |
|
|
| def intersect_aabb( |
| origins: torch.Tensor, |
| directions: torch.Tensor, |
| aabb: torch.Tensor = torch.tensor([-1., -1., -1., 1., 1., 1.]).float(), |
| max_bound: float = 1e10, |
| invalid_value: float = 1e10, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Implementation of ray intersection with AABB box |
| |
| Args: |
| origins: [N,3] tensor of 3d positions |
| directions: [N,3] tensor of normalized directions |
| aabb: [6] array of aabb box in the form of [x_min, y_min, z_min, x_max, y_max, z_max] |
| max_bound: Maximum value of t_max |
| invalid_value: Value to return in case of no intersection |
| |
| Returns: |
| t_min, t_max - two tensors of shapes N representing distance of intersection from the origin. |
| """ |
| |
| if len(origins.shape) == 4: |
| origins = rearrange(origins, 'b h w c -> (b h w) c') |
| directions = rearrange(directions, 'b h w c -> (b h w) c') |
| directions = directions.to(origins.device) |
| aabb = aabb.to(origins.device) |
| B = max(origins.size(0), directions.size(0)) |
| tx_min = (aabb[:3] - origins) / directions |
| tx_max = (aabb[3:] - origins) / directions |
| |
| t_min = torch.stack((tx_min, tx_max)).amin(dim=0) |
| t_max = torch.stack((tx_min, tx_max)).amax(dim=0) |
|
|
| t_min = t_min.amax(dim=-1) |
| t_max = t_max.amin(dim=-1) |
|
|
| t_min = torch.clamp(t_min, min=0, max=max_bound) |
| t_max = torch.clamp(t_max, min=0, max=max_bound) |
|
|
| cond = t_max <= t_min |
| |
| invalid_value = torch.tensor(invalid_value).repeat(B).float().to(origins.device) |
| t_min = torch.where(cond, invalid_value, t_min) |
| t_max = torch.where(cond, invalid_value, t_max) |
| |
| return t_min, t_max |
|
|
| def intersect_aabb_end(origin,dir,min=0,max=4): |
| t_max = intersect_aabb(origin,dir)[1] |
| assert torch.isnan(t_max).any() == False , "nan in t_max of intersect_aabb_end" |
| assert min < t_max.min() < max, "t_max out of range %s, min is %s, max is %s" % (t_max.min(), min, max) |
| return t_max |
|
|