Create math_utils.py
Browse fileswe are flattening the directory, since HF only supports flat imports
- math_utils.py +123 -0
math_utils.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
# MIT License
|
| 7 |
+
|
| 8 |
+
# Copyright (c) 2022 Petr Kellnhofer
|
| 9 |
+
|
| 10 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 11 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 12 |
+
# in the Software without restriction, including without limitation the rights
|
| 13 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 14 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 15 |
+
# furnished to do so, subject to the following conditions:
|
| 16 |
+
|
| 17 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 18 |
+
# copies or substantial portions of the Software.
|
| 19 |
+
|
| 20 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 21 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 22 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 23 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 24 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 25 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 26 |
+
# SOFTWARE.
|
| 27 |
+
|
| 28 |
+
import torch
|
| 29 |
+
|
| 30 |
+
def transform_vectors(matrix: torch.Tensor, vectors4: torch.Tensor) -> torch.Tensor:
|
| 31 |
+
"""
|
| 32 |
+
Left-multiplies MxM @ NxM. Returns NxM.
|
| 33 |
+
"""
|
| 34 |
+
res = torch.matmul(vectors4, matrix.T)
|
| 35 |
+
return res
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor:
|
| 39 |
+
"""
|
| 40 |
+
Normalize vector lengths.
|
| 41 |
+
"""
|
| 42 |
+
return vectors / (torch.norm(vectors, dim=-1, keepdim=True))
|
| 43 |
+
|
| 44 |
+
def torch_dot(x: torch.Tensor, y: torch.Tensor):
|
| 45 |
+
"""
|
| 46 |
+
Dot product of two tensors.
|
| 47 |
+
"""
|
| 48 |
+
return (x * y).sum(-1)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def get_ray_limits_box(rays_o: torch.Tensor, rays_d: torch.Tensor, box_side_length):
|
| 52 |
+
"""
|
| 53 |
+
Author: Petr Kellnhofer
|
| 54 |
+
Intersects rays with the [-1, 1] NDC volume.
|
| 55 |
+
Returns min and max distance of entry.
|
| 56 |
+
Returns -1 for no intersection.
|
| 57 |
+
https://www.scratchapixel.com/lessons/3d-basic-rendering/minimal-ray-tracer-rendering-simple-shapes/ray-box-intersection
|
| 58 |
+
"""
|
| 59 |
+
o_shape = rays_o.shape
|
| 60 |
+
rays_o = rays_o.detach().reshape(-1, 3)
|
| 61 |
+
rays_d = rays_d.detach().reshape(-1, 3)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
bb_min = [-1*(box_side_length/2), -1*(box_side_length/2), -1*(box_side_length/2)]
|
| 65 |
+
bb_max = [1*(box_side_length/2), 1*(box_side_length/2), 1*(box_side_length/2)]
|
| 66 |
+
bounds = torch.tensor([bb_min, bb_max], dtype=rays_o.dtype, device=rays_o.device)
|
| 67 |
+
is_valid = torch.ones(rays_o.shape[:-1], dtype=bool, device=rays_o.device)
|
| 68 |
+
|
| 69 |
+
# Precompute inverse for stability.
|
| 70 |
+
invdir = 1 / rays_d
|
| 71 |
+
sign = (invdir < 0).long()
|
| 72 |
+
|
| 73 |
+
# Intersect with YZ plane.
|
| 74 |
+
tmin = (bounds.index_select(0, sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0]
|
| 75 |
+
tmax = (bounds.index_select(0, 1 - sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0]
|
| 76 |
+
|
| 77 |
+
# Intersect with XZ plane.
|
| 78 |
+
tymin = (bounds.index_select(0, sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1]
|
| 79 |
+
tymax = (bounds.index_select(0, 1 - sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1]
|
| 80 |
+
|
| 81 |
+
# Resolve parallel rays.
|
| 82 |
+
is_valid[torch.logical_or(tmin > tymax, tymin > tmax)] = False
|
| 83 |
+
|
| 84 |
+
# Use the shortest intersection.
|
| 85 |
+
tmin = torch.max(tmin, tymin)
|
| 86 |
+
tmax = torch.min(tmax, tymax)
|
| 87 |
+
|
| 88 |
+
# Intersect with XY plane.
|
| 89 |
+
tzmin = (bounds.index_select(0, sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2]
|
| 90 |
+
tzmax = (bounds.index_select(0, 1 - sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2]
|
| 91 |
+
|
| 92 |
+
# Resolve parallel rays.
|
| 93 |
+
is_valid[torch.logical_or(tmin > tzmax, tzmin > tmax)] = False
|
| 94 |
+
|
| 95 |
+
# Use the shortest intersection.
|
| 96 |
+
tmin = torch.max(tmin, tzmin)
|
| 97 |
+
tmax = torch.min(tmax, tzmax)
|
| 98 |
+
|
| 99 |
+
# Mark invalid.
|
| 100 |
+
tmin[torch.logical_not(is_valid)] = -1
|
| 101 |
+
tmax[torch.logical_not(is_valid)] = -2
|
| 102 |
+
|
| 103 |
+
return tmin.reshape(*o_shape[:-1], 1), tmax.reshape(*o_shape[:-1], 1)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def linspace(start: torch.Tensor, stop: torch.Tensor, num: int):
|
| 107 |
+
"""
|
| 108 |
+
Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive.
|
| 109 |
+
Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch.
|
| 110 |
+
"""
|
| 111 |
+
# create a tensor of 'num' steps from 0 to 1
|
| 112 |
+
steps = torch.arange(num, dtype=torch.float32, device=start.device) / (num - 1)
|
| 113 |
+
|
| 114 |
+
# reshape the 'steps' tensor to [-1, *([1]*start.ndim)] to allow for broadcastings
|
| 115 |
+
# - using 'steps.reshape([-1, *([1]*start.ndim)])' would be nice here but torchscript
|
| 116 |
+
# "cannot statically infer the expected size of a list in this contex", hence the code below
|
| 117 |
+
for i in range(start.ndim):
|
| 118 |
+
steps = steps.unsqueeze(-1)
|
| 119 |
+
|
| 120 |
+
# the output starts at 'start' and increments until 'stop' in each dimension
|
| 121 |
+
out = start[None] + steps * (stop - start)[None]
|
| 122 |
+
|
| 123 |
+
return out
|