Spaces:
Sleeping
Sleeping
| from typing import Any, Generic, Optional, Tuple, TypeVar, Union | |
| import warp as wp | |
| import warp.types | |
| import warp.utils | |
| from warp.types import Array, Cols, Matrix, Rows, Scalar, Vector | |
| # typing hints | |
| _BlockType = TypeVar("BlockType") | |
| class _MatrixBlockType(Matrix): | |
| pass | |
| class _ScalarBlockType(Generic[Scalar]): | |
| pass | |
| BlockType = Union[_MatrixBlockType[Rows, Cols, Scalar], _ScalarBlockType[Scalar]] | |
| _struct_cache = dict() | |
| class BsrMatrix(Generic[_BlockType]): | |
| """Untyped base class for BSR and CSR matrices. | |
| Should not be constructed directly but through functions such as :func:`bsr_zeros`. | |
| Attributes: | |
| nrow (int): Number of rows of blocks | |
| ncol (int): Number of columns of blocks | |
| nnz (int): Number of non-zero blocks: must be equal to ``offsets[nrow-1]``, cached on host for convenience | |
| offsets (Array[int]): Array of size at least ``1 + nrows`` such that the start and end indices of the blocks of row ``r`` are ``offsets[r]`` and ``offsets[r+1]``, respectively. | |
| columns (Array[int]): Array of size at least equal to ``nnz`` containing block column indices | |
| values (Array[BlockType]): Array of size at least equal to ``nnz`` containing block values | |
| """ | |
| def scalar_type(self) -> Scalar: | |
| """Scalar type for individual block coefficients. For CSR matrices, this is the same as the block type""" | |
| return warp.types.type_scalar_type(self.values.dtype) | |
| def block_shape(self) -> Tuple[int, int]: | |
| """Shape of the individual blocks""" | |
| return getattr(self.values.dtype, "_shape_", (1, 1)) | |
| def block_size(self) -> int: | |
| """Size of the individual blocks, i.e. number of rows per block times number of columns per block""" | |
| return warp.types.type_length(self.values.dtype) | |
| def shape(self) -> Tuple[int, int]: | |
| """Shape of the matrix, i.e. number of rows/columns of blocks times number of rows/columns per block""" | |
| block_shape = self.block_shape | |
| return (self.nrow * block_shape[0], self.ncol * block_shape[1]) | |
| def bsr_matrix_t(dtype: BlockType): | |
| dtype = wp.types.type_to_warp(dtype) | |
| if not warp.types.type_is_matrix(dtype) and not dtype in warp.types.scalar_types: | |
| raise ValueError( | |
| f"BsrMatrix block type must be either warp matrix or scalar; got {warp.types.type_repr(dtype)}" | |
| ) | |
| class BsrMatrixTyped(BsrMatrix): | |
| nrow: int | |
| """Number of rows of blocks""" | |
| ncol: int | |
| """Number of columns of blocks""" | |
| nnz: int | |
| """Number of non-zero blocks: equal to offsets[-1], cached on host for convenience""" | |
| offsets: wp.array(dtype=int) | |
| """Array of size at least 1 + nrows""" | |
| columns: wp.array(dtype=int) | |
| """Array of size at least equal to nnz""" | |
| values: wp.array(dtype=dtype) | |
| module = wp.get_module(BsrMatrix.__module__) | |
| if hasattr(dtype, "_shape_"): | |
| type_str = f"{warp.types.type_scalar_type(dtype).__name__}_{dtype._shape_[0]}_{dtype._shape_[1]}" | |
| else: | |
| type_str = dtype.__name__ | |
| key = f"{BsrMatrix.__qualname__}_{type_str}" | |
| if key not in _struct_cache: | |
| _struct_cache[key] = wp.codegen.Struct( | |
| cls=BsrMatrixTyped, | |
| key=key, | |
| module=module, | |
| ) | |
| return _struct_cache[key] | |
| def bsr_zeros( | |
| rows_of_blocks: int, | |
| cols_of_blocks: int, | |
| block_type: BlockType, | |
| device: wp.context.Devicelike = None, | |
| ) -> BsrMatrix: | |
| """ | |
| Constructs and returns an empty BSR or CSR matrix with the given shape | |
| Args: | |
| bsr: The BSR or CSR matrix to set to zero | |
| rows_of_blocks: Number of rows of blocks | |
| cols_of_blocks: Number of columns of blocks | |
| block_type: Type of individual blocks. For CSR matrices, this should be a scalar type; | |
| for BSR matrices, this should be a matrix type (e.g. from :func:`warp.mat`) | |
| device: Device on which to allocate the matrix arrays | |
| """ | |
| bsr = bsr_matrix_t(block_type)() | |
| bsr.nrow = rows_of_blocks | |
| bsr.ncol = cols_of_blocks | |
| bsr.nnz = 0 | |
| bsr.columns = wp.empty(shape=(0,), dtype=int, device=device) | |
| bsr.values = wp.empty(shape=(0,), dtype=block_type, device=device) | |
| bsr.offsets = wp.zeros(shape=(bsr.nrow + 1,), dtype=int, device=device) | |
| return bsr | |
| def _bsr_ensure_fits(bsr: BsrMatrix, nrow: int = None, nnz: int = None): | |
| if nrow is None: | |
| nrow = bsr.nrow | |
| if nnz is None: | |
| nnz = bsr.nnz | |
| if bsr.offsets.size < nrow + 1: | |
| bsr.offsets = wp.empty(shape=(nrow + 1,), dtype=int, device=bsr.offsets.device) | |
| if bsr.columns.size < nnz: | |
| bsr.columns = wp.empty(shape=(nnz,), dtype=int, device=bsr.columns.device) | |
| if bsr.values.size < nnz: | |
| bsr.values = wp.empty(shape=(nnz,), dtype=bsr.values.dtype, device=bsr.values.device) | |
| def bsr_set_zero(bsr: BsrMatrix, rows_of_blocks: Optional[int] = None, cols_of_blocks: Optional[int] = None): | |
| """ | |
| Sets a BSR matrix to zero, possibly changing its size | |
| Args: | |
| bsr: The BSR or CSR matrix to set to zero | |
| rows_of_blocks: If not ``None``, the new number of rows of blocks | |
| cols_of_blocks: If not ``None``, the new number of columns of blocks | |
| """ | |
| if rows_of_blocks is not None: | |
| bsr.nrow = rows_of_blocks | |
| if cols_of_blocks is not None: | |
| bsr.ncol = cols_of_blocks | |
| bsr.nnz = 0 | |
| _bsr_ensure_fits(bsr) | |
| bsr.offsets.zero_() | |
| def bsr_set_from_triplets( | |
| dest: BsrMatrix[BlockType[Rows, Cols, Scalar]], | |
| rows: "Array[int]", | |
| columns: "Array[int]", | |
| values: "Array[Union[Scalar, BlockType[Rows, Cols, Scalar]]]", | |
| ): | |
| """ | |
| Fills a BSR matrix with values defined by coordinate-oriented (COO) triplets, discarding existing blocks. | |
| The first dimension of the three input arrays must match, and determines the number of non-zeros in the constructed matrix. | |
| Args: | |
| dest: Sparse matrix to populate | |
| rows: Row index for each non-zero | |
| columns: Columns index for each non-zero | |
| values: Block values for each non-zero. Must be either a one-dimensional array with data type identical | |
| to the `dest` matrix's block type, or a 3d array with data type equal to the `dest` matrix's scalar type. | |
| """ | |
| if values.device != columns.device or values.device != rows.device or values.device != dest.values.device: | |
| raise ValueError("All arguments must reside on the same device") | |
| if values.shape[0] != rows.shape[0] or values.shape[0] != columns.shape[0]: | |
| raise ValueError("All triplet arrays must have the same length") | |
| # Accept either array1d(dtype) or contiguous array3d(scalar_type) as values | |
| if values.ndim == 1: | |
| if values.dtype != dest.values.dtype: | |
| raise ValueError("Values array type must correspond to that of dest matrix") | |
| elif values.ndim == 3: | |
| if values.shape[1:] != dest.block_shape: | |
| raise ValueError( | |
| f"Last two dimensions in values array ({values.shape[1:]}) should correspond to matrix block shape {(dest.block_shape)})" | |
| ) | |
| if warp.types.type_scalar_type(values.dtype) != dest.scalar_type: | |
| raise ValueError("Scalar type of values array should correspond to that of matrix") | |
| if not values.is_contiguous: | |
| raise ValueError("Multi-dimensional values array should be contiguous") | |
| else: | |
| raise ValueError("Number of dimension for values array should be 1 or 3") | |
| nnz = rows.shape[0] | |
| if nnz == 0: | |
| bsr_set_zero(dest) | |
| return | |
| # Increase dest array sizes if needed | |
| _bsr_ensure_fits(dest, nnz=nnz) | |
| device = dest.values.device | |
| scalar_type = dest.scalar_type | |
| from warp.context import runtime | |
| if device.is_cpu: | |
| if scalar_type == wp.float32: | |
| native_func = runtime.core.bsr_matrix_from_triplets_float_host | |
| elif scalar_type == wp.float64: | |
| native_func = runtime.core.bsr_matrix_from_triplets_double_host | |
| else: | |
| if scalar_type == wp.float32: | |
| native_func = runtime.core.bsr_matrix_from_triplets_float_device | |
| elif scalar_type == wp.float64: | |
| native_func = runtime.core.bsr_matrix_from_triplets_double_device | |
| if not native_func: | |
| raise NotImplementedError(f"bsr_from_triplets not implemented for scalar type {scalar_type}") | |
| dest.nnz = native_func( | |
| dest.block_shape[0], | |
| dest.block_shape[1], | |
| dest.nrow, | |
| nnz, | |
| rows.ptr, | |
| columns.ptr, | |
| values.ptr, | |
| dest.offsets.ptr, | |
| dest.columns.ptr, | |
| dest.values.ptr, | |
| ) | |
| def bsr_assign(dest: BsrMatrix[BlockType[Rows, Cols, Scalar]], src: BsrMatrix[BlockType[Rows, Cols, Any]]): | |
| """Copies the content of the `src` matrix to `dest`, casting the block values if the two matrices use distinct scalar types.""" | |
| if dest.values.device != src.values.device: | |
| raise ValueError("Source and destination matrices must reside on the same device") | |
| if dest.block_shape != src.block_shape: | |
| raise ValueError("Source and destination matrices must have the same block shape") | |
| dest.nrow = src.nrow | |
| dest.ncol = src.ncol | |
| dest.nnz = src.nnz | |
| _bsr_ensure_fits(dest) | |
| wp.copy(dest=dest.offsets, src=src.offsets, count=src.nrow + 1) | |
| if src.nnz > 0: | |
| wp.copy(dest=dest.columns, src=src.columns, count=src.nnz) | |
| warp.utils.array_cast(out_array=dest.values, in_array=src.values, count=src.nnz) | |
| def bsr_copy(A: BsrMatrix, scalar_type: Optional[Scalar] = None): | |
| """Returns a copy of matrix ``A``, possibly changing its scalar type. | |
| Args: | |
| scalar_type: If provided, the returned matrix will use this scalar type instead of the one from `A`. | |
| """ | |
| if scalar_type is None: | |
| block_type = A.values.dtype | |
| elif A.block_shape == (1, 1): | |
| block_type = scalar_type | |
| else: | |
| block_type = wp.types.matrix(shape=A.block_shape, dtype=scalar_type) | |
| copy = bsr_zeros(rows_of_blocks=A.nrow, cols_of_blocks=A.ncol, block_type=block_type, device=A.values.device) | |
| bsr_assign(dest=copy, src=A) | |
| return copy | |
| def bsr_set_transpose(dest: BsrMatrix[BlockType[Cols, Rows, Scalar]], src: BsrMatrix[BlockType[Rows, Cols, Scalar]]): | |
| """Assigns the transposed matrix `src` to matrix `dest`""" | |
| if dest.values.device != src.values.device: | |
| raise ValueError("All arguments must reside on the same device") | |
| if dest.scalar_type != src.scalar_type: | |
| raise ValueError("All arguments must have the same scalar type") | |
| transpose_block_shape = src.block_shape[::-1] | |
| if dest.block_shape != transpose_block_shape: | |
| raise ValueError(f"Destination block shape must be {transpose_block_shape}") | |
| dest.nrow = src.ncol | |
| dest.ncol = src.nrow | |
| dest.nnz = src.nnz | |
| if src.nnz == 0: | |
| return | |
| # Increase dest array sizes if needed | |
| _bsr_ensure_fits(dest) | |
| from warp.context import runtime | |
| if dest.values.device.is_cpu: | |
| if dest.scalar_type == wp.float32: | |
| native_func = runtime.core.bsr_transpose_float_host | |
| elif dest.scalar_type == wp.float64: | |
| native_func = runtime.core.bsr_transpose_double_host | |
| else: | |
| if dest.scalar_type == wp.float32: | |
| native_func = runtime.core.bsr_transpose_float_device | |
| elif dest.scalar_type == wp.float64: | |
| native_func = runtime.core.bsr_transpose_double_device | |
| if not native_func: | |
| raise NotImplementedError(f"bsr_set_transpose not implemented for scalar type {dest.scalar_type}") | |
| native_func( | |
| src.block_shape[0], | |
| src.block_shape[1], | |
| src.nrow, | |
| src.ncol, | |
| src.nnz, | |
| src.offsets.ptr, | |
| src.columns.ptr, | |
| src.values.ptr, | |
| dest.offsets.ptr, | |
| dest.columns.ptr, | |
| dest.values.ptr, | |
| ) | |
| def bsr_transposed(A: BsrMatrix): | |
| """Returns a copy of the transposed matrix `A`""" | |
| if A.block_shape == (1, 1): | |
| block_type = A.values.dtype | |
| else: | |
| block_type = wp.types.matrix(shape=A.block_shape[::-1], dtype=A.scalar_type) | |
| transposed = bsr_zeros(rows_of_blocks=A.ncol, cols_of_blocks=A.nrow, block_type=block_type, device=A.values.device) | |
| bsr_set_transpose(dest=transposed, src=A) | |
| return transposed | |
| def _bsr_get_diag_kernel( | |
| A_offsets: wp.array(dtype=int), | |
| A_columns: wp.array(dtype=int), | |
| A_values: wp.array(dtype=Any), | |
| out: wp.array(dtype=Any), | |
| ): | |
| row = wp.tid() | |
| beg = A_offsets[row] | |
| end = A_offsets[row + 1] | |
| diag = wp.lower_bound(A_columns, beg, end, row) | |
| if diag < end: | |
| if A_columns[diag] == row: | |
| out[row] = A_values[diag] | |
| def bsr_get_diag(A: BsrMatrix[_BlockType], out: "Optional[Array[BlockType]]" = None) -> "Array[BlockType]": | |
| """Returns the array of blocks that constitute the diagonal of a sparse matrix. | |
| Args: | |
| A: the sparse matrix from which to extract the diagonal | |
| out: if provided, the array into which to store the diagonal blocks | |
| """ | |
| dim = min(A.nrow, A.ncol) | |
| if out is None: | |
| out = wp.zeros(shape=(dim,), dtype=A.values.dtype, device=A.values.device) | |
| else: | |
| if out.dtype != A.values.dtype: | |
| raise ValueError(f"Output array must have type {A.values.dtype}") | |
| if out.device != A.values.device: | |
| raise ValueError(f"Output array must reside on device {A.values.device}") | |
| if out.shape[0] < dim: | |
| raise ValueError(f"Output array must be of length at least {dim}") | |
| wp.launch( | |
| kernel=_bsr_get_diag_kernel, dim=dim, device=A.values.device, inputs=[A.offsets, A.columns, A.values, out] | |
| ) | |
| return out | |
| def _bsr_set_diag_kernel( | |
| diag: wp.array(dtype=Any), | |
| A_offsets: wp.array(dtype=int), | |
| A_columns: wp.array(dtype=int), | |
| A_values: wp.array(dtype=Any), | |
| ): | |
| row = wp.tid() | |
| A_offsets[row + 1] = row + 1 | |
| A_columns[row] = row | |
| A_values[row] = diag[row] | |
| if row == 0: | |
| A_offsets[0] = 0 | |
| def _bsr_set_diag_constant_kernel( | |
| diag_value: Any, | |
| A_offsets: wp.array(dtype=int), | |
| A_columns: wp.array(dtype=int), | |
| A_values: wp.array(dtype=Any), | |
| ): | |
| row = wp.tid() | |
| A_offsets[row + 1] = row + 1 | |
| A_columns[row] = row | |
| A_values[row] = diag_value | |
| if row == 0: | |
| A_offsets[0] = 0 | |
| def bsr_set_diag( | |
| A: BsrMatrix[BlockType], | |
| diag: "Union[BlockType, Array[BlockType]]", | |
| rows_of_blocks: Optional[int] = None, | |
| cols_of_blocks: Optional[int] = None, | |
| ): | |
| """Sets `A` as a block-diagonal matrix | |
| Args: | |
| A: the sparse matrix to modify | |
| diag: Either a warp array of type ``A.values.dtype``, in which case each element will define one block of the diagonal, | |
| or a constant value of type ``A.values.dtype``, in which case it will get assigned to all diagonal blocks. | |
| rows_of_blocks: If not ``None``, the new number of rows of blocks | |
| cols_of_blocks: If not ``None``, the new number of columns of blocks | |
| The shape of the matrix will be defined one of the following, in that order: | |
| - `rows_of_blocks` and `cols_of_blocks`, if provided. If only one is given, the second is assumed equal. | |
| - the first dimension of `diag`, if `diag` is an array | |
| - the current dimensions of `A` otherwise | |
| """ | |
| if rows_of_blocks is None and cols_of_blocks is not None: | |
| rows_of_blocks = cols_of_blocks | |
| if cols_of_blocks is None and rows_of_blocks is not None: | |
| cols_of_blocks = rows_of_blocks | |
| if warp.types.is_array(diag): | |
| if rows_of_blocks is None: | |
| rows_of_blocks = diag.shape[0] | |
| cols_of_blocks = diag.shape[0] | |
| if rows_of_blocks is not None: | |
| A.nrow = rows_of_blocks | |
| A.ncol = cols_of_blocks | |
| A.nnz = min(A.nrow, A.ncol) | |
| _bsr_ensure_fits(A) | |
| if warp.types.is_array(diag): | |
| wp.launch( | |
| kernel=_bsr_set_diag_kernel, | |
| dim=A.nnz, | |
| device=A.values.device, | |
| inputs=[diag, A.offsets, A.columns, A.values], | |
| ) | |
| else: | |
| if not warp.types.type_is_value(type(diag)): | |
| # Cast to launchable type | |
| diag = A.values.dtype(diag) | |
| wp.launch( | |
| kernel=_bsr_set_diag_constant_kernel, | |
| dim=A.nnz, | |
| device=A.values.device, | |
| inputs=[diag, A.offsets, A.columns, A.values], | |
| ) | |
| def bsr_diag( | |
| diag: "Union[BlockType, Array[BlockType]]", | |
| rows_of_blocks: Optional[int] = None, | |
| cols_of_blocks: Optional[int] = None, | |
| ) -> BsrMatrix["BlockType"]: | |
| """Creates and returns a block-diagonal BSR matrix from an given block value or array of block values. | |
| Args: | |
| diag: Either a warp array of type ``A.values.dtype``, in which case each element will define one block of the diagonal, | |
| or a constant value of type ``A.values.dtype``, in which case it will get assigned to all diagonal blocks. | |
| rows_of_blocks: If not ``None``, the new number of rows of blocks | |
| cols_of_blocks: If not ``None``, the new number of columns of blocks | |
| The shape of the matrix will be defined one of the following, in that order: | |
| - `rows_of_blocks` and `cols_of_blocks`, if provided. If only one is given, the second is assumed equal. | |
| - the first dimension of `diag`, if `diag` is an array | |
| """ | |
| if rows_of_blocks is None and cols_of_blocks is not None: | |
| rows_of_blocks = cols_of_blocks | |
| if cols_of_blocks is None and rows_of_blocks is not None: | |
| cols_of_blocks = rows_of_blocks | |
| if warp.types.is_array(diag): | |
| if rows_of_blocks is None: | |
| rows_of_blocks = diag.shape[0] | |
| cols_of_blocks = diag.shape[0] | |
| A = bsr_zeros( | |
| rows_of_blocks, | |
| cols_of_blocks, | |
| block_type=diag.dtype, | |
| device=diag.device, | |
| ) | |
| else: | |
| if rows_of_blocks is None: | |
| raise ValueError( | |
| "rows_of_blocks and/or cols_of_blocks must be provided for constructing a diagonal matrix with uniform diagonal" | |
| ) | |
| block_type = type(diag) | |
| if not warp.types.type_is_matrix(block_type) and len(getattr(diag, "shape", ())) == 2: | |
| block_type = wp.mat(shape=diag.shape, dtype=diag.dtype) | |
| A = bsr_zeros( | |
| rows_of_blocks, | |
| cols_of_blocks, | |
| block_type=block_type, | |
| ) | |
| bsr_set_diag(A, diag) | |
| return A | |
| def bsr_set_identity(A: BsrMatrix, rows_of_blocks: Optional[int] = None): | |
| """Sets `A` as the identity matrix | |
| Args: | |
| A: the sparse matrix to modify | |
| rows_of_blocks: if provided, the matrix will be resized as a square matrix with `rows_of_blocks` rows and columns. | |
| """ | |
| if A.block_shape == (1, 1): | |
| identity = A.scalar_type(1.0) | |
| else: | |
| from numpy import eye | |
| identity = eye(A.block_shape[0]) | |
| bsr_set_diag(A, diag=identity, rows_of_blocks=rows_of_blocks, cols_of_blocks=rows_of_blocks) | |
| def bsr_identity( | |
| rows_of_blocks: int, block_type: BlockType[Rows, Rows, Scalar], device: wp.context.Devicelike = None | |
| ) -> BsrMatrix[BlockType[Rows, Rows, Scalar]]: | |
| """Creates and returns a square identity matrix. | |
| Args: | |
| rows_of_blocks: Number of rows and columns of blocks in the created matrix. | |
| block_type: Block type for the newly created matrix -- must be square | |
| device: Device onto which to allocate the data arrays | |
| """ | |
| A = bsr_zeros(rows_of_blocks=rows_of_blocks, cols_of_blocks=rows_of_blocks, block_type=block_type, device=device) | |
| bsr_set_identity(A) | |
| return A | |
| def _bsr_scale_kernel( | |
| alpha: Any, | |
| values: wp.array(dtype=Any), | |
| ): | |
| values[wp.tid()] = alpha * values[wp.tid()] | |
| def bsr_scale(x: BsrMatrix, alpha: Scalar) -> BsrMatrix: | |
| """ | |
| Performs the operation ``x := alpha * x`` on BSR matrix `x` and returns `x` | |
| """ | |
| if alpha != 1.0 and x.nnz > 0: | |
| if alpha == 0.0: | |
| bsr_set_zero(x) | |
| else: | |
| if not isinstance(alpha, x.scalar_type): | |
| alpha = x.scalar_type(alpha) | |
| wp.launch(kernel=_bsr_scale_kernel, dim=x.nnz, device=x.values.device, inputs=[alpha, x.values]) | |
| return x | |
| def _bsr_get_block_row(dest_offset: int, bsr_offsets: wp.array(dtype=int), rows: wp.array(dtype=int)): | |
| i = wp.tid() | |
| row = wp.lower_bound(bsr_offsets, i + 1) - 1 | |
| rows[dest_offset + i] = row | |
| def _bsr_axpy_add_block( | |
| src_offset: int, | |
| scale: Any, | |
| rows: wp.array(dtype=int), | |
| cols: wp.array(dtype=int), | |
| dst_offsets: wp.array(dtype=int), | |
| dst_columns: wp.array(dtype=int), | |
| src_values: wp.array(dtype=Any), | |
| dst_values: wp.array(dtype=Any), | |
| ): | |
| i = wp.tid() | |
| row = rows[i + src_offset] | |
| col = cols[i + src_offset] | |
| beg = dst_offsets[row] | |
| end = dst_offsets[row + 1] | |
| block = wp.lower_bound(dst_columns, beg, end, col) | |
| dst_values[block] = dst_values[block] + scale * src_values[i] | |
| class bsr_axpy_work_arrays: | |
| """Opaque structure for persisting :func:`bsr_axpy` temporary work buffers across calls""" | |
| def __init__(self): | |
| self._reset(None) | |
| def _reset(self, device): | |
| self.device = device | |
| self._sum_rows = None | |
| self._sum_cols = None | |
| self._old_y_values = None | |
| self._old_x_values = None | |
| def _allocate(self, device, y: BsrMatrix, sum_nnz: int): | |
| if self.device != device: | |
| self._reset(device) | |
| if self._sum_rows is None or self._sum_rows.size < sum_nnz: | |
| self._sum_rows = wp.empty(shape=(sum_nnz), dtype=int, device=self.device) | |
| if self._sum_cols is None or self._sum_cols.size < sum_nnz: | |
| self._sum_cols = wp.empty(shape=(sum_nnz), dtype=int, device=self.device) | |
| if self._old_y_values is None or self._old_y_values.size < y.nnz: | |
| self._old_y_values = wp.empty(shape=(y.nnz), dtype=y.values.dtype, device=self.device) | |
| def bsr_axpy( | |
| x: BsrMatrix[BlockType[Rows, Cols, Scalar]], | |
| y: Optional[BsrMatrix[BlockType[Rows, Cols, Scalar]]] = None, | |
| alpha: Scalar = 1.0, | |
| beta: Scalar = 1.0, | |
| work_arrays: Optional[bsr_axpy_work_arrays] = None, | |
| ) -> BsrMatrix[BlockType[Rows, Cols, Scalar]]: | |
| """ | |
| Performs the sparse matrix addition ``y := alpha * X + beta * y`` on BSR matrices `x` and `y` and returns `y`. | |
| The `x` and `y` matrices are allowed to alias. | |
| Args: | |
| x: Read-only right-hand-side. | |
| y: Mutable left-hand-side. If `y` is not provided, it will be allocated and treated as zero. | |
| alpha: Uniform scaling factor for `x` | |
| beta: Uniform scaling factor for `y` | |
| work_arrays: In most cases this function will require the use of temporary storage; this storage can be reused across calls by passing an instance of :class:`bsr_axpy_work_arrays` in `work_arrays`. | |
| """ | |
| if y is None: | |
| # If not output matrix is provided, allocate it for convenience | |
| y = bsr_zeros(x.nrow, x.ncol, block_type=x.values.dtype, device=x.values.device) | |
| beta = 0.0 | |
| # Handle easy cases first | |
| if beta == 0.0 or y.nnz == 0: | |
| bsr_assign(src=x, dest=y) | |
| return bsr_scale(y, alpha=alpha) | |
| if alpha == 0.0 or x.nnz == 0: | |
| return bsr_scale(y, alpha=beta) | |
| if not isinstance(alpha, y.scalar_type): | |
| alpha = y.scalar_type(alpha) | |
| if not isinstance(beta, y.scalar_type): | |
| beta = y.scalar_type(beta) | |
| if x == y: | |
| # Aliasing case | |
| return bsr_scale(y, alpha=alpha.value + beta.value) | |
| # General case | |
| if x.values.device != y.values.device: | |
| raise ValueError("All arguments must reside on the same device") | |
| if x.scalar_type != y.scalar_type or x.block_shape != y.block_shape: | |
| raise ValueError("Matrices must have the same block type") | |
| if x.nrow != y.nrow or x.ncol != y.ncol: | |
| raise ValueError("Matrices must have the same number of rows and columns") | |
| if work_arrays is None: | |
| work_arrays = bsr_axpy_work_arrays() | |
| sum_nnz = x.nnz + y.nnz | |
| device = y.values.device | |
| work_arrays._allocate(device, y, sum_nnz) | |
| wp.copy(work_arrays._sum_cols, y.columns, 0, 0, y.nnz) | |
| wp.launch(kernel=_bsr_get_block_row, device=device, dim=y.nnz, inputs=[0, y.offsets, work_arrays._sum_rows]) | |
| wp.copy(work_arrays._sum_cols, x.columns, y.nnz, 0, x.nnz) | |
| wp.launch(kernel=_bsr_get_block_row, device=device, dim=x.nnz, inputs=[y.nnz, x.offsets, work_arrays._sum_rows]) | |
| # Save old y values before overwriting matrix | |
| wp.copy(dest=work_arrays._old_y_values, src=y.values, count=y.nnz) | |
| # Increase dest array sizes if needed | |
| if y.columns.shape[0] < sum_nnz: | |
| y.columns = wp.empty(shape=(sum_nnz,), dtype=int, device=device) | |
| from warp.context import runtime | |
| if device.is_cpu: | |
| native_func = runtime.core.bsr_matrix_from_triplets_float_host | |
| else: | |
| native_func = runtime.core.bsr_matrix_from_triplets_float_device | |
| old_y_nnz = y.nnz | |
| y.nnz = native_func( | |
| y.block_shape[0], | |
| y.block_shape[1], | |
| y.nrow, | |
| sum_nnz, | |
| work_arrays._sum_rows.ptr, | |
| work_arrays._sum_cols.ptr, | |
| 0, | |
| y.offsets.ptr, | |
| y.columns.ptr, | |
| 0, | |
| ) | |
| _bsr_ensure_fits(y) | |
| y.values.zero_() | |
| wp.launch( | |
| kernel=_bsr_axpy_add_block, | |
| device=device, | |
| dim=old_y_nnz, | |
| inputs=[ | |
| 0, | |
| beta, | |
| work_arrays._sum_rows, | |
| work_arrays._sum_cols, | |
| y.offsets, | |
| y.columns, | |
| work_arrays._old_y_values, | |
| y.values, | |
| ], | |
| ) | |
| wp.launch( | |
| kernel=_bsr_axpy_add_block, | |
| device=device, | |
| dim=x.nnz, | |
| inputs=[ | |
| old_y_nnz, | |
| alpha, | |
| work_arrays._sum_rows, | |
| work_arrays._sum_cols, | |
| y.offsets, | |
| y.columns, | |
| x.values, | |
| y.values, | |
| ], | |
| ) | |
| return y | |
| def _bsr_mm_count_coeffs( | |
| z_nnz: int, | |
| x_offsets: wp.array(dtype=int), | |
| x_columns: wp.array(dtype=int), | |
| y_offsets: wp.array(dtype=int), | |
| counts: wp.array(dtype=int), | |
| ): | |
| row = wp.tid() | |
| count = int(0) | |
| x_beg = x_offsets[row] | |
| x_end = x_offsets[row + 1] | |
| for x_block in range(x_beg, x_end): | |
| x_col = x_columns[x_block] | |
| count += y_offsets[x_col + 1] - y_offsets[x_col] | |
| counts[row + 1] = count | |
| if row == 0: | |
| counts[0] = z_nnz | |
| def _bsr_mm_list_coeffs( | |
| x_offsets: wp.array(dtype=int), | |
| x_columns: wp.array(dtype=int), | |
| y_offsets: wp.array(dtype=int), | |
| y_columns: wp.array(dtype=int), | |
| mm_offsets: wp.array(dtype=int), | |
| mm_rows: wp.array(dtype=int), | |
| mm_cols: wp.array(dtype=int), | |
| ): | |
| row = wp.tid() | |
| mm_block = mm_offsets[row] | |
| x_beg = x_offsets[row] | |
| x_end = x_offsets[row + 1] | |
| for x_block in range(x_beg, x_end): | |
| x_col = x_columns[x_block] | |
| y_beg = y_offsets[x_col] | |
| y_end = y_offsets[x_col + 1] | |
| for y_block in range(y_beg, y_end): | |
| mm_cols[mm_block] = y_columns[y_block] | |
| mm_rows[mm_block] = row | |
| mm_block += 1 | |
| def _bsr_mm_compute_values( | |
| alpha: Any, | |
| x_offsets: wp.array(dtype=int), | |
| x_columns: wp.array(dtype=int), | |
| x_values: wp.array(dtype=Any), | |
| y_offsets: wp.array(dtype=int), | |
| y_columns: wp.array(dtype=int), | |
| y_values: wp.array(dtype=Any), | |
| mm_offsets: wp.array(dtype=int), | |
| mm_cols: wp.array(dtype=int), | |
| mm_values: wp.array(dtype=Any), | |
| ): | |
| row = wp.tid() | |
| mm_beg = mm_offsets[row] | |
| mm_end = mm_offsets[row + 1] | |
| x_beg = x_offsets[row] | |
| x_end = x_offsets[row + 1] | |
| for x_block in range(x_beg, x_end): | |
| x_col = x_columns[x_block] | |
| ax_val = alpha * x_values[x_block] | |
| y_beg = y_offsets[x_col] | |
| y_end = y_offsets[x_col + 1] | |
| for y_block in range(y_beg, y_end): | |
| mm_block = wp.lower_bound(mm_cols, mm_beg, mm_end, y_columns[y_block]) | |
| mm_values[mm_block] = mm_values[mm_block] + ax_val * y_values[y_block] | |
| class bsr_mm_work_arrays: | |
| """Opaque structure for persisting :func:`bsr_mm` temporary work buffers across calls""" | |
| def __init__(self): | |
| self._reset(None) | |
| def _reset(self, device): | |
| self.device = device | |
| self._pinned_count_buffer = None | |
| self._mm_row_counts = None | |
| self._mm_rows = None | |
| self._mm_cols = None | |
| self._old_z_values = None | |
| self._old_z_offsets = None | |
| self._old_z_columns = None | |
| def _allocate_stage_1(self, device, z: BsrMatrix, copied_z_nnz: int, z_aliasing: bool): | |
| if self.device != device: | |
| self._reset(device) | |
| # Allocations that do not depend on any computation | |
| if self.device.is_cuda: | |
| if self._pinned_count_buffer is None: | |
| self._pinned_count_buffer = wp.empty(shape=(1,), dtype=int, pinned=True, device="cpu") | |
| if self._mm_row_counts is None or self._mm_row_counts.size < z.nrow + 1: | |
| self._mm_row_counts = wp.empty(shape=(z.nrow + 1,), dtype=int, device=self.device) | |
| if copied_z_nnz > 0: | |
| if self._old_z_values is None or self._old_z_values.size < copied_z_nnz: | |
| self._old_z_values = wp.empty(shape=(copied_z_nnz,), dtype=z.values.dtype, device=self.device) | |
| if z_aliasing: | |
| if self._old_z_columns is None or self._old_z_columns.size < z.nnz: | |
| self._old_z_columns = wp.empty(shape=(z.nnz,), dtype=z.columns.dtype, device=self.device) | |
| if self._old_z_offsets is None or self._old_z_offsets.size < z.nrow + 1: | |
| self._old_z_offsets = wp.empty(shape=(z.nrow + 1,), dtype=z.offsets.dtype, device=self.device) | |
| def _allocate_stage_2(self, mm_nnz: int): | |
| # Allocations that depend on unmerged nnz estimate | |
| if self._mm_rows is None or self._mm_rows.size < mm_nnz: | |
| self._mm_rows = wp.empty(shape=(mm_nnz,), dtype=int, device=self.device) | |
| if self._mm_cols is None or self._mm_cols.size < mm_nnz: | |
| self._mm_cols = wp.empty(shape=(mm_nnz,), dtype=int, device=self.device) | |
| def bsr_mm( | |
| x: BsrMatrix[BlockType[Rows, Any, Scalar]], | |
| y: BsrMatrix[BlockType[Any, Cols, Scalar]], | |
| z: Optional[BsrMatrix[BlockType[Rows, Cols, Scalar]]] = None, | |
| alpha: Scalar = 1.0, | |
| beta: Scalar = 0.0, | |
| work_arrays: Optional[bsr_mm_work_arrays] = None, | |
| ) -> BsrMatrix[BlockType[Rows, Cols, Scalar]]: | |
| """ | |
| Performs the sparse matrix-matrix multiplication ``z := alpha * x * y + beta * z`` on BSR matrices `x`, `y` and `z`, and returns `z`. | |
| The `x`, `y` and `z` matrices are allowed to alias. | |
| If the matrix `z` is not provided as input, it will be allocated and treated as zero. | |
| Args: | |
| x: Read-only left factor of the matrix-matrix product. | |
| y: Read-only right factor of the matrix-matrix product. | |
| z: Mutable left-hand-side. If `z` is not provided, it will be allocated and treated as zero. | |
| alpha: Uniform scaling factor for the ``x * y`` product | |
| beta: Uniform scaling factor for `z` | |
| work_arrays: In most cases this function will require the use of temporary storage; this storage can be reused across calls by passing an instance of :class:`bsr_mm_work_arrays` in `work_arrays`. | |
| """ | |
| if z is None: | |
| # If not output matrix is provided, allocate it for convenience | |
| z_block_shape = (x.block_shape[0], y.block_shape[1]) | |
| if z_block_shape == (1, 1): | |
| z_block_type = x.scalar_type | |
| else: | |
| z_block_type = wp.types.matrix(shape=z_block_shape, dtype=x.scalar_type) | |
| z = bsr_zeros(x.nrow, y.ncol, block_type=z_block_type, device=x.values.device) | |
| beta = 0.0 | |
| if x.values.device != y.values.device or x.values.device != z.values.device: | |
| raise ValueError("All arguments must reside on the same device") | |
| if x.scalar_type != y.scalar_type or x.scalar_type != z.scalar_type: | |
| raise ValueError("Matrices must have the same scalar type") | |
| if ( | |
| x.block_shape[0] != z.block_shape[0] | |
| or y.block_shape[1] != z.block_shape[1] | |
| or x.block_shape[1] != y.block_shape[0] | |
| ): | |
| raise ValueError("Incompatible block sizes for matrix multiplication") | |
| if x.nrow != z.nrow or z.ncol != y.ncol or x.ncol != y.nrow: | |
| raise ValueError("Incompatible number of rows/columns for matrix multiplication") | |
| device = z.values.device | |
| if alpha == 0.0 or x.nnz == 0 or y.nnz == 0: | |
| # Easy case | |
| return bsr_scale(z, beta) | |
| if not isinstance(alpha, z.scalar_type): | |
| alpha = z.scalar_type(alpha) | |
| if not isinstance(beta, z.scalar_type): | |
| beta = z.scalar_type(beta) | |
| if work_arrays is None: | |
| work_arrays = bsr_mm_work_arrays() | |
| z_aliasing = z == x or z == y | |
| copied_z_nnz = z.nnz if beta != 0.0 or z_aliasing else 0 | |
| work_arrays._allocate_stage_1(device, z, copied_z_nnz, z_aliasing) | |
| # Prefix sum of number of (unmerged) mm blocks per row | |
| wp.launch( | |
| kernel=_bsr_mm_count_coeffs, | |
| device=device, | |
| dim=z.nrow, | |
| inputs=[copied_z_nnz, x.offsets, x.columns, y.offsets, work_arrays._mm_row_counts], | |
| ) | |
| warp.utils.array_scan(work_arrays._mm_row_counts, work_arrays._mm_row_counts) | |
| # Get back total counts on host | |
| if device.is_cuda: | |
| wp.copy(dest=work_arrays._pinned_count_buffer, src=work_arrays._mm_row_counts, src_offset=z.nrow, count=1) | |
| wp.synchronize_stream(wp.get_stream(device)) | |
| mm_nnz = int(work_arrays._pinned_count_buffer.numpy()[0]) | |
| else: | |
| mm_nnz = int(work_arrays._mm_row_counts.numpy()[z.nrow]) | |
| work_arrays._allocate_stage_2(mm_nnz) | |
| # If z has a non-zero scale, save current data before overwriting it | |
| if copied_z_nnz > 0: | |
| # Copy z row and column indices | |
| wp.copy(dest=work_arrays._mm_cols, src=z.columns, count=copied_z_nnz) | |
| wp.launch( | |
| kernel=_bsr_get_block_row, device=device, dim=copied_z_nnz, inputs=[0, z.offsets, work_arrays._mm_rows] | |
| ) | |
| # Save current z values in temporary buffer | |
| wp.copy(src=z.values, dest=work_arrays._old_z_values, count=copied_z_nnz) | |
| if z_aliasing: | |
| # If z is aliasing with x or y, need to save topology as well | |
| wp.copy(src=z.columns, dest=work_arrays._old_z_columns, count=copied_z_nnz) | |
| wp.copy(src=z.offsets, dest=work_arrays._old_z_offsets, count=z.nrow + 1) | |
| # Fill unmerged mm blocks rows and columns | |
| wp.launch( | |
| kernel=_bsr_mm_list_coeffs, | |
| device=device, | |
| dim=z.nrow, | |
| inputs=[ | |
| x.offsets, | |
| x.columns, | |
| y.offsets, | |
| y.columns, | |
| work_arrays._mm_row_counts, | |
| work_arrays._mm_rows, | |
| work_arrays._mm_cols, | |
| ], | |
| ) | |
| # Increase dest array size if needed | |
| if z.columns.shape[0] < mm_nnz: | |
| z.columns = wp.empty(shape=(mm_nnz,), dtype=int, device=device) | |
| from warp.context import runtime | |
| if device.is_cpu: | |
| native_func = runtime.core.bsr_matrix_from_triplets_float_host | |
| else: | |
| native_func = runtime.core.bsr_matrix_from_triplets_float_device | |
| z.nnz = native_func( | |
| z.block_shape[0], | |
| z.block_shape[1], | |
| z.nrow, | |
| mm_nnz, | |
| work_arrays._mm_rows.ptr, | |
| work_arrays._mm_cols.ptr, | |
| 0, | |
| z.offsets.ptr, | |
| z.columns.ptr, | |
| 0, | |
| ) | |
| _bsr_ensure_fits(z) | |
| z.values.zero_() | |
| if copied_z_nnz > 0: | |
| # Add back original z values | |
| wp.launch( | |
| kernel=_bsr_axpy_add_block, | |
| device=device, | |
| dim=copied_z_nnz, | |
| inputs=[ | |
| 0, | |
| beta, | |
| work_arrays._mm_rows, | |
| work_arrays._mm_cols, | |
| z.offsets, | |
| z.columns, | |
| work_arrays._old_z_values, | |
| z.values, | |
| ], | |
| ) | |
| # Add mm blocks to z values | |
| if (warp.types.type_is_matrix(x.values.dtype) or warp.types.type_is_matrix(y.values.dtype)) and not ( | |
| warp.types.type_is_matrix(z.values.dtype) | |
| ): | |
| # Result block type is scalar, but operands are matrices | |
| # Cast result to (1x1) matrix to perform multiplication | |
| mm_values = z.values.view(wp.types.matrix(shape=(1, 1), dtype=z.scalar_type)) | |
| else: | |
| mm_values = z.values | |
| wp.launch( | |
| kernel=_bsr_mm_compute_values, | |
| device=device, | |
| dim=z.nrow, | |
| inputs=[ | |
| alpha, | |
| work_arrays._old_z_offsets if x == z else x.offsets, | |
| work_arrays._old_z_columns if x == z else x.columns, | |
| work_arrays._old_z_values if x == z else x.values, | |
| work_arrays._old_z_offsets if y == z else y.offsets, | |
| work_arrays._old_z_columns if y == z else y.columns, | |
| work_arrays._old_z_values if y == z else y.values, | |
| z.offsets, | |
| z.columns, | |
| mm_values, | |
| ], | |
| ) | |
| return z | |
| def _bsr_mv_kernel( | |
| alpha: Any, | |
| A_offsets: wp.array(dtype=int), | |
| A_columns: wp.array(dtype=int), | |
| A_values: wp.array(dtype=Any), | |
| x: wp.array(dtype=Any), | |
| beta: Any, | |
| y: wp.array(dtype=Any), | |
| ): | |
| row = wp.tid() | |
| # zero-initialize with type of y elements | |
| scalar_zero = type(alpha)(0) | |
| v = y.dtype(scalar_zero) | |
| if alpha != scalar_zero: | |
| beg = A_offsets[row] | |
| end = A_offsets[row + 1] | |
| for block in range(beg, end): | |
| v += A_values[block] * x[A_columns[block]] | |
| v *= alpha | |
| if beta != scalar_zero: | |
| v += beta * y[row] | |
| y[row] = v | |
| def bsr_mv( | |
| A: BsrMatrix[BlockType[Rows, Cols, Scalar]], | |
| x: "Array[Vector[Cols, Scalar] | Scalar]", | |
| y: Optional["Array[Vector[Rows, Scalar] | Scalar]"] = None, | |
| alpha: Scalar = 1.0, | |
| beta: Scalar = 0.0, | |
| work_buffer: Optional["Array[Vector[Rows, Scalar] | Scalar]"] = None, | |
| ) -> "Array[Vector[Rows, Scalar] | Scalar]": | |
| """ | |
| Performs the sparse matrix-vector product ``y := alpha * A * x + beta * y`` and returns `y`. | |
| The `x` and `y` vectors are allowed to alias. | |
| Args: | |
| A: Read-only, left matrix factor of the matrix-vector product. | |
| x: Read-only, right vector factor of the matrix-vector product. | |
| y: Mutable left-hand-side. If `y` is not provided, it will be allocated and treated as zero. | |
| alpha: Uniform scaling factor for `x`. If zero, `x` will not be read and may be left uninitialized. | |
| beta: Uniform scaling factor for `y`. If zero, `y` will not be read and may be left uninitialized. | |
| work_buffer: Temporary storage is required if and only if `x` and `y` are the same vector. If provided the `work_buffer` array | |
| will be used for this purpose, otherwise a temporary allocation wil be performed. | |
| """ | |
| if y is None: | |
| # If no output array is provided, allocate one for convenience | |
| y_vec_len = A.block_shape[0] | |
| y_dtype = A.scalar_type if y_vec_len == 1 else wp.vec(length=y_vec_len, dtype=A.scalar_type) | |
| y = wp.empty(shape=(A.nrow,), device=A.values.device, dtype=y_dtype) | |
| y.zero_() | |
| beta = 0.0 | |
| if not isinstance(alpha, A.scalar_type): | |
| alpha = A.scalar_type(alpha) | |
| if not isinstance(beta, A.scalar_type): | |
| beta = A.scalar_type(beta) | |
| if A.values.device != x.device or A.values.device != y.device: | |
| raise ValueError("A, x and y must reside on the same device") | |
| if x.shape[0] != A.ncol: | |
| raise ValueError("Number of columns of A must match number of rows of x") | |
| if y.shape[0] != A.nrow: | |
| raise ValueError("Number of rows of A must match number of rows of y") | |
| if x == y: | |
| # Aliasing case, need temporary storage | |
| if work_buffer is None: | |
| work_buffer = wp.empty_like(y) | |
| elif work_buffer.size < y.size: | |
| raise ValueError(f"Work buffer size is insufficient, needs to be at least {y.size}") | |
| elif not wp.types.types_equal(work_buffer.dtype, y.dtype): | |
| raise ValueError(f"Work buffer must have same data type as y, {wp.types.type_repr(y.dtype)}") | |
| # Save old y values before overwriting vector | |
| wp.copy(dest=work_buffer, src=y, count=y.size) | |
| x = work_buffer | |
| # Promote scalar vectors to length-1 vecs and conversely | |
| if warp.types.type_is_matrix(A.values.dtype): | |
| if A.block_shape[0] == 1: | |
| if y.dtype == A.scalar_type: | |
| y = y.view(dtype=wp.vec(length=1, dtype=A.scalar_type)) | |
| if A.block_shape[1] == 1: | |
| if x.dtype == A.scalar_type: | |
| x = x.view(dtype=wp.vec(length=1, dtype=A.scalar_type)) | |
| else: | |
| if A.block_shape[0] == 1: | |
| if y.dtype != A.scalar_type: | |
| y = y.view(dtype=A.scalar_type) | |
| if A.block_shape[1] == 1: | |
| if x.dtype != A.scalar_type: | |
| x = x.view(dtype=A.scalar_type) | |
| wp.launch( | |
| kernel=_bsr_mv_kernel, | |
| device=A.values.device, | |
| dim=A.nrow, | |
| inputs=[alpha, A.offsets, A.columns, A.values, x, beta, y], | |
| ) | |
| return y | |