qbhf2's picture
added NvidiaWarp and GarmentCode repos
66c9c8a
from typing import Union, Any, Tuple
import warp as wp
import warp.types
from warp.sparse import BsrMatrix, bsr_zeros, bsr_get_diag, bsr_mv
from warp.utils import array_inner
def bsr_to_scipy(matrix: BsrMatrix) -> "scipy.sparse.bsr_array":
try:
from scipy.sparse import csr_array, bsr_array
except ImportError:
# WAR for older scipy
from scipy.sparse import csr_matrix as csr_array, bsr_matrix as bsr_array
if matrix.block_shape == (1, 1):
return csr_array(
(
matrix.values.numpy().flatten()[: matrix.nnz],
matrix.columns.numpy()[: matrix.nnz],
matrix.offsets.numpy(),
),
shape=matrix.shape,
)
return bsr_array(
(
matrix.values.numpy().reshape((matrix.values.shape[0], *matrix.block_shape))[: matrix.nnz],
matrix.columns.numpy()[: matrix.nnz],
matrix.offsets.numpy(),
),
shape=matrix.shape,
)
def scipy_to_bsr(sp: Union["scipy.sparse.bsr_array", "scipy.sparse.csr_array"], device=None, dtype=None) -> BsrMatrix:
try:
from scipy.sparse import csr_array
except ImportError:
# WAR for older scipy
from scipy.sparse import csr_matrix as csr_array
if dtype is None:
dtype = warp.types.np_dtype_to_warp_type[sp.dtype]
sp.sort_indices()
if isinstance(sp, csr_array):
matrix = bsr_zeros(sp.shape[0], sp.shape[1], dtype, device=device)
else:
block_shape = sp.blocksize
block_type = wp.types.matrix(shape=block_shape, dtype=dtype)
matrix = bsr_zeros(sp.shape[0] // block_shape[0], sp.shape[1] // block_shape[1], block_type, device=device)
matrix.nnz = sp.nnz
matrix.values = wp.array(sp.data.flatten(), dtype=matrix.values.dtype, device=device)
matrix.columns = wp.array(sp.indices, dtype=matrix.columns.dtype, device=device)
matrix.offsets = wp.array(sp.indptr, dtype=matrix.offsets.dtype, device=device)
return matrix
@wp.kernel
def _bsr_cg_kernel_1(
rs_old: wp.array(dtype=Any),
p_Ap: wp.array(dtype=Any),
x: wp.array(dtype=Any),
r: wp.array(dtype=Any),
p: wp.array(dtype=Any),
Ap: wp.array(dtype=Any),
):
i = wp.tid()
if p_Ap[0] != 0.0:
alpha = rs_old[0] / p_Ap[0]
x[i] = x[i] + alpha * p[i]
r[i] = r[i] - alpha * Ap[i]
@wp.kernel
def _bsr_cg_kernel_2(
tol: Any,
rs_old: wp.array(dtype=Any),
rs_new: wp.array(dtype=Any),
z: wp.array(dtype=Any),
p: wp.array(dtype=Any),
):
# p = r + (rsnew / rsold) * p;
i = wp.tid()
if rs_new[0] > tol:
beta = rs_new[0] / rs_old[0]
else:
beta = rs_new[0] - rs_new[0]
p[i] = z[i] + beta * p[i]
@wp.kernel
def _bsr_cg_solve_block_diag_precond_kernel(
diag: wp.array(dtype=Any),
r: wp.array(dtype=Any),
z: wp.array(dtype=Any),
):
i = wp.tid()
d = wp.get_diag(diag[i])
if wp.dot(d, d) == 0.0:
z[i] = r[i]
else:
d_abs = wp.max(d, -d)
z[i] = wp.cw_div(r[i], d_abs)
@wp.kernel
def _bsr_cg_solve_scalar_diag_precond_kernel(
diag: wp.array(dtype=Any),
r: wp.array(dtype=Any),
z: wp.array(dtype=Any),
):
i = wp.tid()
d = diag[i]
if d == 0.0:
z[i] = r[i]
else:
z[i] = r[i] / wp.abs(d)
def bsr_cg(
A: BsrMatrix,
x: wp.array,
b: wp.array,
max_iters: int = 0,
tol: float = 0.0001,
check_every=10,
use_diag_precond=True,
mv_routine=bsr_mv,
device=None,
quiet=False,
) -> Tuple[float, int]:
"""Solves the linear system A x = b using the Conjugate Gradient method, optionally with diagonal preconditioning
Args:
A: system left-hand side
x: result vector and initial guess
b: system right-hand-side
max_iters: maximum number of iterations to performing before aborting. If set to zero, equal to the system size.
tol: relative tolerance under which to stop the solve
check_every: number of iterations every which to evaluate the current residual norm to compare against tolerance
use_diag_precond: Whether to use diagonal preconditioning
mv_routine: Matrix-vector multiplication routine to for multiplications with ``A``
device: Warp device to use for the computation
Returns:
Tuple (residual norm, iteration count)
"""
if max_iters == 0:
max_iters = A.shape[0]
if device is None:
device = A.values.device
scalar_dtype = A.scalar_type
r = wp.zeros_like(b)
p = wp.zeros_like(b)
Ap = wp.zeros_like(b)
if use_diag_precond:
A_diag = bsr_get_diag(A)
z = wp.zeros_like(b)
if A.block_shape == (1, 1):
precond_kernel = _bsr_cg_solve_scalar_diag_precond_kernel
else:
precond_kernel = _bsr_cg_solve_block_diag_precond_kernel
else:
z = r
rz_old = wp.empty(n=1, dtype=scalar_dtype, device=device)
rz_new = wp.empty(n=1, dtype=scalar_dtype, device=device)
p_Ap = wp.empty(n=1, dtype=scalar_dtype, device=device)
# r = b - A * x;
r.assign(b)
mv_routine(A, x, r, alpha=-1.0, beta=1.0)
# z = M^-1 r
if use_diag_precond:
wp.launch(kernel=precond_kernel, dim=A.nrow, device=device, inputs=[A_diag, r, z])
# p = z;
p.assign(z)
# rsold = r' * z;
array_inner(r, z, out=rz_old)
tol_sq = tol * tol * A.shape[0]
err = rz_old.numpy()[0]
end_iter = 0
if err > tol_sq:
end_iter = max_iters
for i in range(max_iters):
# Ap = A * p;
mv_routine(A, p, Ap)
array_inner(p, Ap, out=p_Ap)
wp.launch(kernel=_bsr_cg_kernel_1, dim=A.nrow, device=device, inputs=[rz_old, p_Ap, x, r, p, Ap])
# z = M^-1 r
if use_diag_precond:
wp.launch(kernel=precond_kernel, dim=A.nrow, device=device, inputs=[A_diag, r, z])
# rznew = r' * z;
array_inner(r, z, out=rz_new)
if ((i + 1) % check_every) == 0:
err = rz_new.numpy()[0]
if not quiet:
print(f"At iteration {i} error = \t {err} \t tol: {tol_sq}")
if err <= tol_sq:
end_iter = i
break
wp.launch(kernel=_bsr_cg_kernel_2, dim=A.nrow, device=device, inputs=[tol_sq, rz_old, rz_new, z, p])
# swap buffers
rs_tmp = rz_old
rz_old = rz_new
rz_new = rs_tmp
err = rz_old.numpy()[0]
if not quiet:
print(f"Terminated after {end_iter} iterations with error = \t {err}")
return err, end_iter
def invert_diagonal_bsr_mass_matrix(A: BsrMatrix):
"""Inverts each block of a block-diagonal mass matrix"""
scale = A.scalar_type(A.block_shape[0])
values = A.values
if not wp.types.type_is_matrix(values.dtype):
values = values.view(dtype=wp.mat(shape=(1, 1), dtype=A.scalar_type))
wp.launch(kernel=_block_diagonal_mass_invert, dim=A.nrow, inputs=[values, scale], device=values.device)
@wp.kernel
def _block_diagonal_mass_invert(values: wp.array(dtype=Any), scale: Any):
i = wp.tid()
values[i] = scale * values[i] / wp.ddot(values[i], values[i])