File size: 8,362 Bytes
66c9c8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
from typing import Any

import warp as wp

from warp.fem.types import DofIndex, ElementIndex, Coords, get_node_coord
from warp.fem.geometry import GeometryPartition
from warp.fem import utils


from .function_space import FunctionSpace
from .dof_mapper import DofMapper, IdentityMapper
from .partition import make_space_partition, SpacePartition


class NodalFunctionSpace(FunctionSpace):
    """Function space where values are collocated at nodes"""

    def __init__(self, dtype: type = float, dof_mapper: DofMapper = None):
        self.dof_mapper = IdentityMapper(dtype) if dof_mapper is None else dof_mapper
        self.dtype = self.dof_mapper.value_dtype
        self.dof_dtype = self.dof_mapper.dof_dtype

        if self.dtype == wp.float32:
            self.gradient_dtype = wp.vec2
        elif self.dtype == wp.vec2:
            self.gradient_dtype = wp.mat22
        elif self.dtype == wp.vec3:
            self.gradient_dtype = wp.mat33
        else:
            self.gradient_dtype = None

        self.VALUE_DOF_COUNT = self.dof_mapper.DOF_SIZE
        self.unit_dof_value = self._make_unit_dof_value(self.dof_mapper)

    @property
    def name(self):
        return f"{self.__class__.__qualname__}_{self.ORDER}_{self.dof_mapper}".replace(".", "_")

    def make_field(
        self,
        space_partition: SpacePartition = None,
        geometry_partition: GeometryPartition = None,
    ) -> "wp.fem.field.NodalField":
        from warp.fem.field import NodalField

        if space_partition is None:
            space_partition = make_space_partition(self, geometry_partition)

        return NodalField(space=self, space_partition=space_partition)

    @staticmethod
    def _make_unit_dof_value(dof_mapper: DofMapper):
        from warp.fem import cache

        def unit_dof_value(args: Any, dof: DofIndex):
            return dof_mapper.dof_to_value(utils.unit_element(dof_mapper.dof_dtype(0.0), get_node_coord(dof)))

        return cache.get_func(unit_dof_value, str(dof_mapper))

    # Interface for generating Trace space

    def _inner_cell_index(args: Any, side_index: ElementIndex):
        """Given a side, returns the index of the inner cell"""
        raise NotImplementedError

    def _outer_cell_index(args: Any, side_index: ElementIndex):
        """Given a side, returns the index of the outer cell"""
        raise NotImplementedError

    def _inner_cell_coords(args: Any, side_index: ElementIndex, side_coords: Coords):
        """Given coordinates within a side, returns coordinates within the inner cell"""
        raise NotImplementedError

    def _outer_cell_coords(args: Any, side_index: ElementIndex, side_coords: Coords):
        """Given coordinates within a side, returns coordinates within the outer cell"""
        raise NotImplementedError

    def _cell_to_side_coords(
        args: Any,
        side_index: ElementIndex,
        element_index: ElementIndex,
        element_coords: Coords,
    ):
        """Given coordinates within a cell, returns coordinates within a side, or OUTSIDE"""
        raise NotImplementedError


