qbhf2's picture
added NvidiaWarp and GarmentCode repos
66c9c8a
from typing import List, Dict, Set, Optional, Any, Union
import warp as wp
import re
import ast
from warp.sparse import BsrMatrix, bsr_zeros, bsr_set_from_triplets, bsr_copy, bsr_assign
from warp.types import type_length
from warp.utils import array_cast
from warp.codegen import get_annotations
from warp.fem.domain import GeometryDomain
from warp.fem.field import (
TestField,
TrialField,
FieldLike,
DiscreteField,
FieldRestriction,
make_restriction,
)
from warp.fem.quadrature import Quadrature, RegularQuadrature
from warp.fem.operator import Operator, Integrand
from warp.fem import cache
from warp.fem.types import Domain, Field, Sample, DofIndex, NULL_DOF_INDEX, OUTSIDE, make_free_sample
def _resolve_path(func, node):
"""
Resolves variable and path from ast node/attribute (adapted from warp.codegen)
"""
modules = []
while isinstance(node, ast.Attribute):
modules.append(node.attr)
node = node.value
if isinstance(node, ast.Name):
modules.append(node.id)
# reverse list since ast presents it backward order
path = [*reversed(modules)]
if len(path) == 0:
return None, path
# try and evaluate object path
try:
# Look up the closure info and append it to adj.func.__globals__
# in case you want to define a kernel inside a function and refer
# to varibles you've declared inside that function:
capturedvars = dict(
zip(
func.__code__.co_freevars,
[c.cell_contents for c in (func.__closure__ or [])],
)
)
vars_dict = {**func.__globals__, **capturedvars}
func = eval(".".join(path), vars_dict)
return func, path
except (NameError, AttributeError):
pass
return None, path
def _path_to_ast_attribute(name: str) -> ast.Attribute:
path = name.split(".")
path.reverse()
node = ast.Name(id=path.pop(), ctx=ast.Load())
while len(path):
node = ast.Attribute(
value=node,
attr=path.pop(),
ctx=ast.Load(),
)
return node
class IntegrandTransformer(ast.NodeTransformer):
def __init__(self, integrand: Integrand, field_args: Dict[str, FieldLike]):
self._integrand = integrand
self._field_args = field_args
def visit_Call(self, call: ast.Call):
call = self.generic_visit(call)
callee = getattr(call.func, "id", None)
if callee in self._field_args:
# Shortcut for evaluating fields as f(x...)
field = self._field_args[callee]
arg_type = self._integrand.argspec.annotations[callee]
operator = arg_type.call_operator
call.func = ast.Attribute(
value=_path_to_ast_attribute(f"{arg_type.__module__}.{arg_type.__qualname__}"),
attr="call_operator",
ctx=ast.Load(),
)
call.args = [ast.Name(id=callee, ctx=ast.Load())] + call.args
self._replace_call_func(call, operator, field)
return call
func, _ = _resolve_path(self._integrand.func, call.func)
if isinstance(func, Operator) and len(call.args) > 0:
# Evaluating operators as op(field, x, ...)
callee = getattr(call.args[0], "id", None)
if callee in self._field_args:
field = self._field_args[callee]
self._replace_call_func(call, func, field)
if isinstance(func, Integrand):
key = self._translate_callee(func, call.args)
call.func = ast.Attribute(
value=call.func,
attr=key,
ctx=ast.Load(),
)
# print(ast.dump(call, indent=4))
return call
def _replace_call_func(self, call: ast.Call, operator: Operator, field: FieldLike):
try:
pointer = operator.resolver(field)
setattr(operator, pointer.key, pointer)
except AttributeError:
raise ValueError(f"Operator {operator.func.__name__} is not defined for field {field.name}")
call.func = ast.Attribute(value=call.func, attr=pointer.key, ctx=ast.Load())
def _translate_callee(self, callee: Integrand, args: List[ast.AST]):
# Get field types for call site arguments
call_site_field_args = []
for arg in args:
name = getattr(arg, "id", None)
if name in self._field_args:
call_site_field_args.append(self._field_args[name])
call_site_field_args.reverse()
# Pass to callee in same order
callee_field_args = {}
for arg in callee.argspec.args:
arg_type = callee.argspec.annotations[arg]
if arg_type in (Field, Domain):
callee_field_args[arg] = call_site_field_args.pop()
return _translate_integrand(callee, callee_field_args).key
def _translate_integrand(integrand: Integrand, field_args: Dict[str, FieldLike]) -> wp.Function:
# Specialize field argument types
argspec = integrand.argspec
annotations = {}
for arg in argspec.args:
arg_type = argspec.annotations[arg]
if arg_type == Field:
annotations[arg] = field_args[arg].ElementEvalArg
elif arg_type == Domain:
annotations[arg] = field_args[arg].ElementArg
else:
annotations[arg] = arg_type
# Transform field evaluation calls
transformer = IntegrandTransformer(integrand, field_args)
def is_field_like(f):
return isinstance(f, FieldLike)
suffix = "_".join([f.name for f in field_args.values() if is_field_like(f)])
func = cache.get_integrand_function(
integrand=integrand,
suffix=suffix,
annotations=annotations,
code_transformers=[transformer],
)
key = func.key
setattr(integrand, key, integrand.module.functions[key])
return getattr(integrand, key)
def _get_integrand_field_arguments(
integrand: Integrand,
fields: Dict[str, FieldLike],
domain: GeometryDomain = None,
):
# parse argument types
field_args = {}
value_args = {}
domain_name = None
sample_name = None
argspec = integrand.argspec
for arg in argspec.args:
arg_type = argspec.annotations[arg]
if arg_type == Field:
if arg not in fields:
raise ValueError(f"Missing field for argument '{arg}'")
field_args[arg] = fields[arg]
elif arg_type == Domain:
domain_name = arg
field_args[arg] = domain
elif arg_type == Sample:
sample_name = arg
else:
value_args[arg] = arg_type
return field_args, value_args, domain_name, sample_name
def _get_test_and_trial_fields(
fields: Dict[str, FieldLike],
):
test = None
trial = None
test_name = None
trial_name = None
for name, field in fields.items():
if isinstance(field, TestField):
if test is not None:
raise ValueError("Duplicate test field argument")
test = field
test_name = name
elif isinstance(field, TrialField):
if trial is not None:
raise ValueError("Duplicate test field argument")
trial = field
trial_name = name
if trial is not None:
if test is None:
raise ValueError("A trial field cannot be provided without a test field")
if test.domain != trial.domain:
raise ValueError("Incompatible test and trial domains")
return test, test_name, trial, trial_name
def _gen_field_struct(field_args: Dict[str, FieldLike]):
class Fields:
pass
annotations = get_annotations(Fields)
for name, arg in field_args.items():
if isinstance(arg, GeometryDomain):
continue
setattr(Fields, name, arg.EvalArg())
annotations[name] = arg.EvalArg
try:
Fields.__annotations__ = annotations
except AttributeError:
setattr(Fields.__dict__, "__annotations__", annotations)
suffix = "_".join([f"{name}_{arg_struct.cls.__qualname__}" for name, arg_struct in annotations.items()])
return cache.get_struct(Fields, suffix=suffix)
def _gen_value_struct(value_args: Dict[str, type]):
class Values:
pass
annotations = get_annotations(Values)
for name, arg_type in value_args.items():
setattr(Values, name, None)
annotations[name] = arg_type
def arg_type_name(arg_type):
if isinstance(arg_type, wp.codegen.Struct):
return arg_type_name(arg_type.cls)
return getattr(arg_type, "__name__", str(arg_type))
def arg_type_name(arg_type):
if isinstance(arg_type, wp.codegen.Struct):
return arg_type_name(arg_type.cls)
return getattr(arg_type, "__name__", str(arg_type))
try:
Values.__annotations__ = annotations
except AttributeError:
setattr(Values.__dict__, "__annotations__", annotations)
suffix = "_".join([f"{name}_{arg_type_name(arg_type)}" for name, arg_type in annotations.items()])
return cache.get_struct(Values, suffix=suffix)
def _get_trial_arg():
pass
def _get_test_arg():
pass
class _FieldWrappers:
pass
def _register_integrand_field_wrappers(integrand_func: wp.Function, fields: Dict[str, FieldLike]):
integrand_func._field_wrappers = _FieldWrappers()
for name, field in fields.items():
setattr(integrand_func._field_wrappers, name, field.ElementEvalArg)
class PassFieldArgsToIntegrand(ast.NodeTransformer):
def __init__(
self,
arg_names: List[str],
field_args: Set[str],
value_args: Set[str],
sample_name: str,
domain_name: str,
test_name: str = None,
trial_name: str = None,
func_name: str = "integrand_func",
fields_var_name: str = "fields",
values_var_name: str = "values",
domain_var_name: str = "domain_arg",
sample_var_name: str = "sample",
field_wrappers_attr: str = "_field_wrappers",
):
self._arg_names = arg_names
self._field_args = field_args
self._value_args = value_args
self._domain_name = domain_name
self._sample_name = sample_name
self._func_name = func_name
self._test_name = test_name
self._trial_name = trial_name
self._fields_var_name = fields_var_name
self._values_var_name = values_var_name
self._domain_var_name = domain_var_name
self._sample_var_name = sample_var_name
self._field_wrappers_attr = field_wrappers_attr
def visit_Call(self, call: ast.Call):
call = self.generic_visit(call)
callee = getattr(call.func, "id", None)
if callee == self._func_name:
# Replace function arguments with ours generated structs
call.args.clear()
for arg in self._arg_names:
if arg == self._domain_name:
call.args.append(
ast.Name(id=self._domain_var_name, ctx=ast.Load()),
)
elif arg == self._sample_name:
call.args.append(
ast.Name(id=self._sample_var_name, ctx=ast.Load()),
)
elif arg in self._field_args:
call.args.append(
ast.Call(
func=ast.Attribute(
value=ast.Attribute(
value=ast.Name(id=self._func_name, ctx=ast.Load()),
attr=self._field_wrappers_attr,
ctx=ast.Load(),
),
attr=arg,
ctx=ast.Load(),
),
args=[
ast.Name(id=self._domain_var_name, ctx=ast.Load()),
ast.Attribute(
value=ast.Name(id=self._fields_var_name, ctx=ast.Load()),
attr=arg,
ctx=ast.Load(),
),
],
keywords=[],
)
)
elif arg in self._value_args:
call.args.append(
ast.Attribute(
value=ast.Name(id=self._values_var_name, ctx=ast.Load()),
attr=arg,
ctx=ast.Load(),
)
)
else:
raise RuntimeError(f"Unhandled argument {arg}")
# print(ast.dump(call, indent=4))
elif callee == _get_test_arg.__name__:
# print(ast.dump(call, indent=4))
call = ast.Attribute(
value=ast.Name(id=self._fields_var_name, ctx=ast.Load()),
attr=self._test_name,
ctx=ast.Load(),
)
elif callee == _get_trial_arg.__name__:
# print(ast.dump(call, indent=4))
call = ast.Attribute(
value=ast.Name(id=self._fields_var_name, ctx=ast.Load()),
attr=self._trial_name,
ctx=ast.Load(),
)
return call
def get_integrate_constant_kernel(
integrand_func: wp.Function,
domain: GeometryDomain,
quadrature: Quadrature,
FieldStruct: wp.codegen.Struct,
ValueStruct: wp.codegen.Struct,
accumulate_dtype,
):
def integrate_kernel_fn(
qp_arg: quadrature.Arg,
domain_arg: domain.ElementArg,
domain_index_arg: domain.ElementIndexArg,
fields: FieldStruct,
values: ValueStruct,
result: wp.array(dtype=accumulate_dtype),
):
element_index = domain.element_index(domain_index_arg, wp.tid())
elem_sum = accumulate_dtype(0.0)
test_dof_index = NULL_DOF_INDEX
trial_dof_index = NULL_DOF_INDEX
qp_point_count = quadrature.point_count(domain_arg, qp_arg, element_index)
for k in range(qp_point_count):
qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
vol = domain.element_measure(domain_arg, sample)
val = integrand_func(sample, fields, values)
elem_sum += accumulate_dtype(qp_weight * vol * val)
wp.atomic_add(result, 0, elem_sum)
return integrate_kernel_fn
def get_integrate_linear_kernel(
integrand_func: wp.Function,
domain: GeometryDomain,
quadrature: Quadrature,
FieldStruct: wp.codegen.Struct,
ValueStruct: wp.codegen.Struct,
test: TestField,
output_dtype,
accumulate_dtype,
):
def integrate_kernel_fn(
qp_arg: quadrature.Arg,
domain_arg: domain.ElementArg,
domain_index_arg: domain.ElementIndexArg,
test_arg: test.space_restriction.NodeArg,
fields: FieldStruct,
values: ValueStruct,
result: wp.array2d(dtype=output_dtype),
):
local_node_index, test_dof = wp.tid()
node_index = test.space_restriction.node_partition_index(test_arg, local_node_index)
element_count = test.space_restriction.node_element_count(test_arg, local_node_index)
trial_dof_index = NULL_DOF_INDEX
val_sum = accumulate_dtype(0.0)
for n in range(element_count):
node_element_index = test.space_restriction.node_element_index(test_arg, local_node_index, n)
element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
test_dof_index = DofIndex(node_element_index.node_index_in_element, test_dof)
qp_point_count = quadrature.point_count(domain_arg, qp_arg, element_index)
for k in range(qp_point_count):
qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
qp_coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
vol = domain.element_measure(domain_arg, make_free_sample(element_index, qp_coords))
sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
val = integrand_func(sample, fields, values)
val_sum += accumulate_dtype(qp_weight * vol * val)
result[node_index, test_dof] = output_dtype(val_sum)
return integrate_kernel_fn
def get_integrate_linear_nodal_kernel(
integrand_func: wp.Function,
domain: GeometryDomain,
FieldStruct: wp.codegen.Struct,
ValueStruct: wp.codegen.Struct,
test: TestField,
output_dtype,
accumulate_dtype,
):
def integrate_kernel_fn(
domain_arg: domain.ElementArg,
domain_index_arg: domain.ElementIndexArg,
test_restriction_arg: test.space_restriction.NodeArg,
fields: FieldStruct,
values: ValueStruct,
result: wp.array2d(dtype=output_dtype),
):
local_node_index, dof = wp.tid()
node_index = test.space_restriction.node_partition_index(test_restriction_arg, local_node_index)
element_count = test.space_restriction.node_element_count(test_restriction_arg, local_node_index)
trial_dof_index = NULL_DOF_INDEX
val_sum = accumulate_dtype(0.0)
for n in range(element_count):
node_element_index = test.space_restriction.node_element_index(test_restriction_arg, local_node_index, n)
element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
coords = test.space.node_coords_in_element(
domain_arg,
_get_test_arg(),
element_index,
node_element_index.node_index_in_element,
)
if coords[0] != OUTSIDE:
node_weight = test.space.node_quadrature_weight(
domain_arg,
_get_test_arg(),
element_index,
node_element_index.node_index_in_element,
)
test_dof_index = DofIndex(node_element_index.node_index_in_element, dof)
sample = Sample(
element_index,
coords,
node_index,
node_weight,
test_dof_index,
trial_dof_index,
)
vol = domain.element_measure(domain_arg, sample)
val = integrand_func(sample, fields, values)
val_sum += accumulate_dtype(node_weight * vol * val)
result[node_index, dof] = output_dtype(val_sum)
return integrate_kernel_fn
def get_integrate_bilinear_kernel(
integrand_func: wp.Function,
domain: GeometryDomain,
quadrature: Quadrature,
FieldStruct: wp.codegen.Struct,
ValueStruct: wp.codegen.Struct,
test: TestField,
trial: TrialField,
output_dtype,
accumulate_dtype,
):
NODES_PER_ELEMENT = trial.space.topology.NODES_PER_ELEMENT
def integrate_kernel_fn(
qp_arg: quadrature.Arg,
domain_arg: domain.ElementArg,
domain_index_arg: domain.ElementIndexArg,
test_arg: test.space_restriction.NodeArg,
trial_partition_arg: trial.space_partition.PartitionArg,
trial_topology_arg: trial.space_partition.space_topology.TopologyArg,
fields: FieldStruct,
values: ValueStruct,
row_offsets: wp.array(dtype=int),
triplet_rows: wp.array(dtype=int),
triplet_cols: wp.array(dtype=int),
triplet_values: wp.array3d(dtype=output_dtype),
):
test_local_node_index, trial_node, test_dof, trial_dof = wp.tid()
element_count = test.space_restriction.node_element_count(test_arg, test_local_node_index)
test_node_index = test.space_restriction.node_partition_index(test_arg, test_local_node_index)
trial_dof_index = DofIndex(trial_node, trial_dof)
for element in range(element_count):
test_element_index = test.space_restriction.node_element_index(test_arg, test_local_node_index, element)
element_index = domain.element_index(domain_index_arg, test_element_index.domain_element_index)
qp_point_count = quadrature.point_count(domain_arg, qp_arg, element_index)
test_dof_index = DofIndex(
test_element_index.node_index_in_element,
test_dof,
)
val_sum = accumulate_dtype(0.0)
for k in range(qp_point_count):
qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
vol = domain.element_measure(domain_arg, make_free_sample(element_index, coords))
sample = Sample(
element_index,
coords,
qp_index,
qp_weight,
test_dof_index,
trial_dof_index,
)
val = integrand_func(sample, fields, values)
val_sum += accumulate_dtype(qp_weight * vol * val)
block_offset = (row_offsets[test_node_index] + element) * NODES_PER_ELEMENT + trial_node
triplet_values[block_offset, test_dof, trial_dof] = output_dtype(val_sum)
# Set row and column indices
if test_dof == 0 and trial_dof == 0:
trial_node_index = trial.space_partition.partition_node_index(
trial_partition_arg,
trial.space.topology.element_node_index(domain_arg, trial_topology_arg, element_index, trial_node),
)
triplet_rows[block_offset] = test_node_index
triplet_cols[block_offset] = trial_node_index
return integrate_kernel_fn
def get_integrate_bilinear_nodal_kernel(
integrand_func: wp.Function,
domain: GeometryDomain,
FieldStruct: wp.codegen.Struct,
ValueStruct: wp.codegen.Struct,
test: TestField,
output_dtype,
accumulate_dtype,
):
def integrate_kernel_fn(
domain_arg: domain.ElementArg,
domain_index_arg: domain.ElementIndexArg,
test_restriction_arg: test.space_restriction.NodeArg,
fields: FieldStruct,
values: ValueStruct,
triplet_rows: wp.array(dtype=int),
triplet_cols: wp.array(dtype=int),
triplet_values: wp.array3d(dtype=output_dtype),
):
local_node_index, test_dof, trial_dof = wp.tid()
element_count = test.space_restriction.node_element_count(test_restriction_arg, local_node_index)
node_index = test.space_restriction.node_partition_index(test_restriction_arg, local_node_index)
val_sum = accumulate_dtype(0.0)
for n in range(element_count):
node_element_index = test.space_restriction.node_element_index(test_restriction_arg, local_node_index, n)
element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
coords = test.space.node_coords_in_element(
domain_arg,
_get_test_arg(),
element_index,
node_element_index.node_index_in_element,
)
if coords[0] != OUTSIDE:
node_weight = test.space.node_quadrature_weight(
domain_arg,
_get_test_arg(),
element_index,
node_element_index.node_index_in_element,
)
test_dof_index = DofIndex(node_element_index.node_index_in_element, test_dof)
trial_dof_index = DofIndex(node_element_index.node_index_in_element, trial_dof)
sample = Sample(
element_index,
coords,
node_index,
node_weight,
test_dof_index,
trial_dof_index,
)
vol = domain.element_measure(domain_arg, sample)
val = integrand_func(sample, fields, values)
val_sum += accumulate_dtype(node_weight * vol * val)
triplet_values[local_node_index, test_dof, trial_dof] = output_dtype(val_sum)
triplet_rows[local_node_index] = node_index
triplet_cols[local_node_index] = node_index
return integrate_kernel_fn
def _generate_integrate_kernel(
integrand: Integrand,
domain: GeometryDomain,
nodal: bool,
quadrature: Quadrature,
test: Optional[TestField],
test_name: str,
trial: Optional[TrialField],
trial_name: str,
fields: Dict[str, FieldLike],
output_dtype: type,
accumulate_dtype: type,
kernel_options: Dict[str, Any] = {},
) -> wp.Kernel:
output_dtype = wp.types.type_scalar_type(output_dtype)
# Extract field arguments from integrand
field_args, value_args, domain_name, sample_name = _get_integrand_field_arguments(
integrand, fields=fields, domain=domain
)
FieldStruct = _gen_field_struct(field_args)
ValueStruct = _gen_value_struct(value_args)
# Check if kernel exist in cache
kernel_suffix = f"_itg_{wp.types.type_typestr(output_dtype)}{wp.types.type_typestr(accumulate_dtype)}_{domain.name}_{FieldStruct.key}"
if nodal:
kernel_suffix += "_nodal"
else:
kernel_suffix += quadrature.name
if test:
kernel_suffix += f"_test_{test.space_partition.name}_{test.space.name}"
if trial:
kernel_suffix += f"_trial_{trial.space_partition.name}_{trial.space.name}"
kernel = cache.get_integrand_kernel(
integrand=integrand,
suffix=kernel_suffix,
)
if kernel is not None:
return kernel, FieldStruct, ValueStruct
# Not found in cache, trasnform integrand and generate kernel
integrand_func = _translate_integrand(
integrand,
field_args,
)
_register_integrand_field_wrappers(integrand_func, fields)
if test is None and trial is None:
integrate_kernel_fn = get_integrate_constant_kernel(
integrand_func,
domain,
quadrature,
FieldStruct,
ValueStruct,
accumulate_dtype=accumulate_dtype,
)
elif trial is None:
if nodal:
integrate_kernel_fn = get_integrate_linear_nodal_kernel(
integrand_func,
domain,
FieldStruct,
ValueStruct,
test=test,
output_dtype=output_dtype,
accumulate_dtype=accumulate_dtype,
)
else:
integrate_kernel_fn = get_integrate_linear_kernel(
integrand_func,
domain,
quadrature,
FieldStruct,
ValueStruct,
test=test,
output_dtype=output_dtype,
accumulate_dtype=accumulate_dtype,
)
else:
if nodal:
integrate_kernel_fn = get_integrate_bilinear_nodal_kernel(
integrand_func,
domain,
FieldStruct,
ValueStruct,
test=test,
output_dtype=output_dtype,
accumulate_dtype=accumulate_dtype,
)
else:
integrate_kernel_fn = get_integrate_bilinear_kernel(
integrand_func,
domain,
quadrature,
FieldStruct,
ValueStruct,
test=test,
trial=trial,
output_dtype=output_dtype,
accumulate_dtype=accumulate_dtype,
)
kernel = cache.get_integrand_kernel(
integrand=integrand,
kernel_fn=integrate_kernel_fn,
suffix=kernel_suffix,
kernel_options=kernel_options,
code_transformers=[
PassFieldArgsToIntegrand(
arg_names=integrand.argspec.args,
field_args=field_args.keys(),
value_args=value_args.keys(),
sample_name=sample_name,
domain_name=domain_name,
test_name=test_name,
trial_name=trial_name,
)
],
)
return kernel, FieldStruct, ValueStruct
def _launch_integrate_kernel(
kernel: wp.Kernel,
FieldStruct: wp.codegen.Struct,
ValueStruct: wp.codegen.Struct,
domain: GeometryDomain,
nodal: bool,
quadrature: Quadrature,
test: Optional[TestField],
trial: Optional[TrialField],
fields: Dict[str, FieldLike],
values: Dict[str, Any],
accumulate_dtype: type,
temporary_store: Optional[cache.TemporaryStore],
output_dtype: type,
output: Optional[Union[wp.array, BsrMatrix]],
device,
):
# Set-up launch arguments
domain_elt_arg = domain.element_arg_value(device=device)
domain_elt_index_arg = domain.element_index_arg_value(device=device)
if quadrature is not None:
qp_arg = quadrature.arg_value(device=device)
field_arg_values = FieldStruct()
for k, v in fields.items():
setattr(field_arg_values, k, v.eval_arg_value(device=device))
value_struct_values = ValueStruct()
for k, v in values.items():
setattr(value_struct_values, k, v)
# Constant form
if test is None and trial is None:
if output is not None and output.dtype == accumulate_dtype:
if output.size < 1:
raise RuntimeError("Output array must be of size at least 1")
accumulate_array = output
else:
accumulate_temporary = cache.borrow_temporary(
shape=(1),
device=device,
dtype=accumulate_dtype,
temporary_store=temporary_store,
requires_grad=output is not None and output.requires_grad,
)
accumulate_array = accumulate_temporary.array
accumulate_array.zero_()
wp.launch(
kernel=kernel,
dim=domain.element_count(),
inputs=[
qp_arg,
domain_elt_arg,
domain_elt_index_arg,
field_arg_values,
value_struct_values,
accumulate_array,
],
device=device,
)
if output == accumulate_array:
return output
elif output is None:
return accumulate_array.numpy()[0]
else:
array_cast(in_array=accumulate_array, out_array=output)
return output
test_arg = test.space_restriction.node_arg(device=device)
# Linear form
if trial is None:
# If an output array is provided with the correct type, accumulate directly into it
# Otherwise, grab a temporary array
if output is None:
if type_length(output_dtype) == test.space.VALUE_DOF_COUNT:
output_shape = (test.space_partition.node_count(),)
elif type_length(output_dtype) == 1:
output_shape = (test.space_partition.node_count(), test.space.VALUE_DOF_COUNT)
else:
raise RuntimeError(
f"Incompatible output type {wp.types.type_repr(output_dtype)}, must be scalar or vector of length {test.space.VALUE_DOF_COUNT}"
)
output_temporary = cache.borrow_temporary(
temporary_store=temporary_store,
shape=output_shape,
dtype=output_dtype,
device=device,
)
output = output_temporary.array
else:
output_temporary = None
if output.shape[0] < test.space_partition.node_count():
raise RuntimeError(f"Output array must have at least {test.space_partition.node_count()} rows")
output_dtype = output.dtype
if type_length(output_dtype) != test.space.VALUE_DOF_COUNT:
if type_length(output_dtype) != 1:
raise RuntimeError(
f"Incompatible output type {wp.types.type_repr(output_dtype)}, must be scalar or vector of length {test.space.VALUE_DOF_COUNT}"
)
if output.ndim != 2 and output.shape[1] != test.space.VALUE_DOF_COUNT:
raise RuntimeError(
f"Incompatible output array shape, last dimension must be of size {test.space.VALUE_DOF_COUNT}"
)
# Launch the integration on the kernel on a 2d scalar view of the actual array
output.zero_()
def as_2d_array(array):
return wp.array(
data=None,
ptr=array.ptr,
capacity=array.capacity,
owner=False,
device=array.device,
shape=(test.space_partition.node_count(), test.space.VALUE_DOF_COUNT),
dtype=wp.types.type_scalar_type(output_dtype),
grad=None if array.grad is None else as_2d_array(array.grad),
)
output_view = output if output.ndim == 2 else as_2d_array(output)
if nodal:
wp.launch(
kernel=kernel,
dim=(test.space_restriction.node_count(), test.space.VALUE_DOF_COUNT),
inputs=[
domain_elt_arg,
domain_elt_index_arg,
test_arg,
field_arg_values,
value_struct_values,
output_view,
],
device=device,
)
else:
wp.launch(
kernel=kernel,
dim=(test.space_restriction.node_count(), test.space.VALUE_DOF_COUNT),
inputs=[
qp_arg,
domain_elt_arg,
domain_elt_index_arg,
test_arg,
field_arg_values,
value_struct_values,
output_view,
],
device=device,
)
if output_temporary is not None:
return output_temporary.detach()
return output
# Bilinear form
if test.space.VALUE_DOF_COUNT == 1 and trial.space.VALUE_DOF_COUNT == 1:
block_type = output_dtype
else:
block_type = cache.cached_mat_type(
shape=(test.space.VALUE_DOF_COUNT, trial.space.VALUE_DOF_COUNT), dtype=output_dtype
)
if nodal:
nnz = test.space_restriction.node_count()
else:
nnz = test.space_restriction.total_node_element_count() * trial.space.topology.NODES_PER_ELEMENT
triplet_rows_temp = cache.borrow_temporary(temporary_store, shape=(nnz,), dtype=int, device=device)
triplet_cols_temp = cache.borrow_temporary(temporary_store, shape=(nnz,), dtype=int, device=device)
triplet_values_temp = cache.borrow_temporary(
temporary_store,
shape=(
nnz,
test.space.VALUE_DOF_COUNT,
trial.space.VALUE_DOF_COUNT,
),
dtype=output_dtype,
device=device,
)
triplet_cols = triplet_cols_temp.array
triplet_rows = triplet_rows_temp.array
triplet_values = triplet_values_temp.array
triplet_values.zero_()
if nodal:
wp.launch(
kernel=kernel,
dim=triplet_values.shape,
inputs=[
domain_elt_arg,
domain_elt_index_arg,
test_arg,
field_arg_values,
value_struct_values,
triplet_rows,
triplet_cols,
triplet_values,
],
device=device,
)
else:
offsets = test.space_restriction.partition_element_offsets()
trial_partition_arg = trial.space_partition.partition_arg_value(device)
trial_topology_arg = trial.space_partition.space_topology.topo_arg_value(device)
wp.launch(
kernel=kernel,
dim=(
test.space_restriction.node_count(),
trial.space.topology.NODES_PER_ELEMENT,
test.space.VALUE_DOF_COUNT,
trial.space.VALUE_DOF_COUNT,
),
inputs=[
qp_arg,
domain_elt_arg,
domain_elt_index_arg,
test_arg,
trial_partition_arg,
trial_topology_arg,
field_arg_values,
value_struct_values,
offsets,
triplet_rows,
triplet_cols,
triplet_values,
],
device=device,
)
if output is not None:
if output.nrow != test.space_partition.node_count() or output.ncol != trial.space_partition.node_count():
raise RuntimeError(
f"Output matrix must have {test.space_partition.node_count()} rows and {trial.space_partition.node_count()} columns of blocks"
)
else:
output = bsr_zeros(
rows_of_blocks=test.space_partition.node_count(),
cols_of_blocks=trial.space_partition.node_count(),
block_type=block_type,
device=device,
)
bsr_set_from_triplets(output, triplet_rows, triplet_cols, triplet_values)
# Do not wait for garbage collection
triplet_values_temp.release()
triplet_rows_temp.release()
triplet_cols_temp.release()
return output
def integrate(
integrand: Integrand,
domain: Optional[GeometryDomain] = None,
quadrature: Optional[Quadrature] = None,
nodal: bool = False,
fields: Dict[str, FieldLike] = {},
values: Dict[str, Any] = {},
accumulate_dtype: type = wp.float64,
output_dtype: Optional[type] = None,
output: Optional[Union[BsrMatrix, wp.array]] = None,
device=None,
temporary_store: Optional[cache.TemporaryStore] = None,
kernel_options: Dict[str, Any] = {},
):
"""
Integrates a constant, linear or bilinear form, and returns a scalar, array, or sparse matrix, respectively.
Args:
integrand: Form to be integrated, must have :func:`integrand` decorator
domain: Integration domain. If None, deduced from fields
quadrature: Quadrature formula. If None, deduced from domain and fields degree.
nodal: For linear or bilinear form only, use the test function nodes as the quadrature points. Assumes Lagrange interpolation functions are used, and no differential or DG operator is evaluated on the test or trial functions.
fields: Discrete, test, and trial fields to be passed to the integrand. Keys in the dictionary must match integrand parameter names.
values: Additional variable values to be passed to the integrand, can be of any type accepted by warp kernel launchs. Keys in the dictionary must match integrand parameter names.
temporary_store: shared pool from which to allocate temporary arrays
accumulate_dtype: Scalar type to be used for accumulating integration samples
output: Sparse matrix or warp array into which to store the result of the integration
output_dtype: Scalar type for returned results in `output` is notr provided. If None, defaults to `accumulate_dtype`
device: Device on which to perform the integration
kernel_options: Overloaded options to be passed to the kernel builder (e.g, ``{"enable_backward": True}``)
"""
if not isinstance(integrand, Integrand):
raise ValueError("integrand must be tagged with @warp.fem.integrand decorator")
test, test_name, trial, trial_name = _get_test_and_trial_fields(fields)
if domain is None:
if quadrature is not None:
domain = quadrature.domain
elif test is not None:
domain = test.domain
if domain is None:
raise ValueError("Must provide at least one of domain, quadrature, or test field")
if test is not None and domain != test.domain:
raise NotImplementedError("Mixing integration and test domain is not supported yet")
if nodal:
if quadrature is not None:
raise ValueError("Cannot specify quadrature for nodal integration")
if test is None:
raise ValueError("Nodal integration requires specifying a test function")
if trial is not None and test.space_partition != trial.space_partition:
raise ValueError(
"Bilinear nodal integration requires test and trial to be defined on the same function space"
)
else:
if quadrature is None:
order = sum(field.degree for field in fields.values())
quadrature = RegularQuadrature(domain=domain, order=order)
elif domain != quadrature.domain:
raise ValueError("Incompatible integration and quadrature domain")
# Canonicalize types
accumulate_dtype = wp.types.type_to_warp(accumulate_dtype)
if output is not None:
if isinstance(output, BsrMatrix):
output_dtype = output.scalar_type
else:
output_dtype = output.dtype
elif output_dtype is None:
output_dtype = accumulate_dtype
else:
output_dtype = wp.types.type_to_warp(output_dtype)
kernel, FieldStruct, ValueStruct = _generate_integrate_kernel(
integrand=integrand,
domain=domain,
nodal=nodal,
quadrature=quadrature,
test=test,
test_name=test_name,
trial=trial,
trial_name=trial_name,
fields=fields,
accumulate_dtype=accumulate_dtype,
output_dtype=output_dtype,
kernel_options=kernel_options,
)
return _launch_integrate_kernel(
kernel=kernel,
FieldStruct=FieldStruct,
ValueStruct=ValueStruct,
domain=domain,
nodal=nodal,
quadrature=quadrature,
test=test,
trial=trial,
fields=fields,
values=values,
accumulate_dtype=accumulate_dtype,
temporary_store=temporary_store,
output_dtype=output_dtype,
output=output,
device=device,
)
def get_interpolate_to_field_function(
integrand_func: wp.Function,
domain: GeometryDomain,
FieldStruct: wp.codegen.Struct,
ValueStruct: wp.codegen.Struct,
dest: FieldRestriction,
):
value_type = dest.space.dtype
def interpolate_to_field_fn(
local_node_index: int,
domain_arg: domain.ElementArg,
domain_index_arg: domain.ElementIndexArg,
dest_node_arg: dest.space_restriction.NodeArg,
dest_eval_arg: dest.field.EvalArg,
fields: FieldStruct,
values: ValueStruct,
):
node_index = dest.space_restriction.node_partition_index(dest_node_arg, local_node_index)
element_count = dest.space_restriction.node_element_count(dest_node_arg, local_node_index)
test_dof_index = NULL_DOF_INDEX
trial_dof_index = NULL_DOF_INDEX
node_weight = 1.0
# Volume-weighted average across elements
# Superfluous if the interpolated function is continuous, but helpful for visualizing discontinuous spaces
val_sum = value_type(0.0)
vol_sum = float(0.0)
for n in range(element_count):
node_element_index = dest.space_restriction.node_element_index(dest_node_arg, local_node_index, n)
element_index = domain.element_index(domain_index_arg, node_element_index.domain_element_index)
coords = dest.space.node_coords_in_element(
domain_arg,
dest_eval_arg.space_arg,
element_index,
node_element_index.node_index_in_element,
)
if coords[0] != OUTSIDE:
sample = Sample(
element_index,
coords,
node_index,
node_weight,
test_dof_index,
trial_dof_index,
)
vol = domain.element_measure(domain_arg, sample)
val = integrand_func(sample, fields, values)
vol_sum += vol
val_sum += vol * val
return val_sum, vol_sum
return interpolate_to_field_fn
def get_interpolate_to_field_kernel(
interpolate_to_field_fn: wp.Function,
domain: GeometryDomain,
FieldStruct: wp.codegen.Struct,
ValueStruct: wp.codegen.Struct,
dest: FieldRestriction,
):
def interpolate_to_field_kernel_fn(
domain_arg: domain.ElementArg,
domain_index_arg: domain.ElementIndexArg,
dest_node_arg: dest.space_restriction.NodeArg,
dest_eval_arg: dest.field.EvalArg,
fields: FieldStruct,
values: ValueStruct,
):
local_node_index = wp.tid()
val_sum, vol_sum = interpolate_to_field_fn(
local_node_index, domain_arg, domain_index_arg, dest_node_arg, dest_eval_arg, fields, values
)
if vol_sum > 0.0:
node_index = dest.space_restriction.node_partition_index(dest_node_arg, local_node_index)
dest.field.set_node_value(dest_eval_arg, node_index, val_sum / vol_sum)
return interpolate_to_field_kernel_fn
def get_interpolate_to_array_kernel(
integrand_func: wp.Function,
domain: GeometryDomain,
quadrature: Quadrature,
FieldStruct: wp.codegen.Struct,
ValueStruct: wp.codegen.Struct,
value_type: type,
):
def interpolate_to_array_kernel_fn(
qp_arg: quadrature.Arg,
domain_arg: quadrature.domain.ElementArg,
domain_index_arg: quadrature.domain.ElementIndexArg,
fields: FieldStruct,
values: ValueStruct,
result: wp.array(dtype=value_type),
):
element_index = domain.element_index(domain_index_arg, wp.tid())
test_dof_index = NULL_DOF_INDEX
trial_dof_index = NULL_DOF_INDEX
qp_point_count = quadrature.point_count(domain_arg, qp_arg, element_index)
for k in range(qp_point_count):
qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
result[qp_index] = integrand_func(sample, fields, values)
return interpolate_to_array_kernel_fn
def get_interpolate_nonvalued_kernel(
integrand_func: wp.Function,
domain: GeometryDomain,
quadrature: Quadrature,
FieldStruct: wp.codegen.Struct,
ValueStruct: wp.codegen.Struct,
):
def interpolate_nonvalued_kernel_fn(
qp_arg: quadrature.Arg,
domain_arg: quadrature.domain.ElementArg,
domain_index_arg: quadrature.domain.ElementIndexArg,
fields: FieldStruct,
values: ValueStruct,
):
element_index = domain.element_index(domain_index_arg, wp.tid())
test_dof_index = NULL_DOF_INDEX
trial_dof_index = NULL_DOF_INDEX
qp_point_count = quadrature.point_count(domain_arg, qp_arg, element_index)
for k in range(qp_point_count):
qp_index = quadrature.point_index(domain_arg, qp_arg, element_index, k)
coords = quadrature.point_coords(domain_arg, qp_arg, element_index, k)
qp_weight = quadrature.point_weight(domain_arg, qp_arg, element_index, k)
sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
integrand_func(sample, fields, values)
return interpolate_nonvalued_kernel_fn
def _generate_interpolate_kernel(
integrand: Integrand,
domain: GeometryDomain,
dest: Optional[Union[FieldLike, wp.array]],
quadrature: Optional[Quadrature],
fields: Dict[str, FieldLike],
kernel_options: Dict[str, Any] = {},
) -> wp.Kernel:
# Extract field arguments from integrand
field_args, value_args, domain_name, sample_name = _get_integrand_field_arguments(
integrand, fields=fields, domain=domain
)
# Generate field struct
integrand_func = _translate_integrand(
integrand,
field_args,
)
_register_integrand_field_wrappers(integrand_func, fields)
FieldStruct = _gen_field_struct(field_args)
ValueStruct = _gen_value_struct(value_args)
# Check if kernel exist in cache
if isinstance(dest, FieldRestriction):
kernel_suffix = (
f"_itp_{FieldStruct.key}_{dest.domain.name}_{dest.space_restriction.space_partition.name}_{dest.space.name}"
)
elif wp.types.is_array(dest):
kernel_suffix = f"_itp_{FieldStruct.key}_{quadrature.name}_{wp.types.type_repr(dest.dtype)}"
else:
kernel_suffix = f"_itp_{FieldStruct.key}_{quadrature.name}"
kernel = cache.get_integrand_kernel(
integrand=integrand,
suffix=kernel_suffix,
)
if kernel is not None:
return kernel, FieldStruct, ValueStruct
# Generate interpolation kernel
if isinstance(dest, FieldRestriction):
# need to split into kernel + function for diffferentiability
interpolate_fn = get_interpolate_to_field_function(
integrand_func,
domain,
dest=dest,
FieldStruct=FieldStruct,
ValueStruct=ValueStruct,
)
interpolate_fn = cache.get_integrand_function(
integrand=integrand,
func=interpolate_fn,
suffix=kernel_suffix,
code_transformers=[
PassFieldArgsToIntegrand(
arg_names=integrand.argspec.args,
field_args=field_args.keys(),
value_args=value_args.keys(),
sample_name=sample_name,
domain_name=domain_name,
)
],
)
interpolate_kernel_fn = get_interpolate_to_field_kernel(
interpolate_fn,
domain,
dest=dest,
FieldStruct=FieldStruct,
ValueStruct=ValueStruct,
)
elif wp.types.is_array(dest):
interpolate_kernel_fn = get_interpolate_to_array_kernel(
integrand_func,
domain=domain,
quadrature=quadrature,
value_type=dest.dtype,
FieldStruct=FieldStruct,
ValueStruct=ValueStruct,
)
else:
interpolate_kernel_fn = get_interpolate_nonvalued_kernel(
integrand_func,
domain=domain,
quadrature=quadrature,
FieldStruct=FieldStruct,
ValueStruct=ValueStruct,
)
kernel = cache.get_integrand_kernel(
integrand=integrand,
kernel_fn=interpolate_kernel_fn,
suffix=kernel_suffix,
kernel_options=kernel_options,
code_transformers=[
PassFieldArgsToIntegrand(
arg_names=integrand.argspec.args,
field_args=field_args.keys(),
value_args=value_args.keys(),
sample_name=sample_name,
domain_name=domain_name,
)
],
)
return kernel, FieldStruct, ValueStruct
def _launch_interpolate_kernel(
kernel: wp.kernel,
FieldStruct: wp.codegen.Struct,
ValueStruct: wp.codegen.Struct,
domain: GeometryDomain,
dest: Optional[Union[FieldRestriction, wp.array]],
quadrature: Optional[Quadrature],
fields: Dict[str, FieldLike],
values: Dict[str, Any],
device,
) -> wp.Kernel:
# Set-up launch arguments
elt_arg = domain.element_arg_value(device=device)
elt_index_arg = domain.element_index_arg_value(device=device)
field_arg_values = FieldStruct()
for k, v in fields.items():
setattr(field_arg_values, k, v.eval_arg_value(device=device))
value_struct_values = ValueStruct()
for k, v in values.items():
setattr(value_struct_values, k, v)
if isinstance(dest, FieldRestriction):
dest_node_arg = dest.space_restriction.node_arg(device=device)
dest_eval_arg = dest.field.eval_arg_value(device=device)
wp.launch(
kernel=kernel,
dim=dest.space_restriction.node_count(),
inputs=[
elt_arg,
elt_index_arg,
dest_node_arg,
dest_eval_arg,
field_arg_values,
value_struct_values,
],
device=device,
)
elif wp.types.is_array(dest):
qp_arg = quadrature.arg_value(device)
wp.launch(
kernel=kernel,
dim=domain.element_count(),
inputs=[qp_arg, elt_arg, elt_index_arg, field_arg_values, value_struct_values, dest],
device=device,
)
else:
qp_arg = quadrature.arg_value(device)
wp.launch(
kernel=kernel,
dim=domain.element_count(),
inputs=[qp_arg, elt_arg, elt_index_arg, field_arg_values, value_struct_values],
device=device,
)
def interpolate(
integrand: Integrand,
dest: Optional[Union[DiscreteField, FieldRestriction, wp.array]] = None,
quadrature: Optional[Quadrature] = None,
fields: Dict[str, FieldLike] = {},
values: Dict[str, Any] = {},
device=None,
kernel_options: Dict[str, Any] = {},
):
"""
Interpolates a function at a finite set of sample points and optionally assigns the result to a discrete field or a raw warp array.
Args:
integrand: Function to be interpolated, must have :func:`integrand` decorator
dest: Where to store the interpolation result. Can be either
- a :class:`DiscreteField`, or restriction of a discrete field to a domain (from :func:`make_restriction`). In this case, interpolation will be performed at each node.
- a normal warp array. In this case, the `quadrature` argument defining the interpolation locations must be provided and the result of the `integrand` at each quadrature point will be assigned to the array.
- ``None``. In this case, the `quadrature` argument must also be provided and the `integrand` function is reponsible for dealing with the interpolation result.
quadrature: Quadrature formula defining the interpolation samples if `dest` is not a discrete field or field restriction.
fields: Discrete fields to be passed to the integrand. Keys in the dictionary must match integrand parameters names.
values: Additional variable values to be passed to the integrand, can be of any type accepted by warp kernel launchs. Keys in the dictionary must match integrand parameter names.
device: Device on which to perform the interpolation
kernel_options: Overloaded options to be passed to the kernel builder (e.g, ``{"enable_backward": True}``)
"""
if not isinstance(integrand, Integrand):
raise ValueError("integrand must be tagged with @integrand decorator")
test, _, trial, __ = _get_test_and_trial_fields(fields)
if test is not None or trial is not None:
raise ValueError("Test or Trial fields should not be used for interpolation")
if isinstance(dest, DiscreteField):
dest = make_restriction(dest)
if isinstance(dest, FieldRestriction):
domain = dest.domain
else:
if quadrature is None:
raise ValueError("When not interpolating to a field, a quadrature formula must be provided")
domain = quadrature.domain
kernel, FieldStruct, ValueStruct = _generate_interpolate_kernel(
integrand=integrand,
domain=domain,
dest=dest,
quadrature=quadrature,
fields=fields,
kernel_options=kernel_options,
)
return _launch_interpolate_kernel(
kernel=kernel,
FieldStruct=FieldStruct,
ValueStruct=ValueStruct,
domain=domain,
dest=dest,
quadrature=quadrature,
fields=fields,
values=values,
device=device,
)