GarmentCode / NvidiaWarp-GarmentCode /examples /fem /example_diffusion_mgpu.py
qbhf2's picture
added NvidiaWarp and GarmentCode repos
66c9c8a
"""
This example illustrates using domain decomposition to solve a diffusion PDE over multiple devices
"""
from typing import Tuple
import warp as wp
import warp.fem as fem
from warp.sparse import bsr_axpy, bsr_mv
from warp.utils import array_cast
# Import example utilities
# Make sure that works both when imported as module and run as standalone file
try:
from .bsr_utils import bsr_cg
from .example_diffusion import diffusion_form, linear_form
from .plot_utils import Plot
except ImportError:
from bsr_utils import bsr_cg
from example_diffusion import diffusion_form, linear_form
from plot_utils import Plot
@fem.integrand
def mass_form(
s: fem.Sample,
u: fem.Field,
v: fem.Field,
):
return u(s) * v(s)
@wp.kernel
def scal_kernel(a: wp.array(dtype=wp.float64), alpha: wp.float64):
a[wp.tid()] = a[wp.tid()] * alpha
@wp.kernel
def sum_kernel(a: wp.indexedarray(dtype=wp.float64), b: wp.array(dtype=wp.float64)):
a[wp.tid()] = a[wp.tid()] + b[wp.tid()]
def sum_vecs(vecs, indices, sum: wp.array, tmp: wp.array):
for v, idx in zip(vecs, indices):
wp.copy(dest=tmp, src=v)
idx_sum = wp.indexedarray(sum, idx)
wp.launch(kernel=sum_kernel, dim=idx.shape, device=sum.device, inputs=[idx_sum, tmp])
return sum
class DistributedSystem:
device = None
scalar_type: type
tmp_buf: wp.array
nrow: int
shape = Tuple[int, int]
rank_data = None
def mv_routine(A: DistributedSystem, x: wp.array, y: wp.array, alpha=1.0, beta=0.0):
"""Distributed matrix-vector multiplication routine, for example purposes"""
tmp = A.tmp_buf
wp.launch(kernel=scal_kernel, dim=y.shape, device=y.device, inputs=[y, wp.float64(beta)])
stream = wp.get_stream()
for mat_i, x_i, y_i, idx in zip(*A.rank_data):
# WAR copy with indexed array requiring matching shape
tmp_i = wp.array(
ptr=tmp.ptr, device=tmp.device, capacity=tmp.capacity, dtype=tmp.dtype, shape=idx.shape, owner=False
)
# Compress rhs on rank 0
x_idx = wp.indexedarray(x, idx)
wp.copy(dest=tmp_i, src=x_idx, count=idx.size, stream=stream)
# Send to rank i
wp.copy(dest=x_i, src=tmp_i, count=idx.size, stream=stream)
with wp.ScopedDevice(x_i.device):
wp.wait_stream(stream)
bsr_mv(A=mat_i, x=x_i, y=y_i, alpha=alpha, beta=0.0)
wp.wait_stream(wp.get_stream(x_i.device))
# Back to rank 0 for sum
wp.copy(dest=tmp_i, src=y_i, count=idx.size, stream=stream)
y_idx = wp.indexedarray(y, idx)
wp.launch(kernel=sum_kernel, dim=idx.shape, device=y_idx.device, inputs=[y_idx, tmp_i], stream=stream)
class Example:
def __init__(self, stage=None, quiet=False):
self._bd_weight = 100.0
self._quiet = quiet
self._geo = fem.Grid2D(res=wp.vec2i(25))
self._main_device = wp.get_device("cuda")
with wp.ScopedDevice(self._main_device):
self._scalar_space = fem.make_polynomial_space(self._geo, degree=3)
self._scalar_field = self._scalar_space.make_field()
self.renderer = Plot(stage)
def update(self):
devices = wp.get_cuda_devices()
main_device = self._main_device
rhs_vecs = []
res_vecs = []
matrices = []
indices = []
# Build local system for each device
for k, device in enumerate(devices):
with wp.ScopedDevice(device):
# Construct the partition corresponding to the k'th device
geo_partition = fem.LinearGeometryPartition(self._geo, k, len(devices))
matrix, rhs, partition_node_indices = self._assemble_local_system(geo_partition)
rhs_vecs.append(rhs)
res_vecs.append(wp.empty_like(rhs))
matrices.append(matrix)
indices.append(partition_node_indices.to(main_device))
# Global rhs as sum of all local rhs
glob_rhs = wp.zeros(n=self._scalar_space.node_count(), dtype=wp.float64, device=main_device)
tmp = wp.empty_like(glob_rhs)
sum_vecs(rhs_vecs, indices, glob_rhs, tmp)
# Distributed CG
global_res = wp.zeros_like(glob_rhs)
A = DistributedSystem()
A.device = device
A.scalar_type = glob_rhs.dtype
A.nrow = self._scalar_space.node_count()
A.shape = (A.nrow, A.nrow)
A.tmp_buf = tmp
A.rank_data = (matrices, rhs_vecs, res_vecs, indices)
bsr_cg(
A,
x=global_res,
b=glob_rhs,
use_diag_precond=False,
quiet=self._quiet,
mv_routine=mv_routine,
device=main_device,
)
array_cast(in_array=global_res, out_array=self._scalar_field.dof_values)
def render(self):
self.renderer.add_surface("solution", self._scalar_field)
def _assemble_local_system(self, geo_partition: fem.GeometryPartition):
scalar_space = self._scalar_space
space_partition = fem.make_space_partition(scalar_space, geo_partition)
domain = fem.Cells(geometry=geo_partition)
# Right-hand-side
test = fem.make_test(space=scalar_space, space_partition=space_partition, domain=domain)
rhs = fem.integrate(linear_form, fields={"v": test})
# Weakly-imposed boundary conditions on all sides
boundary = fem.BoundarySides(geometry=geo_partition)
bd_test = fem.make_test(space=scalar_space, space_partition=space_partition, domain=boundary)
bd_trial = fem.make_trial(space=scalar_space, space_partition=space_partition, domain=boundary)
bd_matrix = fem.integrate(mass_form, fields={"u": bd_trial, "v": bd_test})
# Diffusion form
trial = fem.make_trial(space=scalar_space, space_partition=space_partition, domain=domain)
matrix = fem.integrate(diffusion_form, fields={"u": trial, "v": test}, values={"nu": 1.0})
bsr_axpy(y=matrix, x=bd_matrix, alpha=self._bd_weight)
return matrix, rhs, space_partition.space_node_indices()
if __name__ == "__main__":
wp.init()
wp.set_module_options({"enable_backward": False})
example = Example()
example.update()
example.render()
example.renderer.plot()