File size: 10,863 Bytes
66c9c8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
from typing import Optional, Type

import warp as wp

from warp.fem.types import ElementIndex
from warp.fem.geometry import Geometry, DeformedGeometry
from warp.fem import cache


class SpaceTopology:
    """
    Interface class for defining the topology of a function space.

    The topology only considers the indices of the nodes in each element, and as such,
    the connectivity pattern of the function space.
    It does not specify the actual location of the nodes within the elements, or the valuation function.
    """

    dimension: int
    """Embedding dimension of the function space"""

    NODES_PER_ELEMENT: int
    """Number of interpolation nodes per element of the geometry.
    
    .. note:: This will change to be defined per-element in future versions
    """

    @wp.struct
    class TopologyArg:
        """Structure containing arguments to be passed to device functions"""

        pass

    def __init__(self, geometry: Geometry, nodes_per_element: int):
        self._geometry = geometry
        self.dimension = geometry.dimension
        self.NODES_PER_ELEMENT = wp.constant(nodes_per_element)
        self.ElementArg = geometry.CellArg

    @property
    def geometry(self) -> Geometry:
        """Underlying geometry"""
        return self._geometry

    def node_count(self) -> int:
        """Number of nodes in the interpolation basis"""
        raise NotImplementedError

    def topo_arg_value(self, device) -> "TopologyArg":
        """Value of the topology argument structure to be passed to device functions"""
        return SpaceTopology.TopologyArg()

    @property
    def name(self):
        return f"{self.__class__.__name__}_{self.NODES_PER_ELEMENT}"

    def __str__(self):
        return self.name

    @staticmethod
    def element_node_index(
        geo_arg: "ElementArg", topo_arg: "TopologyArg", element_index: ElementIndex, node_index_in_elt: int
    ):
        """Global node index for a given node in a given element"""
        raise NotImplementedError

    def element_node_indices(self, out: Optional[wp.array] = None) -> wp.array:
        """Returns a temporary array containing the global index for each node of each element"""

        NODES_PER_ELEMENT = self.NODES_PER_ELEMENT

        @cache.dynamic_kernel(suffix=self.name)
        def fill_element_node_indices(
            geo_cell_arg: self.geometry.CellArg,
            topo_arg: self.TopologyArg,
            element_node_indices: wp.array2d(dtype=int),
        ):
            element_index = wp.tid()
            for n in range(NODES_PER_ELEMENT):
                element_node_indices[element_index, n] = self.element_node_index(
                    geo_cell_arg, topo_arg, element_index, n
                )

        shape = (self.geometry.cell_count(), NODES_PER_ELEMENT)
        if out is None:
            element_node_indices = wp.empty(
                shape=shape,
                dtype=int,
            )
        else:
            if out.shape != shape or out.dtype != wp.int32:
                raise ValueError(f"Out element node idices array must have shape {shape} and data type 'int32'")
            element_node_indices = out

        wp.launch(
            dim=element_node_indices.shape[0],
            kernel=fill_element_node_indices,
            inputs=[
                self.geometry.cell_arg_value(device=element_node_indices.device),
                self.topo_arg_value(device=element_node_indices.device),
                element_node_indices,
            ],
            device=element_node_indices.device,
        )

        return element_node_indices

    # Interface generating trace space topology

    def trace(self) -> "TraceSpaceTopology":
        """Trace of the function space over lower-dimensional elements of the geometry"""

        return TraceSpaceTopology(self)

    @property
    def is_trace(self) -> bool:
        """Whether this topology is defined on the trace of the geometry"""
        return self.dimension == self.geometry.dimension - 1

    def full_space_topology(self) -> "SpaceTopology":
        """Returns the full space topology from which this topology is derived"""
        return self

    def __eq__(self, other: "SpaceTopology") -> bool:
        """Checks whether two topologies are compatible"""
        return self.geometry == other.geometry and self.name == other.name

    def is_derived_from(self, other: "SpaceTopology") -> bool:
        """Checks whether two topologies are equal, or `self` is the trace of `other`"""
        if self.dimension == other.dimension:
            return self == other
        if self.dimension + 1 == other.dimension:
            return self.full_space_topology() == other
        return False


