Spaces:
Sleeping
Sleeping
| from typing import Any | |
| import warp as wp | |
| from warp.fem.types import DofIndex, ElementIndex, Coords, get_node_coord | |
| from warp.fem.geometry import GeometryPartition | |
| from warp.fem import utils | |
| from .function_space import FunctionSpace | |
| from .dof_mapper import DofMapper, IdentityMapper | |
| from .partition import make_space_partition, SpacePartition | |
| class NodalFunctionSpace(FunctionSpace): | |
| """Function space where values are collocated at nodes""" | |
| def __init__(self, dtype: type = float, dof_mapper: DofMapper = None): | |
| self.dof_mapper = IdentityMapper(dtype) if dof_mapper is None else dof_mapper | |
| self.dtype = self.dof_mapper.value_dtype | |
| self.dof_dtype = self.dof_mapper.dof_dtype | |
| if self.dtype == wp.float32: | |
| self.gradient_dtype = wp.vec2 | |
| elif self.dtype == wp.vec2: | |
| self.gradient_dtype = wp.mat22 | |
| elif self.dtype == wp.vec3: | |
| self.gradient_dtype = wp.mat33 | |
| else: | |
| self.gradient_dtype = None | |
| self.VALUE_DOF_COUNT = self.dof_mapper.DOF_SIZE | |
| self.unit_dof_value = self._make_unit_dof_value(self.dof_mapper) | |
| def name(self): | |
| return f"{self.__class__.__qualname__}_{self.ORDER}_{self.dof_mapper}".replace(".", "_") | |
| def make_field( | |
| self, | |
| space_partition: SpacePartition = None, | |
| geometry_partition: GeometryPartition = None, | |
| ) -> "wp.fem.field.NodalField": | |
| from warp.fem.field import NodalField | |
| if space_partition is None: | |
| space_partition = make_space_partition(self, geometry_partition) | |
| return NodalField(space=self, space_partition=space_partition) | |
| def _make_unit_dof_value(dof_mapper: DofMapper): | |
| from warp.fem import cache | |
| def unit_dof_value(args: Any, dof: DofIndex): | |
| return dof_mapper.dof_to_value(utils.unit_element(dof_mapper.dof_dtype(0.0), get_node_coord(dof))) | |
| return cache.get_func(unit_dof_value, str(dof_mapper)) | |
| # Interface for generating Trace space | |
| def _inner_cell_index(args: Any, side_index: ElementIndex): | |
| """Given a side, returns the index of the inner cell""" | |
| raise NotImplementedError | |
| def _outer_cell_index(args: Any, side_index: ElementIndex): | |
| """Given a side, returns the index of the outer cell""" | |
| raise NotImplementedError | |
| def _inner_cell_coords(args: Any, side_index: ElementIndex, side_coords: Coords): | |
| """Given coordinates within a side, returns coordinates within the inner cell""" | |
| raise NotImplementedError | |
| def _outer_cell_coords(args: Any, side_index: ElementIndex, side_coords: Coords): | |
| """Given coordinates within a side, returns coordinates within the outer cell""" | |
| raise NotImplementedError | |
| def _cell_to_side_coords( | |
| args: Any, | |
| side_index: ElementIndex, | |
| element_index: ElementIndex, | |
| element_coords: Coords, | |
| ): | |
| """Given coordinates within a cell, returns coordinates within a side, or OUTSIDE""" | |
| raise NotImplementedError | |
| class NodalFunctionSpaceTrace(NodalFunctionSpace): | |
| """Trace of a NodalFunctionSpace""" | |
| def __init__(self, space: NodalFunctionSpace): | |
| self._space = space | |
| super().__init__(space.dtype, space.dof_mapper) | |
| self.geometry = space.geometry | |
| self.NODES_PER_ELEMENT = wp.constant(2 * space.NODES_PER_ELEMENT) | |
| self.DIMENSION = space.DIMENSION - 1 | |
| self.SpaceArg = space.SpaceArg | |
| self.space_arg_value = space.space_arg_value | |
| def node_count(self) -> int: | |
| return self._space.node_count() | |
| def name(self): | |
| return f"{self._space.name}_Trace" | |
| def _make_element_node_index(space: NodalFunctionSpace): | |
| from warp.fem import cache | |
| NODES_PER_ELEMENT = space.NODES_PER_ELEMENT | |
| def trace_element_node_index(args: space.SpaceArg, element_index: ElementIndex, node_index_in_elt: int): | |
| if node_index_in_elt < NODES_PER_ELEMENT: | |
| inner_element = space._inner_cell_index(args, element_index) | |
| return space.element_node_index(args, inner_element, node_index_in_elt) | |
| outer_element = space._outer_cell_index(args, element_index) | |
| return space.element_node_index(args, outer_element, node_index_in_elt - NODES_PER_ELEMENT) | |
| return cache.get_func(trace_element_node_index, space.name) | |
| def _make_node_coords_in_element(space: NodalFunctionSpace): | |
| from warp.fem import cache | |
| NODES_PER_ELEMENT = space.NODES_PER_ELEMENT | |
| def trace_node_coords_in_element( | |
| args: space.SpaceArg, | |
| element_index: ElementIndex, | |
| node_index_in_elt: int, | |
| ): | |
| if node_index_in_elt < NODES_PER_ELEMENT: | |
| neighbour_elem = space._inner_cell_index(args, element_index) | |
| neighbour_coords = space.node_coords_in_element(args, neighbour_elem, node_index_in_elt) | |
| else: | |
| neighbour_elem = space._outer_cell_index(args, element_index) | |
| neighbour_coords = space.node_coords_in_element( | |
| args, | |
| neighbour_elem, | |
| node_index_in_elt - NODES_PER_ELEMENT, | |
| ) | |
| return space._cell_to_side_coords(args, element_index, neighbour_elem, neighbour_coords) | |
| return cache.get_func(trace_node_coords_in_element, space.name) | |
| def _make_element_inner_weight(space: NodalFunctionSpace): | |
| from warp.fem import cache | |
| def trace_element_inner_weight( | |
| args: space.SpaceArg, | |
| element_index: ElementIndex, | |
| coords: Coords, | |
| node_index_in_elt: int, | |
| ): | |
| return space.element_inner_weight( | |
| args, | |
| space._inner_cell_index(args, element_index), | |
| space._inner_cell_coords(args, element_index, coords), | |
| node_index_in_elt, | |
| ) | |
| return cache.get_func(trace_element_inner_weight, space.name) | |
| def _make_element_outer_weight(space: NodalFunctionSpace): | |
| from warp.fem import cache | |
| NODES_PER_ELEMENT = space.NODES_PER_ELEMENT | |
| def trace_element_outer_weight( | |
| args: space.SpaceArg, | |
| element_index: ElementIndex, | |
| coords: Coords, | |
| node_index_in_elt: int, | |
| ): | |
| return space.element_outer_weight( | |
| args, | |
| space._outer_cell_index(args, element_index), | |
| space._outer_cell_coords(args, element_index, coords), | |
| node_index_in_elt - NODES_PER_ELEMENT, | |
| ) | |
| return cache.get_func(trace_element_outer_weight, space.name) | |
| def _make_element_inner_weight_gradient(space: NodalFunctionSpace): | |
| from warp.fem import cache | |
| def trace_element_inner_weight_gradient( | |
| args: space.SpaceArg, | |
| element_index: ElementIndex, | |
| coords: Coords, | |
| node_index_in_elt: int, | |
| ): | |
| return space.element_inner_weight_gradient( | |
| args, | |
| space._inner_cell_index(args, element_index), | |
| space._inner_cell_coords(args, element_index, coords), | |
| node_index_in_elt, | |
| ) | |
| return cache.get_func(trace_element_inner_weight_gradient, space.name) | |
| def _make_element_outer_weight_gradient(space: NodalFunctionSpace): | |
| from warp.fem import cache | |
| NODES_PER_ELEMENT = space.NODES_PER_ELEMENT | |
| def trace_element_outer_weight_gradient( | |
| args: space.SpaceArg, | |
| element_index: ElementIndex, | |
| coords: Coords, | |
| node_index_in_elt: int, | |
| ): | |
| return space.element_outer_weight_gradient( | |
| args, | |
| space._outer_cell_index(args, element_index), | |
| space._outer_cell_coords(args, element_index, coords), | |
| node_index_in_elt - NODES_PER_ELEMENT, | |
| ) | |
| return cache.get_func(trace_element_outer_weight_gradient, space.name) | |
| def __eq__(self, other: "NodalFunctionSpaceTrace") -> bool: | |
| return self._space == other._space | |