ArthurY's picture
update source
c3d0544
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import List, Tuple
import torch
import torch.nn.functional as F
from torch import Tensor
# TODO enum causes segmentation faults with current torch script. Go back to enum after torch script update
"""
@enum.unique
class InterpolationType(enum.Enum):
NEAREST_NEIGHBOR = (1, 1)
LINEAR = (2, 2)
SMOOTH_STEP_1 = (3, 2)
SMOOTH_STEP_2 = (4, 2)
GAUSSIAN = (6, 5)
def __init__(self, index, stride):
self.index = index
self.stride = stride
"""
@torch.jit.script
def linear_step(x: Tensor) -> Tensor:
"""
Clips the input tensor values between 0 and 1 using a linear step.
This function constrains each element in the input tensor to be in the range [0, 1].
Values below 0 are set to 0, and values above 1 are set to 1.
Parameters
----------
x: Tensor
Input tensor to be clipped.
Returns
-------
Tensor
A tensor with values clipped between 0 and 1.
Example
-------
>>> x = torch.tensor([-0.5, 0.5, 1.5])
>>> linear_step(x)
tensor([0., 0.5, 1.])
"""
return torch.clip(x, 0, 1)
@torch.jit.script
def smooth_step_1(x: Tensor) -> Tensor:
"""
Compute the smooth step interpolation of the input tensor values.
This function applies the smooth step function: \(f(x) = 3x^2 - 2x^3\)
to each element in the input tensor, and then clips the result to be in the
range [0, 1]. It's useful for creating a smooth transition between two values.
parameters
----------
x: Tensor
Input tensor, with values expected to be in the range [0, 1] for
meaningful interpolation.
Returns
-------
Tensor
A tensor with smooth step interpolated values, clipped between 0 and 1.
"""
return torch.clip(3 * x**2 - 2 * x**3, 0, 1)
@torch.jit.script
def smooth_step_2(x: Tensor) -> Tensor:
"""
Compute the enhanced smooth step interpolation of the input tensor values.
This function applies the enhanced smooth step function:
\(f(x) = x^3 (6x^2 - 15x + 10)\) to each element in the input tensor.
The result is then clipped to be in the range [0, 1].
Parameters
----------
x: Tensor
Input tensor, with values expected to be in the range [0, 1] for meaningful
interpolation.
Returns
-------
Tensor
A tensor with enhanced smooth step interpolated values, clipped between 0 and 1.
"""
return torch.clip(x**3 * (6 * x**2 - 15 * x + 10), 0, 1)
@torch.jit.script
def nearest_neighbor_weighting(dist_vec: Tensor, dx: Tensor) -> Tensor:
"""
Compute the nearest neighbor weighting for the given distance vector.
This function returns a tensor of ones with a shape derived from the input
`dist_vec`. The resulting tensor represents weights in the context of nearest
neighbor interpolation, where the closest point has a weight of one and all other
points have a weight of zero.
Parameters:
----------
dist_vec: Tensor
A tensor representing the distances from a set of points.
The last two dimensions are expected to be spatial dimensions.
dx: Tensor
A tensor representing spacing between points.
While it's provided as an input, it doesn't influence the output for this
function since nearest neighbor weights are constant.
Returns
-------
Tensor
A tensor filled with ones and shaped according to `dist_vec`
but with the last two dimensions reduced to single dimensions.
"""
return torch.ones(dist_vec.shape[:-2] + [1] + [1], device=dist_vec.device)
@torch.jit.script
def _hyper_cube_weighting(lower_point: Tensor, upper_point: Tensor) -> Tensor:
dim = lower_point.shape[-1]
weights = []
weights = [upper_point[..., 0], lower_point[..., 0]]
for i in range(1, dim):
new_weights = []
for w in weights:
new_weights.append(w * upper_point[..., i])
new_weights.append(w * lower_point[..., i])
weights = new_weights
weights = torch.stack(weights, dim=-1)
return torch.unsqueeze(weights, dim=-1)
@torch.jit.script
def linear_weighting(dist_vec: Tensor, dx: Tensor) -> Tensor:
"""
Compute the linear weighting based on the distance vector and spacing.
Parameters
----------
dist_vec: Tensor
Distance vector for interpolation points.
dx: Tensor
Spacing between points.
Returns
-------
Tensor
Weights derived from the linear interpolation of the distance vector.
"""
normalized_dist_vec = dist_vec / dx
lower_point = normalized_dist_vec[..., 0, :]
upper_point = -normalized_dist_vec[..., -1, :]
return _hyper_cube_weighting(lower_point, upper_point)
@torch.jit.script
def smooth_step_1_weighting(dist_vec: Tensor, dx: Tensor) -> Tensor:
"""
Compute the weighting using the `smooth_step_1` function on the normalized
distance vector.
Parameters
----------
dist_vec: Tensor
Distance vector for interpolation points.
dx: Tensor
Spacing between points.
Returns
-------
Tensor
Weights derived using the `smooth_step_1` interpolation of the distance vector.
"""
normalized_dist_vec = dist_vec / dx
lower_point = smooth_step_1(normalized_dist_vec[..., 0, :])
upper_point = smooth_step_1(-normalized_dist_vec[..., -1, :])
return _hyper_cube_weighting(lower_point, upper_point)
@torch.jit.script
def smooth_step_2_weighting(dist_vec: Tensor, dx: Tensor) -> Tensor:
"""
Compute the weighting using the `smooth_step_2` function on the normalized
distance vector.
Parameters
----------
dist_vec: Tensor
Distance vector for interpolation points.
dx: Tensor
pacing between points.
Returns
-------
Tensor
Weights derived using the `smooth_step_2` interpolation of the distance vector.
"""
normalized_dist_vec = dist_vec / dx
lower_point = smooth_step_2(normalized_dist_vec[..., 0, :])
upper_point = smooth_step_2(-normalized_dist_vec[..., -1, :])
return _hyper_cube_weighting(lower_point, upper_point)
@torch.jit.script
def gaussian_weighting(dist_vec: Tensor, dx: Tensor) -> Tensor:
"""
Compute the Gaussian weighting based on the distance vector and spacing.
Parameters
----------
dist_vec: Tensor
Distance vector for interpolation points.
dx: Tensor
Spacing between points.
Returns
-------
Tensor
Gaussian weights for the provided distance vector.
"""
dim = dx.size(-1)
sharpen = 2.0
sigma = dx / sharpen
factor = 1.0 / ((2.0 * math.pi) ** (dim / 2.0) * sigma.prod())
gaussian = torch.exp(-0.5 * torch.square((dist_vec / sigma)))
gaussian = factor * gaussian.prod(dim=-1)
norm = gaussian.sum(dim=2, keepdim=True)
weights = torch.unsqueeze(gaussian / norm, dim=3)
return weights
# @torch.jit.script
def _gather_nd(params: Tensor, indices: Tensor) -> Tensor:
"""As seen here https://discuss.pytorch.org/t/how-to-do-the-tf-gather-nd-in-pytorch/6445/30"""
orig_shape = list(indices.shape)
num_samples = 1
for s in orig_shape[:-1]:
num_samples *= s
m = orig_shape[-1]
n = len(params.shape)
if m <= n:
out_shape = orig_shape[:-1] + list(params.shape)[m:]
else:
raise ValueError(
f"the last dimension of indices must less or equal to the rank of params. Got indices:{indices.shape}, params:{params.shape}. {m} > {n}"
)
indices = indices.reshape((num_samples, m)).transpose(0, 1).tolist()
output = params[indices] # (num_samples, ...)
return output.reshape(out_shape).contiguous()
@torch.jit.script
def index_values_high_mem(points: Tensor, idx: Tensor) -> Tensor:
"""
Index values from the `points` tensor using the provided indices `idx`.
Parameters
----------
points: Tensor
The source tensor from which values will be indexed.
idx: Tensor
The tensor containing indices for indexing.
Returns
-------
Tensor
Indexed values from the `points` tensor.
"""
idx = idx.unsqueeze(3).repeat_interleave(points.size(-1), dim=3)
points = points.unsqueeze(1).repeat_interleave(idx.size(1), dim=1)
out = torch.gather(points, dim=2, index=idx)
return out
# @torch.jit.script
def index_values_low_mem(points: Tensor, idx: Tensor) -> Tensor:
"""
Input:
points: (b,m,c) float32 array, known points
idx: (b,n,3) int32 array, indices to known points
Output:
out: (b,m,n,c) float32 array, interpolated point values
"""
device = points.device
idxShape = idx.shape
batch_size = idxShape[0]
num_points = idxShape[1]
K = idxShape[2]
num_features = points.shape[2]
batch_indices = torch.reshape(
torch.tile(
torch.unsqueeze(torch.arange(0, batch_size).to(device), dim=0),
(num_points * K,),
),
[-1],
) # BNK
point_indices = torch.reshape(idx, [-1]) # BNK
vertices = _gather_nd(
points, torch.stack((batch_indices, point_indices), dim=1)
) # BNKxC
vertices4d = torch.reshape(
vertices, [batch_size, num_points, K, num_features]
) # BxNxKxC
return vertices4d
@torch.jit.script
def _grid_knn_idx(
query_points: Tensor,
grid: List[Tuple[float, float, int]],
stride: int,
padding: bool = True,
) -> Tensor:
# set k
k = stride // 2
# set device
device = query_points.device
# find nearest neighbors of query points from a grid
# dx vector on grid
dx = torch.tensor([(x[1] - x[0]) / (x[2] - 1) for x in grid])
dx = dx.view(1, 1, len(grid)).to(device)
# min point on grid (this will change if we are padding the grid)
start = torch.tensor([val[0] for val in grid]).to(device)
if padding:
start = start - (k * dx)
start = start.view(1, 1, len(grid))
# this is the center nearest neighbor in the grid
center_idx = (((query_points - start) / dx) + (stride / 2.0 % 1.0)).to(torch.int64)
# index window
idx_add = (
torch.arange(-((stride - 1) // 2), stride // 2 + 1).view(1, 1, -1).to(device)
)
# find all index in window around center index
# TODO make for more general diminsions
if len(grid) == 1:
idx_row_0 = center_idx[..., 0:1] + idx_add
idx = idx_row_0.view(idx_row_0.shape[0:2] + torch.Size([int(stride)]))
elif len(grid) == 2:
dim_size_1 = grid[1][2]
if padding:
dim_size_1 += 2 * k
idx_row_0 = dim_size_1 * (center_idx[..., 0:1] + idx_add)
idx_row_0 = idx_row_0.unsqueeze(-1)
idx_row_1 = center_idx[..., 1:2] + idx_add
idx_row_1 = idx_row_1.unsqueeze(2)
idx = (idx_row_0 + idx_row_1).view(
idx_row_0.shape[0:2] + torch.Size([int(stride**2)])
)
elif len(grid) == 3:
dim_size_1 = grid[1][2]
dim_size_2 = grid[2][2]
if padding:
dim_size_1 += 2 * k
dim_size_2 += 2 * k
idx_row_0 = dim_size_2 * dim_size_1 * (center_idx[..., 0:1] + idx_add)
idx_row_0 = idx_row_0.unsqueeze(-1).unsqueeze(-1)
idx_row_1 = dim_size_2 * (center_idx[..., 1:2] + idx_add)
idx_row_1 = idx_row_1.unsqueeze(2).unsqueeze(-1)
idx_row_2 = center_idx[..., 2:3] + idx_add
idx_row_2 = idx_row_2.unsqueeze(2).unsqueeze(3)
idx = (idx_row_0 + idx_row_1 + idx_row_2).view(
idx_row_0.shape[0:2] + torch.Size([int(stride**3)])
)
else:
raise RuntimeError
return idx
# TODO currently the `tolist` operation is not supported by torch script and when fixed torch script will be used
# @torch.jit.script
def interpolation(
query_points: Tensor,
context_grid: Tensor,
grid: List[Tuple[float, float, int]],
interpolation_type: str = "smooth_step_2",
mem_speed_trade: bool = True,
) -> Tensor:
"""
Interpolate values at `query_points` based on `context_grid` using specified
interpolation methods.
Parameters
----------
query_points: Tensor
Points at which interpolation is to be performed.
context_grid: Tensor
Source grid from which values are to be interpolated.
grid: List[Tuple[float, float, int]]
Describes the grid's range and resolution.
interpolation_type: str, optional
Type of interpolation to be used, by default "smooth_step_2".
mem_speed_trade: bool, optional
Trade-off between memory usage and speed.
If True, uses low memory indexing, by default True.
Returns
-------
Tensor
Interpolated values at the `query_points`.
"""
# set stride TODO this will be replaced with InterpolationType later
if interpolation_type == "nearest_neighbor":
stride = 1
elif interpolation_type == "linear":
stride = 2
elif interpolation_type == "smooth_step_1":
stride = 2
elif interpolation_type == "smooth_step_2":
stride = 2
elif interpolation_type == "gaussian":
stride = 5
else:
raise RuntimeError(f"Interpolation type {interpolation_type} not supported")
# set device
device = query_points.device
# useful values
dims = len(grid)
nr_channels = context_grid.size(0)
dx = [((x[1] - x[0]) / (x[2] - 1)) for x in grid]
# generate mesh grid of position information [grid_dim_1, grid_dim_2, ..., 2-3]
# NOTE the mesh grid is padded by stride//2
k = stride // 2
linspace = [
torch.linspace(x[0] - k * dx_i, x[1] + k * dx_i, x[2] + 2 * k)
for x, dx_i in zip(grid, dx)
]
meshgrid = torch.meshgrid(linspace)
meshgrid = torch.stack(meshgrid, dim=-1).to(device)
# pad context grid by k to avoid cuts on corners
padding = dims * (k, k)
context_grid = F.pad(context_grid, padding)
# reshape query points, context grid and mesh grid for easier indexing
# [1, grid_dim_1*grid_dim_2*..., 2-4]
nr_grid_points = int(torch.tensor([x[2] + 2 * k for x in grid]).prod())
meshgrid = meshgrid.view(1, nr_grid_points, dims)
context_grid = torch.reshape(context_grid, [1, nr_channels, nr_grid_points])
context_grid = torch.swapaxes(context_grid, 1, 2)
query_points = query_points.unsqueeze(0)
# compute index of nearest neighbor on grid to query points
idx = _grid_knn_idx(query_points, grid, stride, padding=True)
# index mesh grid to get distance vector
if mem_speed_trade:
mesh_grid_idx = index_values_low_mem(meshgrid, idx)
else:
mesh_grid_idx = index_values_high_mem(meshgrid, idx)
dist_vec = query_points.unsqueeze(2) - mesh_grid_idx
# make tf dx vec (for interpolation function)
dx = torch.tensor(dx, dtype=torch.float32)
dx = torch.reshape(dx, [1, 1, 1, dims]).to(device)
# compute bump function
if interpolation_type == "nearest_neighbor":
weights = nearest_neighbor_weighting(dist_vec, dx)
elif interpolation_type == "linear":
weights = linear_weighting(dist_vec, dx)
elif interpolation_type == "smooth_step_1":
weights = smooth_step_1_weighting(dist_vec, dx)
elif interpolation_type == "smooth_step_2":
weights = smooth_step_2_weighting(dist_vec, dx)
elif interpolation_type == "gaussian":
weights = gaussian_weighting(dist_vec, dx)
else:
raise RuntimeError
# index context grid with index
if mem_speed_trade:
context_grid_idx = index_values_low_mem(context_grid, idx)
else:
context_grid_idx = index_values_high_mem(context_grid, idx)
# interpolate points
product = weights * context_grid_idx
interpolated_points = product.sum(dim=2)
return interpolated_points[0]