class TraceSpaceTopology(SpaceTopology):
    """Auto-generated trace topology defining the node indices associated to the geometry sides"""

    def __init__(self, topo: SpaceTopology):
        super().__init__(topo.geometry, 2 * topo.NODES_PER_ELEMENT)

        self._topo = topo
        self.dimension = topo.dimension - 1
        self.ElementArg = topo.geometry.SideArg

        self.TopologyArg = topo.TopologyArg
        self.topo_arg_value = topo.topo_arg_value

        self.inner_cell_index = self._make_inner_cell_index()
        self.outer_cell_index = self._make_outer_cell_index()
        self.neighbor_cell_index = self._make_neighbor_cell_index()

        self.element_node_index = self._make_element_node_index()

    def node_count(self) -> int:
        return self._topo.node_count()

    @property
    def name(self):
        return f"{self._topo.name}_Trace"

    def _make_inner_cell_index(self):
        NODES_PER_ELEMENT = self._topo.NODES_PER_ELEMENT

        @cache.dynamic_func(suffix=self.name)
        def inner_cell_index(args: self.geometry.SideArg, element_index: ElementIndex, node_index_in_elt: int):
            index_in_inner_cell = wp.select(node_index_in_elt < NODES_PER_ELEMENT, -1, node_index_in_elt)
            return self.geometry.side_inner_cell_index(args, element_index), index_in_inner_cell

        return inner_cell_index

    def _make_outer_cell_index(self):
        NODES_PER_ELEMENT = self._topo.NODES_PER_ELEMENT

        @cache.dynamic_func(suffix=self.name)
        def outer_cell_index(args: self.geometry.SideArg, element_index: ElementIndex, node_index_in_elt: int):
            return self.geometry.side_outer_cell_index(args, element_index), node_index_in_elt - NODES_PER_ELEMENT

        return outer_cell_index

    def _make_neighbor_cell_index(self):
        NODES_PER_ELEMENT = self._topo.NODES_PER_ELEMENT

        @cache.dynamic_func(suffix=self.name)
        def neighbor_cell_index(args: self.geometry.SideArg, element_index: ElementIndex, node_index_in_elt: int):
            if node_index_in_elt < NODES_PER_ELEMENT:
                return self.geometry.side_inner_cell_index(args, element_index), node_index_in_elt
            else:
                return (
                    self.geometry.side_outer_cell_index(args, element_index),
                    node_index_in_elt - NODES_PER_ELEMENT,
                )

        return neighbor_cell_index

    def _make_element_node_index(self):
        @cache.dynamic_func(suffix=self.name)
        def trace_element_node_index(
            geo_side_arg: self.geometry.SideArg,
            topo_arg: self._topo.TopologyArg,
            element_index: ElementIndex,
            node_index_in_elt: int,
        ):
            cell_index, index_in_cell = self.neighbor_cell_index(geo_side_arg, element_index, node_index_in_elt)

            geo_cell_arg = self.geometry.side_to_cell_arg(geo_side_arg)
            return self._topo.element_node_index(geo_cell_arg, topo_arg, cell_index, index_in_cell)

        return trace_element_node_index

    def full_space_topology(self) -> SpaceTopology:
        """Returns the full space topology from which this topology is derived"""
        return self._topo

    def __eq__(self, other: "TraceSpaceTopology") -> bool:
        return self._topo == other._topo


class DiscontinuousSpaceTopologyMixin:
    """Helper for defining discontinuous topologies (per-element nodes)"""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.element_node_index = self._make_element_node_index()

    def node_count(self):
        return self.geometry.cell_count() * self.NODES_PER_ELEMENT

    @property
    def name(self):
        return f"{self.geometry.name}_D{self.NODES_PER_ELEMENT}"

    def _make_element_node_index(self):
        NODES_PER_ELEMENT = self.NODES_PER_ELEMENT

        @cache.dynamic_func(suffix=self.name)
        def element_node_index(
            elt_arg: self.geometry.CellArg,
            topo_arg: self.TopologyArg,
            element_index: ElementIndex,
            node_index_in_elt: int,
        ):
            return NODES_PER_ELEMENT * element_index + node_index_in_elt

        return element_node_index


class DiscontinuousSpaceTopology(DiscontinuousSpaceTopologyMixin, SpaceTopology):
    """Topology for generic discontinuous spaces"""

    pass


class DeformedGeometrySpaceTopology(SpaceTopology):
    def __init__(self, geometry: DeformedGeometry, base_topology: SpaceTopology):
        super().__init__(geometry, base_topology.NODES_PER_ELEMENT)

        self.base = base_topology
        self.node_count = self.base.node_count
        self.topo_arg_value = self.base.topo_arg_value
        self.TopologyArg = self.base.TopologyArg

        self.element_node_index = self._make_element_node_index()

    @property
    def name(self):
        return f"{self.base.name}_{self.geometry.field.name}"

    def _make_element_node_index(self):
        @cache.dynamic_func(suffix=self.name)
        def element_node_index(
            elt_arg: self.geometry.CellArg,
            topo_arg: self.TopologyArg,
            element_index: ElementIndex,
            node_index_in_elt: int,
        ):
            return self.base.element_node_index(elt_arg.elt_arg, topo_arg, element_index, node_index_in_elt)

        return element_node_index


def forward_base_topology(topology_class: Type[SpaceTopology], geometry: Geometry, *args, **kwargs) -> SpaceTopology:
    """
    If `geometry` is *not* a :class:`DeformedGeometry`, constructs a normal instance of `topology_class` over `geometry`, forwarding additional arguments.

    If `geometry` *is* a :class:`DeformedGeometry`, constructs an instance of `topology_class` over the base (undeformed) geometry of `geometry`, then warp it
    in a :class:`DeformedGeometrySpaceTopology` forwarding the calls to the underlying topology.
    """

    if isinstance(geometry, DeformedGeometry):
        base_topo = topology_class(geometry.base, *args, **kwargs)
        return DeformedGeometrySpaceTopology(geometry, base_topo)

    return topology_class(geometry, *args, **kwargs)