Spaces:
Sleeping
Sleeping
File size: 6,321 Bytes
66c9c8a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 | """
This example illustrates using domain decomposition to solve a diffusion PDE over multiple devices
"""
from typing import Tuple
import warp as wp
import warp.fem as fem
from warp.sparse import bsr_axpy, bsr_mv
from warp.utils import array_cast
# Import example utilities
# Make sure that works both when imported as module and run as standalone file
try:
from .bsr_utils import bsr_cg
from .example_diffusion import diffusion_form, linear_form
from .plot_utils import Plot
except ImportError:
from bsr_utils import bsr_cg
from example_diffusion import diffusion_form, linear_form
from plot_utils import Plot
@fem.integrand
def mass_form(
s: fem.Sample,
u: fem.Field,
v: fem.Field,
):
return u(s) * v(s)
@wp.kernel
def scal_kernel(a: wp.array(dtype=wp.float64), alpha: wp.float64):
a[wp.tid()] = a[wp.tid()] * alpha
@wp.kernel
def sum_kernel(a: wp.indexedarray(dtype=wp.float64), b: wp.array(dtype=wp.float64)):
a[wp.tid()] = a[wp.tid()] + b[wp.tid()]
def sum_vecs(vecs, indices, sum: wp.array, tmp: wp.array):
for v, idx in zip(vecs, indices):
wp.copy(dest=tmp, src=v)
idx_sum = wp.indexedarray(sum, idx)
wp.launch(kernel=sum_kernel, dim=idx.shape, device=sum.device, inputs=[idx_sum, tmp])
return sum
class DistributedSystem:
device = None
scalar_type: type
tmp_buf: wp.array
nrow: int
shape = Tuple[int, int]
rank_data = None
def mv_routine(A: DistributedSystem, x: wp.array, y: wp.array, alpha=1.0, beta=0.0):
"""Distributed matrix-vector multiplication routine, for example purposes"""
tmp = A.tmp_buf
wp.launch(kernel=scal_kernel, dim=y.shape, device=y.device, inputs=[y, wp.float64(beta)])
stream = wp.get_stream()
for mat_i, x_i, y_i, idx in zip(*A.rank_data):
# WAR copy with indexed array requiring matching shape
tmp_i = wp.array(
ptr=tmp.ptr, device=tmp.device, capacity=tmp.capacity, dtype=tmp.dtype, shape=idx.shape, owner=False
)
# Compress rhs on rank 0
x_idx = wp.indexedarray(x, idx)
wp.copy(dest=tmp_i, src=x_idx, count=idx.size, stream=stream)
# Send to rank i
wp.copy(dest=x_i, src=tmp_i, count=idx.size, stream=stream)
with wp.ScopedDevice(x_i.device):
wp.wait_stream(stream)
bsr_mv(A=mat_i, x=x_i, y=y_i, alpha=alpha, beta=0.0)
wp.wait_stream(wp.get_stream(x_i.device))
# Back to rank 0 for sum
wp.copy(dest=tmp_i, src=y_i, count=idx.size, stream=stream)
y_idx = wp.indexedarray(y, idx)
wp.launch(kernel=sum_kernel, dim=idx.shape, device=y_idx.device, inputs=[y_idx, tmp_i], stream=stream)
class Example:
def __init__(self, stage=None, quiet=False):
self._bd_weight = 100.0
self._quiet = quiet
self._geo = fem.Grid2D(res=wp.vec2i(25))
self._main_device = wp.get_device("cuda")
with wp.ScopedDevice(self._main_device):
self._scalar_space = fem.make_polynomial_space(self._geo, degree=3)
self._scalar_field = self._scalar_space.make_field()
self.renderer = Plot(stage)
def update(self):
devices = wp.get_cuda_devices()
main_device = self._main_device
rhs_vecs = []
res_vecs = []
matrices = []
indices = []
# Build local system for each device
for k, device in enumerate(devices):
with wp.ScopedDevice(device):
# Construct the partition corresponding to the k'th device
geo_partition = fem.LinearGeometryPartition(self._geo, k, len(devices))
matrix, rhs, partition_node_indices = self._assemble_local_system(geo_partition)
rhs_vecs.append(rhs)
res_vecs.append(wp.empty_like(rhs))
matrices.append(matrix)
indices.append(partition_node_indices.to(main_device))
# Global rhs as sum of all local rhs
glob_rhs = wp.zeros(n=self._scalar_space.node_count(), dtype=wp.float64, device=main_device)
tmp = wp.empty_like(glob_rhs)
sum_vecs(rhs_vecs, indices, glob_rhs, tmp)
# Distributed CG
global_res = wp.zeros_like(glob_rhs)
A = DistributedSystem()
A.device = device
A.scalar_type = glob_rhs.dtype
A.nrow = self._scalar_space.node_count()
A.shape = (A.nrow, A.nrow)
A.tmp_buf = tmp
A.rank_data = (matrices, rhs_vecs, res_vecs, indices)
bsr_cg(
A,
x=global_res,
b=glob_rhs,
use_diag_precond=False,
quiet=self._quiet,
mv_routine=mv_routine,
device=main_device,
)
array_cast(in_array=global_res, out_array=self._scalar_field.dof_values)
def render(self):
self.renderer.add_surface("solution", self._scalar_field)
def _assemble_local_system(self, geo_partition: fem.GeometryPartition):
scalar_space = self._scalar_space
space_partition = fem.make_space_partition(scalar_space, geo_partition)
domain = fem.Cells(geometry=geo_partition)
# Right-hand-side
test = fem.make_test(space=scalar_space, space_partition=space_partition, domain=domain)
rhs = fem.integrate(linear_form, fields={"v": test})
# Weakly-imposed boundary conditions on all sides
boundary = fem.BoundarySides(geometry=geo_partition)
bd_test = fem.make_test(space=scalar_space, space_partition=space_partition, domain=boundary)
bd_trial = fem.make_trial(space=scalar_space, space_partition=space_partition, domain=boundary)
bd_matrix = fem.integrate(mass_form, fields={"u": bd_trial, "v": bd_test})
# Diffusion form
trial = fem.make_trial(space=scalar_space, space_partition=space_partition, domain=domain)
matrix = fem.integrate(diffusion_form, fields={"u": trial, "v": test}, values={"nu": 1.0})
bsr_axpy(y=matrix, x=bd_matrix, alpha=self._bd_weight)
return matrix, rhs, space_partition.space_node_indices()
if __name__ == "__main__":
wp.init()
wp.set_module_options({"enable_backward": False})
example = Example()
example.update()
example.render()
example.renderer.plot()
|