BrainFM / Generator /interpol /jit_utils.py
peirong26's picture
Upload 187 files
2571f24 verified
"""A lot of utility functions for TorchScript"""
import torch
import os
from typing import List, Tuple, Optional
from .utils import torch_version
from torch import Tensor
@torch.jit.script
def pad_list_int(x: List[int], dim: int) -> List[int]:
if len(x) < dim:
x = x + x[-1:] * (dim - len(x))
if len(x) > dim:
x = x[:dim]
return x
@torch.jit.script
def pad_list_float(x: List[float], dim: int) -> List[float]:
if len(x) < dim:
x = x + x[-1:] * (dim - len(x))
if len(x) > dim:
x = x[:dim]
return x
@torch.jit.script
def pad_list_str(x: List[str], dim: int) -> List[str]:
if len(x) < dim:
x = x + x[-1:] * (dim - len(x))
if len(x) > dim:
x = x[:dim]
return x
@torch.jit.script
def list_any(x: List[bool]) -> bool:
for elem in x:
if elem:
return True
return False
@torch.jit.script
def list_all(x: List[bool]) -> bool:
for elem in x:
if not elem:
return False
return True
@torch.jit.script
def list_prod_int(x: List[int]) -> int:
if len(x) == 0:
return 1
x0 = x[0]
for x1 in x[1:]:
x0 = x0 * x1
return x0
@torch.jit.script
def list_sum_int(x: List[int]) -> int:
if len(x) == 0:
return 1
x0 = x[0]
for x1 in x[1:]:
x0 = x0 + x1
return x0
@torch.jit.script
def list_prod_tensor(x: List[Tensor]) -> Tensor:
if len(x) == 0:
empty: List[int] = []
return torch.ones(empty)
x0 = x[0]
for x1 in x[1:]:
x0 = x0 * x1
return x0
@torch.jit.script
def list_sum_tensor(x: List[Tensor]) -> Tensor:
if len(x) == 0:
empty: List[int] = []
return torch.ones(empty)
x0 = x[0]
for x1 in x[1:]:
x0 = x0 + x1
return x0
@torch.jit.script
def list_reverse_int(x: List[int]) -> List[int]:
if len(x) == 0:
return x
return [x[i] for i in range(-1, -len(x)-1, -1)]
@torch.jit.script
def list_cumprod_int(x: List[int], reverse: bool = False,
exclusive: bool = False) -> List[int]:
if len(x) == 0:
lx: List[int] = []
return lx
if reverse:
x = list_reverse_int(x)
x0 = 1 if exclusive else x[0]
lx = [x0]
all_x = x[:-1] if exclusive else x[1:]
for x1 in all_x:
x0 = x0 * x1
lx.append(x0)
if reverse:
lx = list_reverse_int(lx)
return lx
@torch.jit.script
def movedim1(x, source: int, destination: int):
dim = x.dim()
source = dim + source if source < 0 else source
destination = dim + destination if destination < 0 else destination
permutation = [d for d in range(dim)]
permutation = permutation[:source] + permutation[source+1:]
permutation = permutation[:destination] + [source] + permutation[destination:]
return x.permute(permutation)
@torch.jit.script
def sub2ind(subs, shape: List[int]):
"""Convert sub indices (i, j, k) into linear indices.
The rightmost dimension is the most rapidly changing one
-> if shape == [D, H, W], the strides are therefore [H*W, W, 1]
Parameters
----------
subs : (D, ...) tensor
List of sub-indices. The first dimension is the number of dimension.
Each element should have the same number of elements and shape.
shape : (D,) list[int]
Size of each dimension. Its length should be the same as the
first dimension of ``subs``.
Returns
-------
ind : (...) tensor
Linear indices
"""
subs = subs.unbind(0)
ind = subs[-1]
subs = subs[:-1]
ind = ind.clone()
stride = list_cumprod_int(shape[1:], reverse=True, exclusive=False)
for i, s in zip(subs, stride):
ind += i * s
return ind
@torch.jit.script
def sub2ind_list(subs: List[Tensor], shape: List[int]):
"""Convert sub indices (i, j, k) into linear indices.
The rightmost dimension is the most rapidly changing one
-> if shape == [D, H, W], the strides are therefore [H*W, W, 1]
Parameters
----------
subs : (D,) list[tensor]
List of sub-indices. The first dimension is the number of dimension.
Each element should have the same number of elements and shape.
shape : (D,) list[int]
Size of each dimension. Its length should be the same as the
first dimension of ``subs``.
Returns
-------
ind : (...) tensor
Linear indices
"""
ind = subs[-1]
subs = subs[:-1]
ind = ind.clone()
stride = list_cumprod_int(shape[1:], reverse=True, exclusive=False)
for i, s in zip(subs, stride):
ind += i * s
return ind
# floor_divide returns wrong results for negative values, because it truncates
# instead of performing a proper floor. In recent version of pytorch, it is
# advised to use div(..., rounding_mode='trunc'|'floor') instead.
# Here, we only use floor_divide on positive values so we do not care.
if torch_version('>=', [1, 8]):
@torch.jit.script
def floor_div(x, y) -> torch.Tensor:
return torch.div(x, y, rounding_mode='floor')
@torch.jit.script
def floor_div_int(x, y: int) -> torch.Tensor:
return torch.div(x, y, rounding_mode='floor')
else:
@torch.jit.script
def floor_div(x, y) -> torch.Tensor:
return (x / y).floor_()
@torch.jit.script
def floor_div_int(x, y: int) -> torch.Tensor:
return (x / y).floor_()
@torch.jit.script
def ind2sub(ind, shape: List[int]):
"""Convert linear indices into sub indices (i, j, k).
The rightmost dimension is the most rapidly changing one
-> if shape == [D, H, W], the strides are therefore [H*W, W, 1]
Parameters
----------
ind : tensor_like
Linear indices
shape : (D,) vector_like
Size of each dimension.
Returns
-------
subs : (D, ...) tensor
Sub-indices.
"""
stride = list_cumprod_int(shape, reverse=True, exclusive=True)
sub = ind.new_empty([len(shape)] + ind.shape)
sub.copy_(ind)
for d in range(len(shape)):
if d > 0:
sub[d] = torch.remainder(sub[d], stride[d-1])
sub[d] = floor_div_int(sub[d], stride[d])
return sub
@torch.jit.script
def inbounds_mask_3d(extrapolate: int, gx, gy, gz, nx: int, ny: int, nz: int) \
-> Optional[Tensor]:
# mask of inbounds voxels
mask: Optional[Tensor] = None
if extrapolate in (0, 2): # no / hist
tiny = 5e-2
threshold = tiny
if extrapolate == 2:
threshold = 0.5 + tiny
mask = ((gx > -threshold) & (gx < nx - 1 + threshold) &
(gy > -threshold) & (gy < ny - 1 + threshold) &
(gz > -threshold) & (gz < nz - 1 + threshold))
return mask
return mask
@torch.jit.script
def inbounds_mask_2d(extrapolate: int, gx, gy, nx: int, ny: int) \
-> Optional[Tensor]:
# mask of inbounds voxels
mask: Optional[Tensor] = None
if extrapolate in (0, 2): # no / hist
tiny = 5e-2
threshold = tiny
if extrapolate == 2:
threshold = 0.5 + tiny
mask = ((gx > -threshold) & (gx < nx - 1 + threshold) &
(gy > -threshold) & (gy < ny - 1 + threshold))
return mask
return mask
@torch.jit.script
def inbounds_mask_1d(extrapolate: int, gx, nx: int) -> Optional[Tensor]:
# mask of inbounds voxels
mask: Optional[Tensor] = None
if extrapolate in (0, 2): # no / hist
tiny = 5e-2
threshold = tiny
if extrapolate == 2:
threshold = 0.5 + tiny
mask = (gx > -threshold) & (gx < nx - 1 + threshold)
return mask
return mask
@torch.jit.script
def make_sign(sign: List[Optional[Tensor]]) -> Optional[Tensor]:
is_none : List[bool] = [s is None for s in sign]
if list_all(is_none):
return None
filt_sign: List[Tensor] = []
for s in sign:
if s is not None:
filt_sign.append(s)
return list_prod_tensor(filt_sign)
@torch.jit.script
def square(x):
return x * x
@torch.jit.script
def square_(x):
return x.mul_(x)
@torch.jit.script
def cube(x):
return x * x * x
@torch.jit.script
def cube_(x):
return square_(x).mul_(x)
@torch.jit.script
def pow4(x):
return square(square(x))
@torch.jit.script
def pow4_(x):
return square_(square_(x))
@torch.jit.script
def pow5(x):
return x * pow4(x)
@torch.jit.script
def pow5_(x):
return pow4_(x).mul_(x)
@torch.jit.script
def pow6(x):
return square(cube(x))
@torch.jit.script
def pow6_(x):
return square_(cube_(x))
@torch.jit.script
def pow7(x):
return pow6(x) * x
@torch.jit.script
def pow7_(x):
return pow6_(x).mul_(x)
@torch.jit.script
def dot(x, y, dim: int = -1, keepdim: bool = False):
"""(Batched) dot product along a dimension"""
x = movedim1(x, dim, -1).unsqueeze(-2)
y = movedim1(y, dim, -1).unsqueeze(-1)
d = torch.matmul(x, y).squeeze(-1).squeeze(-1)
if keepdim:
d.unsqueeze(dim)
return d
@torch.jit.script
def dot_multi(x, y, dim: List[int], keepdim: bool = False):
"""(Batched) dot product along a dimension"""
for d in dim:
x = movedim1(x, d, -1)
y = movedim1(y, d, -1)
x = x.reshape(x.shape[:-len(dim)] + [1, -1])
y = y.reshape(x.shape[:-len(dim)] + [-1, 1])
dt = torch.matmul(x, y).squeeze(-1).squeeze(-1)
if keepdim:
for d in dim:
dt.unsqueeze(d)
return dt
# cartesian_prod takes multiple inout tensors as input in eager mode
# but takes a list of tensor in jit mode. This is a helper that works
# in both cases.
if not int(os.environ.get('PYTORCH_JIT', '1')):
cartesian_prod = lambda x: torch.cartesian_prod(*x)
if torch_version('>=', (1, 10)):
def meshgrid_ij(x: List[torch.Tensor]) -> List[torch.Tensor]:
return torch.meshgrid(*x, indexing='ij')
def meshgrid_xy(x: List[torch.Tensor]) -> List[torch.Tensor]:
return torch.meshgrid(*x, indexing='xy')
else:
def meshgrid_ij(x: List[torch.Tensor]) -> List[torch.Tensor]:
return torch.meshgrid(*x)
def meshgrid_xy(x: List[torch.Tensor]) -> List[torch.Tensor]:
grid = torch.meshgrid(*x)
if len(grid) > 1:
grid[0] = grid[0].transpose(0, 1)
grid[1] = grid[1].transpose(0, 1)
return grid
else:
cartesian_prod = torch.cartesian_prod
if torch_version('>=', (1, 10)):
@torch.jit.script
def meshgrid_ij(x: List[torch.Tensor]) -> List[torch.Tensor]:
return torch.meshgrid(x, indexing='ij')
@torch.jit.script
def meshgrid_xy(x: List[torch.Tensor]) -> List[torch.Tensor]:
return torch.meshgrid(x, indexing='xy')
else:
@torch.jit.script
def meshgrid_ij(x: List[torch.Tensor]) -> List[torch.Tensor]:
return torch.meshgrid(x)
@torch.jit.script
def meshgrid_xyt(x: List[torch.Tensor]) -> List[torch.Tensor]:
grid = torch.meshgrid(x)
if len(grid) > 1:
grid[0] = grid[0].transpose(0, 1)
grid[1] = grid[1].transpose(0, 1)
return grid
meshgrid = meshgrid_ij
# In torch < 1.6, div applied to integer tensor performed a floor_divide
# In torch > 1.6, it performs a true divide.
# Floor division must be done using `floor_divide`, but it was buggy
# until torch 1.13 (it was doing a trunc divide instead of a floor divide).
# There was at some point a deprecation warning for floor_divide, but it
# seems to have been lifted afterwards. In torch >= 1.13, floor_divide
# performs a correct floor division.
# Since we only apply floor_divide ot positive values, we are fine.
if torch_version('<', (1, 6)):
floor_div = torch.div
else:
floor_div = torch.floor_divide