class NodalFunctionSpaceTrace(NodalFunctionSpace):
    """Trace of a NodalFunctionSpace"""

    def __init__(self, space: NodalFunctionSpace):
        self._space = space

        super().__init__(space.dtype, space.dof_mapper)
        self.geometry = space.geometry

        self.NODES_PER_ELEMENT = wp.constant(2 * space.NODES_PER_ELEMENT)
        self.DIMENSION = space.DIMENSION - 1

        self.SpaceArg = space.SpaceArg
        self.space_arg_value = space.space_arg_value

    def node_count(self) -> int:
        return self._space.node_count()

    @property
    def name(self):
        return f"{self._space.name}_Trace"

    @staticmethod
    def _make_element_node_index(space: NodalFunctionSpace):
        from warp.fem import cache

        NODES_PER_ELEMENT = space.NODES_PER_ELEMENT

        def trace_element_node_index(args: space.SpaceArg, element_index: ElementIndex, node_index_in_elt: int):
            if node_index_in_elt < NODES_PER_ELEMENT:
                inner_element = space._inner_cell_index(args, element_index)
                return space.element_node_index(args, inner_element, node_index_in_elt)

            outer_element = space._outer_cell_index(args, element_index)
            return space.element_node_index(args, outer_element, node_index_in_elt - NODES_PER_ELEMENT)

        return cache.get_func(trace_element_node_index, space.name)

    @staticmethod
    def _make_node_coords_in_element(space: NodalFunctionSpace):
        from warp.fem import cache

        NODES_PER_ELEMENT = space.NODES_PER_ELEMENT

        def trace_node_coords_in_element(
            args: space.SpaceArg,
            element_index: ElementIndex,
            node_index_in_elt: int,
        ):
            if node_index_in_elt < NODES_PER_ELEMENT:
                neighbour_elem = space._inner_cell_index(args, element_index)
                neighbour_coords = space.node_coords_in_element(args, neighbour_elem, node_index_in_elt)
            else:
                neighbour_elem = space._outer_cell_index(args, element_index)
                neighbour_coords = space.node_coords_in_element(
                    args,
                    neighbour_elem,
                    node_index_in_elt - NODES_PER_ELEMENT,
                )

            return space._cell_to_side_coords(args, element_index, neighbour_elem, neighbour_coords)

        return cache.get_func(trace_node_coords_in_element, space.name)

    @staticmethod
    def _make_element_inner_weight(space: NodalFunctionSpace):
        from warp.fem import cache

        def trace_element_inner_weight(
            args: space.SpaceArg,
            element_index: ElementIndex,
            coords: Coords,
            node_index_in_elt: int,
        ):
            return space.element_inner_weight(
                args,
                space._inner_cell_index(args, element_index),
                space._inner_cell_coords(args, element_index, coords),
                node_index_in_elt,
            )

        return cache.get_func(trace_element_inner_weight, space.name)

    @staticmethod
    def _make_element_outer_weight(space: NodalFunctionSpace):
        from warp.fem import cache

        NODES_PER_ELEMENT = space.NODES_PER_ELEMENT

        def trace_element_outer_weight(
            args: space.SpaceArg,
            element_index: ElementIndex,
            coords: Coords,
            node_index_in_elt: int,
        ):
            return space.element_outer_weight(
                args,
                space._outer_cell_index(args, element_index),
                space._outer_cell_coords(args, element_index, coords),
                node_index_in_elt - NODES_PER_ELEMENT,
            )

        return cache.get_func(trace_element_outer_weight, space.name)

    @staticmethod
    def _make_element_inner_weight_gradient(space: NodalFunctionSpace):
        from warp.fem import cache

        def trace_element_inner_weight_gradient(
            args: space.SpaceArg,
            element_index: ElementIndex,
            coords: Coords,
            node_index_in_elt: int,
        ):
            return space.element_inner_weight_gradient(
                args,
                space._inner_cell_index(args, element_index),
                space._inner_cell_coords(args, element_index, coords),
                node_index_in_elt,
            )

        return cache.get_func(trace_element_inner_weight_gradient, space.name)

    @staticmethod
    def _make_element_outer_weight_gradient(space: NodalFunctionSpace):
        from warp.fem import cache

        NODES_PER_ELEMENT = space.NODES_PER_ELEMENT

        def trace_element_outer_weight_gradient(
            args: space.SpaceArg,
            element_index: ElementIndex,
            coords: Coords,
            node_index_in_elt: int,
        ):
            return space.element_outer_weight_gradient(
                args,
                space._outer_cell_index(args, element_index),
                space._outer_cell_coords(args, element_index, coords),
                node_index_in_elt - NODES_PER_ELEMENT,
            )

        return cache.get_func(trace_element_outer_weight_gradient, space.name)

    def __eq__(self, other: "NodalFunctionSpaceTrace") -> bool:
        return self._space == other._space