from typing import Any, Generic, Optional, Tuple, TypeVar, Union import warp as wp import warp.types import warp.utils from warp.types import Array, Cols, Matrix, Rows, Scalar, Vector # typing hints _BlockType = TypeVar("BlockType") class _MatrixBlockType(Matrix): pass class _ScalarBlockType(Generic[Scalar]): pass BlockType = Union[_MatrixBlockType[Rows, Cols, Scalar], _ScalarBlockType[Scalar]] _struct_cache = dict() class BsrMatrix(Generic[_BlockType]): """Untyped base class for BSR and CSR matrices. Should not be constructed directly but through functions such as :func:`bsr_zeros`. Attributes: nrow (int): Number of rows of blocks ncol (int): Number of columns of blocks nnz (int): Number of non-zero blocks: must be equal to ``offsets[nrow-1]``, cached on host for convenience offsets (Array[int]): Array of size at least ``1 + nrows`` such that the start and end indices of the blocks of row ``r`` are ``offsets[r]`` and ``offsets[r+1]``, respectively. columns (Array[int]): Array of size at least equal to ``nnz`` containing block column indices values (Array[BlockType]): Array of size at least equal to ``nnz`` containing block values """ @property def scalar_type(self) -> Scalar: """Scalar type for individual block coefficients. For CSR matrices, this is the same as the block type""" return warp.types.type_scalar_type(self.values.dtype) @property def block_shape(self) -> Tuple[int, int]: """Shape of the individual blocks""" return getattr(self.values.dtype, "_shape_", (1, 1)) @property def block_size(self) -> int: """Size of the individual blocks, i.e. number of rows per block times number of columns per block""" return warp.types.type_length(self.values.dtype) @property def shape(self) -> Tuple[int, int]: """Shape of the matrix, i.e. number of rows/columns of blocks times number of rows/columns per block""" block_shape = self.block_shape return (self.nrow * block_shape[0], self.ncol * block_shape[1]) def bsr_matrix_t(dtype: BlockType): dtype = wp.types.type_to_warp(dtype) if not warp.types.type_is_matrix(dtype) and not dtype in warp.types.scalar_types: raise ValueError( f"BsrMatrix block type must be either warp matrix or scalar; got {warp.types.type_repr(dtype)}" ) class BsrMatrixTyped(BsrMatrix): nrow: int """Number of rows of blocks""" ncol: int """Number of columns of blocks""" nnz: int """Number of non-zero blocks: equal to offsets[-1], cached on host for convenience""" offsets: wp.array(dtype=int) """Array of size at least 1 + nrows""" columns: wp.array(dtype=int) """Array of size at least equal to nnz""" values: wp.array(dtype=dtype) module = wp.get_module(BsrMatrix.__module__) if hasattr(dtype, "_shape_"): type_str = f"{warp.types.type_scalar_type(dtype).__name__}_{dtype._shape_[0]}_{dtype._shape_[1]}" else: type_str = dtype.__name__ key = f"{BsrMatrix.__qualname__}_{type_str}" if key not in _struct_cache: _struct_cache[key] = wp.codegen.Struct( cls=BsrMatrixTyped, key=key, module=module, ) return _struct_cache[key] def bsr_zeros( rows_of_blocks: int, cols_of_blocks: int, block_type: BlockType, device: wp.context.Devicelike = None, ) -> BsrMatrix: """ Constructs and returns an empty BSR or CSR matrix with the given shape Args: bsr: The BSR or CSR matrix to set to zero rows_of_blocks: Number of rows of blocks cols_of_blocks: Number of columns of blocks block_type: Type of individual blocks. For CSR matrices, this should be a scalar type; for BSR matrices, this should be a matrix type (e.g. from :func:`warp.mat`) device: Device on which to allocate the matrix arrays """ bsr = bsr_matrix_t(block_type)() bsr.nrow = rows_of_blocks bsr.ncol = cols_of_blocks bsr.nnz = 0 bsr.columns = wp.empty(shape=(0,), dtype=int, device=device) bsr.values = wp.empty(shape=(0,), dtype=block_type, device=device) bsr.offsets = wp.zeros(shape=(bsr.nrow + 1,), dtype=int, device=device) return bsr def _bsr_ensure_fits(bsr: BsrMatrix, nrow: int = None, nnz: int = None): if nrow is None: nrow = bsr.nrow if nnz is None: nnz = bsr.nnz if bsr.offsets.size < nrow + 1: bsr.offsets = wp.empty(shape=(nrow + 1,), dtype=int, device=bsr.offsets.device) if bsr.columns.size < nnz: bsr.columns = wp.empty(shape=(nnz,), dtype=int, device=bsr.columns.device) if bsr.values.size < nnz: bsr.values = wp.empty(shape=(nnz,), dtype=bsr.values.dtype, device=bsr.values.device) def bsr_set_zero(bsr: BsrMatrix, rows_of_blocks: Optional[int] = None, cols_of_blocks: Optional[int] = None): """ Sets a BSR matrix to zero, possibly changing its size Args: bsr: The BSR or CSR matrix to set to zero rows_of_blocks: If not ``None``, the new number of rows of blocks cols_of_blocks: If not ``None``, the new number of columns of blocks """ if rows_of_blocks is not None: bsr.nrow = rows_of_blocks if cols_of_blocks is not None: bsr.ncol = cols_of_blocks bsr.nnz = 0 _bsr_ensure_fits(bsr) bsr.offsets.zero_() def bsr_set_from_triplets( dest: BsrMatrix[BlockType[Rows, Cols, Scalar]], rows: "Array[int]", columns: "Array[int]", values: "Array[Union[Scalar, BlockType[Rows, Cols, Scalar]]]", ): """ Fills a BSR matrix with values defined by coordinate-oriented (COO) triplets, discarding existing blocks. The first dimension of the three input arrays must match, and determines the number of non-zeros in the constructed matrix. Args: dest: Sparse matrix to populate rows: Row index for each non-zero columns: Columns index for each non-zero values: Block values for each non-zero. Must be either a one-dimensional array with data type identical to the `dest` matrix's block type, or a 3d array with data type equal to the `dest` matrix's scalar type. """ if values.device != columns.device or values.device != rows.device or values.device != dest.values.device: raise ValueError("All arguments must reside on the same device") if values.shape[0] != rows.shape[0] or values.shape[0] != columns.shape[0]: raise ValueError("All triplet arrays must have the same length") # Accept either array1d(dtype) or contiguous array3d(scalar_type) as values if values.ndim == 1: if values.dtype != dest.values.dtype: raise ValueError("Values array type must correspond to that of dest matrix") elif values.ndim == 3: if values.shape[1:] != dest.block_shape: raise ValueError( f"Last two dimensions in values array ({values.shape[1:]}) should correspond to matrix block shape {(dest.block_shape)})" ) if warp.types.type_scalar_type(values.dtype) != dest.scalar_type: raise ValueError("Scalar type of values array should correspond to that of matrix") if not values.is_contiguous: raise ValueError("Multi-dimensional values array should be contiguous") else: raise ValueError("Number of dimension for values array should be 1 or 3") nnz = rows.shape[0] if nnz == 0: bsr_set_zero(dest) return # Increase dest array sizes if needed _bsr_ensure_fits(dest, nnz=nnz) device = dest.values.device scalar_type = dest.scalar_type from warp.context import runtime if device.is_cpu: if scalar_type == wp.float32: native_func = runtime.core.bsr_matrix_from_triplets_float_host elif scalar_type == wp.float64: native_func = runtime.core.bsr_matrix_from_triplets_double_host else: if scalar_type == wp.float32: native_func = runtime.core.bsr_matrix_from_triplets_float_device elif scalar_type == wp.float64: native_func = runtime.core.bsr_matrix_from_triplets_double_device if not native_func: raise NotImplementedError(f"bsr_from_triplets not implemented for scalar type {scalar_type}") dest.nnz = native_func( dest.block_shape[0], dest.block_shape[1], dest.nrow, nnz, rows.ptr, columns.ptr, values.ptr, dest.offsets.ptr, dest.columns.ptr, dest.values.ptr, ) def bsr_assign(dest: BsrMatrix[BlockType[Rows, Cols, Scalar]], src: BsrMatrix[BlockType[Rows, Cols, Any]]): """Copies the content of the `src` matrix to `dest`, casting the block values if the two matrices use distinct scalar types.""" if dest.values.device != src.values.device: raise ValueError("Source and destination matrices must reside on the same device") if dest.block_shape != src.block_shape: raise ValueError("Source and destination matrices must have the same block shape") dest.nrow = src.nrow dest.ncol = src.ncol dest.nnz = src.nnz _bsr_ensure_fits(dest) wp.copy(dest=dest.offsets, src=src.offsets, count=src.nrow + 1) if src.nnz > 0: wp.copy(dest=dest.columns, src=src.columns, count=src.nnz) warp.utils.array_cast(out_array=dest.values, in_array=src.values, count=src.nnz) def bsr_copy(A: BsrMatrix, scalar_type: Optional[Scalar] = None): """Returns a copy of matrix ``A``, possibly changing its scalar type. Args: scalar_type: If provided, the returned matrix will use this scalar type instead of the one from `A`. """ if scalar_type is None: block_type = A.values.dtype elif A.block_shape == (1, 1): block_type = scalar_type else: block_type = wp.types.matrix(shape=A.block_shape, dtype=scalar_type) copy = bsr_zeros(rows_of_blocks=A.nrow, cols_of_blocks=A.ncol, block_type=block_type, device=A.values.device) bsr_assign(dest=copy, src=A) return copy def bsr_set_transpose(dest: BsrMatrix[BlockType[Cols, Rows, Scalar]], src: BsrMatrix[BlockType[Rows, Cols, Scalar]]): """Assigns the transposed matrix `src` to matrix `dest`""" if dest.values.device != src.values.device: raise ValueError("All arguments must reside on the same device") if dest.scalar_type != src.scalar_type: raise ValueError("All arguments must have the same scalar type") transpose_block_shape = src.block_shape[::-1] if dest.block_shape != transpose_block_shape: raise ValueError(f"Destination block shape must be {transpose_block_shape}") dest.nrow = src.ncol dest.ncol = src.nrow dest.nnz = src.nnz if src.nnz == 0: return # Increase dest array sizes if needed _bsr_ensure_fits(dest) from warp.context import runtime if dest.values.device.is_cpu: if dest.scalar_type == wp.float32: native_func = runtime.core.bsr_transpose_float_host elif dest.scalar_type == wp.float64: native_func = runtime.core.bsr_transpose_double_host else: if dest.scalar_type == wp.float32: native_func = runtime.core.bsr_transpose_float_device elif dest.scalar_type == wp.float64: native_func = runtime.core.bsr_transpose_double_device if not native_func: raise NotImplementedError(f"bsr_set_transpose not implemented for scalar type {dest.scalar_type}") native_func( src.block_shape[0], src.block_shape[1], src.nrow, src.ncol, src.nnz, src.offsets.ptr, src.columns.ptr, src.values.ptr, dest.offsets.ptr, dest.columns.ptr, dest.values.ptr, ) def bsr_transposed(A: BsrMatrix): """Returns a copy of the transposed matrix `A`""" if A.block_shape == (1, 1): block_type = A.values.dtype else: block_type = wp.types.matrix(shape=A.block_shape[::-1], dtype=A.scalar_type) transposed = bsr_zeros(rows_of_blocks=A.ncol, cols_of_blocks=A.nrow, block_type=block_type, device=A.values.device) bsr_set_transpose(dest=transposed, src=A) return transposed @wp.kernel def _bsr_get_diag_kernel( A_offsets: wp.array(dtype=int), A_columns: wp.array(dtype=int), A_values: wp.array(dtype=Any), out: wp.array(dtype=Any), ): row = wp.tid() beg = A_offsets[row] end = A_offsets[row + 1] diag = wp.lower_bound(A_columns, beg, end, row) if diag < end: if A_columns[diag] == row: out[row] = A_values[diag] def bsr_get_diag(A: BsrMatrix[_BlockType], out: "Optional[Array[BlockType]]" = None) -> "Array[BlockType]": """Returns the array of blocks that constitute the diagonal of a sparse matrix. Args: A: the sparse matrix from which to extract the diagonal out: if provided, the array into which to store the diagonal blocks """ dim = min(A.nrow, A.ncol) if out is None: out = wp.zeros(shape=(dim,), dtype=A.values.dtype, device=A.values.device) else: if out.dtype != A.values.dtype: raise ValueError(f"Output array must have type {A.values.dtype}") if out.device != A.values.device: raise ValueError(f"Output array must reside on device {A.values.device}") if out.shape[0] < dim: raise ValueError(f"Output array must be of length at least {dim}") wp.launch( kernel=_bsr_get_diag_kernel, dim=dim, device=A.values.device, inputs=[A.offsets, A.columns, A.values, out] ) return out @wp.kernel def _bsr_set_diag_kernel( diag: wp.array(dtype=Any), A_offsets: wp.array(dtype=int), A_columns: wp.array(dtype=int), A_values: wp.array(dtype=Any), ): row = wp.tid() A_offsets[row + 1] = row + 1 A_columns[row] = row A_values[row] = diag[row] if row == 0: A_offsets[0] = 0 @wp.kernel def _bsr_set_diag_constant_kernel( diag_value: Any, A_offsets: wp.array(dtype=int), A_columns: wp.array(dtype=int), A_values: wp.array(dtype=Any), ): row = wp.tid() A_offsets[row + 1] = row + 1 A_columns[row] = row A_values[row] = diag_value if row == 0: A_offsets[0] = 0 def bsr_set_diag( A: BsrMatrix[BlockType], diag: "Union[BlockType, Array[BlockType]]", rows_of_blocks: Optional[int] = None, cols_of_blocks: Optional[int] = None, ): """Sets `A` as a block-diagonal matrix Args: A: the sparse matrix to modify diag: Either a warp array of type ``A.values.dtype``, in which case each element will define one block of the diagonal, or a constant value of type ``A.values.dtype``, in which case it will get assigned to all diagonal blocks. rows_of_blocks: If not ``None``, the new number of rows of blocks cols_of_blocks: If not ``None``, the new number of columns of blocks The shape of the matrix will be defined one of the following, in that order: - `rows_of_blocks` and `cols_of_blocks`, if provided. If only one is given, the second is assumed equal. - the first dimension of `diag`, if `diag` is an array - the current dimensions of `A` otherwise """ if rows_of_blocks is None and cols_of_blocks is not None: rows_of_blocks = cols_of_blocks if cols_of_blocks is None and rows_of_blocks is not None: cols_of_blocks = rows_of_blocks if warp.types.is_array(diag): if rows_of_blocks is None: rows_of_blocks = diag.shape[0] cols_of_blocks = diag.shape[0] if rows_of_blocks is not None: A.nrow = rows_of_blocks A.ncol = cols_of_blocks A.nnz = min(A.nrow, A.ncol) _bsr_ensure_fits(A) if warp.types.is_array(diag): wp.launch( kernel=_bsr_set_diag_kernel, dim=A.nnz, device=A.values.device, inputs=[diag, A.offsets, A.columns, A.values], ) else: if not warp.types.type_is_value(type(diag)): # Cast to launchable type diag = A.values.dtype(diag) wp.launch( kernel=_bsr_set_diag_constant_kernel, dim=A.nnz, device=A.values.device, inputs=[diag, A.offsets, A.columns, A.values], ) def bsr_diag( diag: "Union[BlockType, Array[BlockType]]", rows_of_blocks: Optional[int] = None, cols_of_blocks: Optional[int] = None, ) -> BsrMatrix["BlockType"]: """Creates and returns a block-diagonal BSR matrix from an given block value or array of block values. Args: diag: Either a warp array of type ``A.values.dtype``, in which case each element will define one block of the diagonal, or a constant value of type ``A.values.dtype``, in which case it will get assigned to all diagonal blocks. rows_of_blocks: If not ``None``, the new number of rows of blocks cols_of_blocks: If not ``None``, the new number of columns of blocks The shape of the matrix will be defined one of the following, in that order: - `rows_of_blocks` and `cols_of_blocks`, if provided. If only one is given, the second is assumed equal. - the first dimension of `diag`, if `diag` is an array """ if rows_of_blocks is None and cols_of_blocks is not None: rows_of_blocks = cols_of_blocks if cols_of_blocks is None and rows_of_blocks is not None: cols_of_blocks = rows_of_blocks if warp.types.is_array(diag): if rows_of_blocks is None: rows_of_blocks = diag.shape[0] cols_of_blocks = diag.shape[0] A = bsr_zeros( rows_of_blocks, cols_of_blocks, block_type=diag.dtype, device=diag.device, ) else: if rows_of_blocks is None: raise ValueError( "rows_of_blocks and/or cols_of_blocks must be provided for constructing a diagonal matrix with uniform diagonal" ) block_type = type(diag) if not warp.types.type_is_matrix(block_type) and len(getattr(diag, "shape", ())) == 2: block_type = wp.mat(shape=diag.shape, dtype=diag.dtype) A = bsr_zeros( rows_of_blocks, cols_of_blocks, block_type=block_type, ) bsr_set_diag(A, diag) return A def bsr_set_identity(A: BsrMatrix, rows_of_blocks: Optional[int] = None): """Sets `A` as the identity matrix Args: A: the sparse matrix to modify rows_of_blocks: if provided, the matrix will be resized as a square matrix with `rows_of_blocks` rows and columns. """ if A.block_shape == (1, 1): identity = A.scalar_type(1.0) else: from numpy import eye identity = eye(A.block_shape[0]) bsr_set_diag(A, diag=identity, rows_of_blocks=rows_of_blocks, cols_of_blocks=rows_of_blocks) def bsr_identity( rows_of_blocks: int, block_type: BlockType[Rows, Rows, Scalar], device: wp.context.Devicelike = None ) -> BsrMatrix[BlockType[Rows, Rows, Scalar]]: """Creates and returns a square identity matrix. Args: rows_of_blocks: Number of rows and columns of blocks in the created matrix. block_type: Block type for the newly created matrix -- must be square device: Device onto which to allocate the data arrays """ A = bsr_zeros(rows_of_blocks=rows_of_blocks, cols_of_blocks=rows_of_blocks, block_type=block_type, device=device) bsr_set_identity(A) return A @wp.kernel def _bsr_scale_kernel( alpha: Any, values: wp.array(dtype=Any), ): values[wp.tid()] = alpha * values[wp.tid()] def bsr_scale(x: BsrMatrix, alpha: Scalar) -> BsrMatrix: """ Performs the operation ``x := alpha * x`` on BSR matrix `x` and returns `x` """ if alpha != 1.0 and x.nnz > 0: if alpha == 0.0: bsr_set_zero(x) else: if not isinstance(alpha, x.scalar_type): alpha = x.scalar_type(alpha) wp.launch(kernel=_bsr_scale_kernel, dim=x.nnz, device=x.values.device, inputs=[alpha, x.values]) return x @wp.kernel def _bsr_get_block_row(dest_offset: int, bsr_offsets: wp.array(dtype=int), rows: wp.array(dtype=int)): i = wp.tid() row = wp.lower_bound(bsr_offsets, i + 1) - 1 rows[dest_offset + i] = row @wp.kernel def _bsr_axpy_add_block( src_offset: int, scale: Any, rows: wp.array(dtype=int), cols: wp.array(dtype=int), dst_offsets: wp.array(dtype=int), dst_columns: wp.array(dtype=int), src_values: wp.array(dtype=Any), dst_values: wp.array(dtype=Any), ): i = wp.tid() row = rows[i + src_offset] col = cols[i + src_offset] beg = dst_offsets[row] end = dst_offsets[row + 1] block = wp.lower_bound(dst_columns, beg, end, col) dst_values[block] = dst_values[block] + scale * src_values[i] class bsr_axpy_work_arrays: """Opaque structure for persisting :func:`bsr_axpy` temporary work buffers across calls""" def __init__(self): self._reset(None) def _reset(self, device): self.device = device self._sum_rows = None self._sum_cols = None self._old_y_values = None self._old_x_values = None def _allocate(self, device, y: BsrMatrix, sum_nnz: int): if self.device != device: self._reset(device) if self._sum_rows is None or self._sum_rows.size < sum_nnz: self._sum_rows = wp.empty(shape=(sum_nnz), dtype=int, device=self.device) if self._sum_cols is None or self._sum_cols.size < sum_nnz: self._sum_cols = wp.empty(shape=(sum_nnz), dtype=int, device=self.device) if self._old_y_values is None or self._old_y_values.size < y.nnz: self._old_y_values = wp.empty(shape=(y.nnz), dtype=y.values.dtype, device=self.device) def bsr_axpy( x: BsrMatrix[BlockType[Rows, Cols, Scalar]], y: Optional[BsrMatrix[BlockType[Rows, Cols, Scalar]]] = None, alpha: Scalar = 1.0, beta: Scalar = 1.0, work_arrays: Optional[bsr_axpy_work_arrays] = None, ) -> BsrMatrix[BlockType[Rows, Cols, Scalar]]: """ Performs the sparse matrix addition ``y := alpha * X + beta * y`` on BSR matrices `x` and `y` and returns `y`. The `x` and `y` matrices are allowed to alias. Args: x: Read-only right-hand-side. y: Mutable left-hand-side. If `y` is not provided, it will be allocated and treated as zero. alpha: Uniform scaling factor for `x` beta: Uniform scaling factor for `y` work_arrays: In most cases this function will require the use of temporary storage; this storage can be reused across calls by passing an instance of :class:`bsr_axpy_work_arrays` in `work_arrays`. """ if y is None: # If not output matrix is provided, allocate it for convenience y = bsr_zeros(x.nrow, x.ncol, block_type=x.values.dtype, device=x.values.device) beta = 0.0 # Handle easy cases first if beta == 0.0 or y.nnz == 0: bsr_assign(src=x, dest=y) return bsr_scale(y, alpha=alpha) if alpha == 0.0 or x.nnz == 0: return bsr_scale(y, alpha=beta) if not isinstance(alpha, y.scalar_type): alpha = y.scalar_type(alpha) if not isinstance(beta, y.scalar_type): beta = y.scalar_type(beta) if x == y: # Aliasing case return bsr_scale(y, alpha=alpha.value + beta.value) # General case if x.values.device != y.values.device: raise ValueError("All arguments must reside on the same device") if x.scalar_type != y.scalar_type or x.block_shape != y.block_shape: raise ValueError("Matrices must have the same block type") if x.nrow != y.nrow or x.ncol != y.ncol: raise ValueError("Matrices must have the same number of rows and columns") if work_arrays is None: work_arrays = bsr_axpy_work_arrays() sum_nnz = x.nnz + y.nnz device = y.values.device work_arrays._allocate(device, y, sum_nnz) wp.copy(work_arrays._sum_cols, y.columns, 0, 0, y.nnz) wp.launch(kernel=_bsr_get_block_row, device=device, dim=y.nnz, inputs=[0, y.offsets, work_arrays._sum_rows]) wp.copy(work_arrays._sum_cols, x.columns, y.nnz, 0, x.nnz) wp.launch(kernel=_bsr_get_block_row, device=device, dim=x.nnz, inputs=[y.nnz, x.offsets, work_arrays._sum_rows]) # Save old y values before overwriting matrix wp.copy(dest=work_arrays._old_y_values, src=y.values, count=y.nnz) # Increase dest array sizes if needed if y.columns.shape[0] < sum_nnz: y.columns = wp.empty(shape=(sum_nnz,), dtype=int, device=device) from warp.context import runtime if device.is_cpu: native_func = runtime.core.bsr_matrix_from_triplets_float_host else: native_func = runtime.core.bsr_matrix_from_triplets_float_device old_y_nnz = y.nnz y.nnz = native_func( y.block_shape[0], y.block_shape[1], y.nrow, sum_nnz, work_arrays._sum_rows.ptr, work_arrays._sum_cols.ptr, 0, y.offsets.ptr, y.columns.ptr, 0, ) _bsr_ensure_fits(y) y.values.zero_() wp.launch( kernel=_bsr_axpy_add_block, device=device, dim=old_y_nnz, inputs=[ 0, beta, work_arrays._sum_rows, work_arrays._sum_cols, y.offsets, y.columns, work_arrays._old_y_values, y.values, ], ) wp.launch( kernel=_bsr_axpy_add_block, device=device, dim=x.nnz, inputs=[ old_y_nnz, alpha, work_arrays._sum_rows, work_arrays._sum_cols, y.offsets, y.columns, x.values, y.values, ], ) return y @wp.kernel def _bsr_mm_count_coeffs( z_nnz: int, x_offsets: wp.array(dtype=int), x_columns: wp.array(dtype=int), y_offsets: wp.array(dtype=int), counts: wp.array(dtype=int), ): row = wp.tid() count = int(0) x_beg = x_offsets[row] x_end = x_offsets[row + 1] for x_block in range(x_beg, x_end): x_col = x_columns[x_block] count += y_offsets[x_col + 1] - y_offsets[x_col] counts[row + 1] = count if row == 0: counts[0] = z_nnz @wp.kernel def _bsr_mm_list_coeffs( x_offsets: wp.array(dtype=int), x_columns: wp.array(dtype=int), y_offsets: wp.array(dtype=int), y_columns: wp.array(dtype=int), mm_offsets: wp.array(dtype=int), mm_rows: wp.array(dtype=int), mm_cols: wp.array(dtype=int), ): row = wp.tid() mm_block = mm_offsets[row] x_beg = x_offsets[row] x_end = x_offsets[row + 1] for x_block in range(x_beg, x_end): x_col = x_columns[x_block] y_beg = y_offsets[x_col] y_end = y_offsets[x_col + 1] for y_block in range(y_beg, y_end): mm_cols[mm_block] = y_columns[y_block] mm_rows[mm_block] = row mm_block += 1 @wp.kernel def _bsr_mm_compute_values( alpha: Any, x_offsets: wp.array(dtype=int), x_columns: wp.array(dtype=int), x_values: wp.array(dtype=Any), y_offsets: wp.array(dtype=int), y_columns: wp.array(dtype=int), y_values: wp.array(dtype=Any), mm_offsets: wp.array(dtype=int), mm_cols: wp.array(dtype=int), mm_values: wp.array(dtype=Any), ): row = wp.tid() mm_beg = mm_offsets[row] mm_end = mm_offsets[row + 1] x_beg = x_offsets[row] x_end = x_offsets[row + 1] for x_block in range(x_beg, x_end): x_col = x_columns[x_block] ax_val = alpha * x_values[x_block] y_beg = y_offsets[x_col] y_end = y_offsets[x_col + 1] for y_block in range(y_beg, y_end): mm_block = wp.lower_bound(mm_cols, mm_beg, mm_end, y_columns[y_block]) mm_values[mm_block] = mm_values[mm_block] + ax_val * y_values[y_block] class bsr_mm_work_arrays: """Opaque structure for persisting :func:`bsr_mm` temporary work buffers across calls""" def __init__(self): self._reset(None) def _reset(self, device): self.device = device self._pinned_count_buffer = None self._mm_row_counts = None self._mm_rows = None self._mm_cols = None self._old_z_values = None self._old_z_offsets = None self._old_z_columns = None def _allocate_stage_1(self, device, z: BsrMatrix, copied_z_nnz: int, z_aliasing: bool): if self.device != device: self._reset(device) # Allocations that do not depend on any computation if self.device.is_cuda: if self._pinned_count_buffer is None: self._pinned_count_buffer = wp.empty(shape=(1,), dtype=int, pinned=True, device="cpu") if self._mm_row_counts is None or self._mm_row_counts.size < z.nrow + 1: self._mm_row_counts = wp.empty(shape=(z.nrow + 1,), dtype=int, device=self.device) if copied_z_nnz > 0: if self._old_z_values is None or self._old_z_values.size < copied_z_nnz: self._old_z_values = wp.empty(shape=(copied_z_nnz,), dtype=z.values.dtype, device=self.device) if z_aliasing: if self._old_z_columns is None or self._old_z_columns.size < z.nnz: self._old_z_columns = wp.empty(shape=(z.nnz,), dtype=z.columns.dtype, device=self.device) if self._old_z_offsets is None or self._old_z_offsets.size < z.nrow + 1: self._old_z_offsets = wp.empty(shape=(z.nrow + 1,), dtype=z.offsets.dtype, device=self.device) def _allocate_stage_2(self, mm_nnz: int): # Allocations that depend on unmerged nnz estimate if self._mm_rows is None or self._mm_rows.size < mm_nnz: self._mm_rows = wp.empty(shape=(mm_nnz,), dtype=int, device=self.device) if self._mm_cols is None or self._mm_cols.size < mm_nnz: self._mm_cols = wp.empty(shape=(mm_nnz,), dtype=int, device=self.device) def bsr_mm( x: BsrMatrix[BlockType[Rows, Any, Scalar]], y: BsrMatrix[BlockType[Any, Cols, Scalar]], z: Optional[BsrMatrix[BlockType[Rows, Cols, Scalar]]] = None, alpha: Scalar = 1.0, beta: Scalar = 0.0, work_arrays: Optional[bsr_mm_work_arrays] = None, ) -> BsrMatrix[BlockType[Rows, Cols, Scalar]]: """ Performs the sparse matrix-matrix multiplication ``z := alpha * x * y + beta * z`` on BSR matrices `x`, `y` and `z`, and returns `z`. The `x`, `y` and `z` matrices are allowed to alias. If the matrix `z` is not provided as input, it will be allocated and treated as zero. Args: x: Read-only left factor of the matrix-matrix product. y: Read-only right factor of the matrix-matrix product. z: Mutable left-hand-side. If `z` is not provided, it will be allocated and treated as zero. alpha: Uniform scaling factor for the ``x * y`` product beta: Uniform scaling factor for `z` work_arrays: In most cases this function will require the use of temporary storage; this storage can be reused across calls by passing an instance of :class:`bsr_mm_work_arrays` in `work_arrays`. """ if z is None: # If not output matrix is provided, allocate it for convenience z_block_shape = (x.block_shape[0], y.block_shape[1]) if z_block_shape == (1, 1): z_block_type = x.scalar_type else: z_block_type = wp.types.matrix(shape=z_block_shape, dtype=x.scalar_type) z = bsr_zeros(x.nrow, y.ncol, block_type=z_block_type, device=x.values.device) beta = 0.0 if x.values.device != y.values.device or x.values.device != z.values.device: raise ValueError("All arguments must reside on the same device") if x.scalar_type != y.scalar_type or x.scalar_type != z.scalar_type: raise ValueError("Matrices must have the same scalar type") if ( x.block_shape[0] != z.block_shape[0] or y.block_shape[1] != z.block_shape[1] or x.block_shape[1] != y.block_shape[0] ): raise ValueError("Incompatible block sizes for matrix multiplication") if x.nrow != z.nrow or z.ncol != y.ncol or x.ncol != y.nrow: raise ValueError("Incompatible number of rows/columns for matrix multiplication") device = z.values.device if alpha == 0.0 or x.nnz == 0 or y.nnz == 0: # Easy case return bsr_scale(z, beta) if not isinstance(alpha, z.scalar_type): alpha = z.scalar_type(alpha) if not isinstance(beta, z.scalar_type): beta = z.scalar_type(beta) if work_arrays is None: work_arrays = bsr_mm_work_arrays() z_aliasing = z == x or z == y copied_z_nnz = z.nnz if beta != 0.0 or z_aliasing else 0 work_arrays._allocate_stage_1(device, z, copied_z_nnz, z_aliasing) # Prefix sum of number of (unmerged) mm blocks per row wp.launch( kernel=_bsr_mm_count_coeffs, device=device, dim=z.nrow, inputs=[copied_z_nnz, x.offsets, x.columns, y.offsets, work_arrays._mm_row_counts], ) warp.utils.array_scan(work_arrays._mm_row_counts, work_arrays._mm_row_counts) # Get back total counts on host if device.is_cuda: wp.copy(dest=work_arrays._pinned_count_buffer, src=work_arrays._mm_row_counts, src_offset=z.nrow, count=1) wp.synchronize_stream(wp.get_stream(device)) mm_nnz = int(work_arrays._pinned_count_buffer.numpy()[0]) else: mm_nnz = int(work_arrays._mm_row_counts.numpy()[z.nrow]) work_arrays._allocate_stage_2(mm_nnz) # If z has a non-zero scale, save current data before overwriting it if copied_z_nnz > 0: # Copy z row and column indices wp.copy(dest=work_arrays._mm_cols, src=z.columns, count=copied_z_nnz) wp.launch( kernel=_bsr_get_block_row, device=device, dim=copied_z_nnz, inputs=[0, z.offsets, work_arrays._mm_rows] ) # Save current z values in temporary buffer wp.copy(src=z.values, dest=work_arrays._old_z_values, count=copied_z_nnz) if z_aliasing: # If z is aliasing with x or y, need to save topology as well wp.copy(src=z.columns, dest=work_arrays._old_z_columns, count=copied_z_nnz) wp.copy(src=z.offsets, dest=work_arrays._old_z_offsets, count=z.nrow + 1) # Fill unmerged mm blocks rows and columns wp.launch( kernel=_bsr_mm_list_coeffs, device=device, dim=z.nrow, inputs=[ x.offsets, x.columns, y.offsets, y.columns, work_arrays._mm_row_counts, work_arrays._mm_rows, work_arrays._mm_cols, ], ) # Increase dest array size if needed if z.columns.shape[0] < mm_nnz: z.columns = wp.empty(shape=(mm_nnz,), dtype=int, device=device) from warp.context import runtime if device.is_cpu: native_func = runtime.core.bsr_matrix_from_triplets_float_host else: native_func = runtime.core.bsr_matrix_from_triplets_float_device z.nnz = native_func( z.block_shape[0], z.block_shape[1], z.nrow, mm_nnz, work_arrays._mm_rows.ptr, work_arrays._mm_cols.ptr, 0, z.offsets.ptr, z.columns.ptr, 0, ) _bsr_ensure_fits(z) z.values.zero_() if copied_z_nnz > 0: # Add back original z values wp.launch( kernel=_bsr_axpy_add_block, device=device, dim=copied_z_nnz, inputs=[ 0, beta, work_arrays._mm_rows, work_arrays._mm_cols, z.offsets, z.columns, work_arrays._old_z_values, z.values, ], ) # Add mm blocks to z values if (warp.types.type_is_matrix(x.values.dtype) or warp.types.type_is_matrix(y.values.dtype)) and not ( warp.types.type_is_matrix(z.values.dtype) ): # Result block type is scalar, but operands are matrices # Cast result to (1x1) matrix to perform multiplication mm_values = z.values.view(wp.types.matrix(shape=(1, 1), dtype=z.scalar_type)) else: mm_values = z.values wp.launch( kernel=_bsr_mm_compute_values, device=device, dim=z.nrow, inputs=[ alpha, work_arrays._old_z_offsets if x == z else x.offsets, work_arrays._old_z_columns if x == z else x.columns, work_arrays._old_z_values if x == z else x.values, work_arrays._old_z_offsets if y == z else y.offsets, work_arrays._old_z_columns if y == z else y.columns, work_arrays._old_z_values if y == z else y.values, z.offsets, z.columns, mm_values, ], ) return z @wp.kernel def _bsr_mv_kernel( alpha: Any, A_offsets: wp.array(dtype=int), A_columns: wp.array(dtype=int), A_values: wp.array(dtype=Any), x: wp.array(dtype=Any), beta: Any, y: wp.array(dtype=Any), ): row = wp.tid() # zero-initialize with type of y elements scalar_zero = type(alpha)(0) v = y.dtype(scalar_zero) if alpha != scalar_zero: beg = A_offsets[row] end = A_offsets[row + 1] for block in range(beg, end): v += A_values[block] * x[A_columns[block]] v *= alpha if beta != scalar_zero: v += beta * y[row] y[row] = v def bsr_mv( A: BsrMatrix[BlockType[Rows, Cols, Scalar]], x: "Array[Vector[Cols, Scalar] | Scalar]", y: Optional["Array[Vector[Rows, Scalar] | Scalar]"] = None, alpha: Scalar = 1.0, beta: Scalar = 0.0, work_buffer: Optional["Array[Vector[Rows, Scalar] | Scalar]"] = None, ) -> "Array[Vector[Rows, Scalar] | Scalar]": """ Performs the sparse matrix-vector product ``y := alpha * A * x + beta * y`` and returns `y`. The `x` and `y` vectors are allowed to alias. Args: A: Read-only, left matrix factor of the matrix-vector product. x: Read-only, right vector factor of the matrix-vector product. y: Mutable left-hand-side. If `y` is not provided, it will be allocated and treated as zero. alpha: Uniform scaling factor for `x`. If zero, `x` will not be read and may be left uninitialized. beta: Uniform scaling factor for `y`. If zero, `y` will not be read and may be left uninitialized. work_buffer: Temporary storage is required if and only if `x` and `y` are the same vector. If provided the `work_buffer` array will be used for this purpose, otherwise a temporary allocation wil be performed. """ if y is None: # If no output array is provided, allocate one for convenience y_vec_len = A.block_shape[0] y_dtype = A.scalar_type if y_vec_len == 1 else wp.vec(length=y_vec_len, dtype=A.scalar_type) y = wp.empty(shape=(A.nrow,), device=A.values.device, dtype=y_dtype) y.zero_() beta = 0.0 if not isinstance(alpha, A.scalar_type): alpha = A.scalar_type(alpha) if not isinstance(beta, A.scalar_type): beta = A.scalar_type(beta) if A.values.device != x.device or A.values.device != y.device: raise ValueError("A, x and y must reside on the same device") if x.shape[0] != A.ncol: raise ValueError("Number of columns of A must match number of rows of x") if y.shape[0] != A.nrow: raise ValueError("Number of rows of A must match number of rows of y") if x == y: # Aliasing case, need temporary storage if work_buffer is None: work_buffer = wp.empty_like(y) elif work_buffer.size < y.size: raise ValueError(f"Work buffer size is insufficient, needs to be at least {y.size}") elif not wp.types.types_equal(work_buffer.dtype, y.dtype): raise ValueError(f"Work buffer must have same data type as y, {wp.types.type_repr(y.dtype)}") # Save old y values before overwriting vector wp.copy(dest=work_buffer, src=y, count=y.size) x = work_buffer # Promote scalar vectors to length-1 vecs and conversely if warp.types.type_is_matrix(A.values.dtype): if A.block_shape[0] == 1: if y.dtype == A.scalar_type: y = y.view(dtype=wp.vec(length=1, dtype=A.scalar_type)) if A.block_shape[1] == 1: if x.dtype == A.scalar_type: x = x.view(dtype=wp.vec(length=1, dtype=A.scalar_type)) else: if A.block_shape[0] == 1: if y.dtype != A.scalar_type: y = y.view(dtype=A.scalar_type) if A.block_shape[1] == 1: if x.dtype != A.scalar_type: x = x.view(dtype=A.scalar_type) wp.launch( kernel=_bsr_mv_kernel, device=A.values.device, dim=A.nrow, inputs=[alpha, A.offsets, A.columns, A.values, x, beta, y], ) return y