File size: 3,783 Bytes
66c9c8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from typing import Any, Optional

import warp as wp

from warp.fem.types import DofIndex, get_node_coord
from warp.fem.geometry import GeometryPartition
from warp.fem import cache, utils


from .function_space import FunctionSpace
from .dof_mapper import DofMapper, IdentityMapper
from .partition import make_space_partition, SpacePartition

from .basis_space import BasisSpace


class CollocatedFunctionSpace(FunctionSpace):
    """Function space where values are collocated at nodes"""

    def __init__(self, basis: BasisSpace, dtype: type = float, dof_mapper: DofMapper = None):
        super().__init__(topology=basis.topology)

        self.dof_mapper = IdentityMapper(dtype) if dof_mapper is None else dof_mapper
        self.dtype = self.dof_mapper.value_dtype
        self.dof_dtype = self.dof_mapper.dof_dtype
        self.VALUE_DOF_COUNT = self.dof_mapper.DOF_SIZE

        self._basis = basis
        self.SpaceArg = self._basis.BasisArg

        self.ORDER = self._basis.ORDER

        self.unit_dof_value = self._make_unit_dof_value(self.dof_mapper)

        self.node_coords_in_element = self._basis.make_node_coords_in_element()
        self.node_quadrature_weight = self._basis.make_node_quadrature_weight()
        self.element_inner_weight = self._basis.make_element_inner_weight()
        self.element_inner_weight_gradient = self._basis.make_element_inner_weight_gradient()
        self.element_outer_weight = self._basis.make_element_outer_weight()
        self.element_outer_weight_gradient = self._basis.make_element_outer_weight_gradient()

        # For backward compatibility
        if hasattr(basis, "node_grid"):
            self.node_grid = basis.node_grid
        if hasattr(basis, "node_triangulation"):
            self.node_triangulation = basis.node_triangulation
        if hasattr(basis, "node_tets"):
            self.node_tets = basis.node_tets
        if hasattr(basis, "node_hexes"):
            self.node_hexes = basis.node_hexes

    def space_arg_value(self, device):
        return self._basis.basis_arg_value(device)

    @property
    def name(self):
        return f"{self._basis.name}_{self.dof_mapper}".replace(".", "_")

    @property
    def degree(self):
        """Maximum polynomial degree of the underlying basis"""
        return self.ORDER

    def make_field(
        self,
        space_partition: Optional[SpacePartition] = None,
    ) -> "wp.fem.field.NodalField":

        from warp.fem.field import NodalField

        if space_partition is None:
            space_partition = make_space_partition(space_topology=self.topology)

        return NodalField(space=self, space_partition=space_partition)

    def _make_unit_dof_value(self, dof_mapper: DofMapper):
        @cache.dynamic_func(suffix=self.name)
        def unit_dof_value(geo_arg: self.topology.ElementArg, space_arg: self.SpaceArg, dof: DofIndex):
            return dof_mapper.dof_to_value(utils.unit_element(dof_mapper.dof_dtype(0.0), get_node_coord(dof)))

        return unit_dof_value

    def node_count(self):
        return self.topology.node_count()

    def node_positions(self, out:Optional[wp.array] = None) -> wp.array:
        return self._basis.node_positions(out=out)

    def trace(self) -> "CollocatedFunctionSpace":
        return CollocatedFunctionSpaceTrace(self)


class CollocatedFunctionSpaceTrace(CollocatedFunctionSpace):
    """Trace of a :class:`CollocatedFunctionSpace`"""

    def __init__(self, space: CollocatedFunctionSpace):
        self._space = space
        super().__init__(space._basis.trace(), space.dtype, space.dof_mapper)

    @property
    def name(self):
        return f"{self._space.name}_Trace"

    def __eq__(self, other: "CollocatedFunctionSpaceTrace") -> bool:
        return self._space == other._space