| from itertools import product |
|
|
| import jax.numpy as jnp |
| from compas.utilities import geometric_key, pairwise |
| from jax_fdm.datastructures import FDMesh |
|
|
| from neural_fdm.generators import evaluate_bezier_surface |
|
|
|
|
| def create_mesh_from_tube_generator(generator, config, *args, **kwargs): |
| """ |
| Boundary-supported mesh on a tube. The mesh has group tags. |
| |
| Parameters |
| ---------- |
| generator: `TubePointGenerator` |
| The tube generator. |
| config: `dict` |
| The configuration dictionary. |
| |
| Returns |
| ------- |
| mesh: `jax_fdm.FDMesh` |
| The mesh. |
| """ |
| |
| tube = generator |
| fix_rings = not config["loss"]["shape"]["include"] |
|
|
| |
| points = tube.points_on_tube() |
| points = jnp.reshape(points, (-1, 3)) |
|
|
| num_u = tube.num_levels |
| num_v = tube.num_sides |
| faces = calculate_mesh_tube_faces(num_u - 1, num_v - 1) |
| mesh = FDMesh.from_vertices_and_faces(points, faces) |
|
|
| |
| for vertices in mesh.vertices_on_boundaries(): |
| mesh.vertices_supports(vertices) |
|
|
| |
| |
| mesh.edges_attribute("tag", "cable") |
|
|
| |
| points = jnp.reshape(points, tube.shape_tube) |
| points_rings = points[tube.levels_rings_comp, :, :].tolist() |
| gkey_key = mesh.gkey_key() |
|
|
| num_ring_edges = 0 |
| for points_ring in points_rings: |
| for line in pairwise(points_ring + points_ring[:1]): |
| edge = tuple([gkey_key[geometric_key(pt)] for pt in line]) |
| if not mesh.has_edge(edge): |
| u, v = edge |
| edge = v, u |
| assert mesh.has_edge(edge) |
| mesh.edge_attribute(edge, "tag", "ring") |
| num_ring_edges += 1 |
|
|
| |
| if fix_rings: |
| mesh.vertices_supports(edge) |
|
|
| assert num_ring_edges == tube.num_rings * tube.num_sides |
|
|
| return mesh |
|
|
|
|
| def create_mesh_from_bezier_generator(generator, *args, **kwargs): |
| """ |
| Boundary-supported mesh on bezier surface. |
| |
| Parameters |
| ---------- |
| generator: `BezierSurfacePointGenerator` |
| The bezier surface generator. |
| |
| Returns |
| ------- |
| mesh: `jax_fdm.FDMesh` |
| The mesh. |
| """ |
| |
| bezier = generator.surface |
| u = generator.u |
| v = generator.v |
|
|
| |
| srf_points = bezier.evaluate_points(u, v) |
| srf_points = jnp.reshape(srf_points, (-1, 3)) |
|
|
| num_u = u.shape[0] |
| num_v = v.shape[0] |
| faces = calculate_mesh_grid_faces(num_u - 1, num_v - 1) |
| mesh = FDMesh.from_vertices_and_faces(srf_points, faces) |
|
|
| |
| mesh.vertices_supports(mesh.vertices_on_boundary()) |
|
|
| return mesh |
|
|
|
|
| def create_mesh_from_grid(grid, u, v): |
| """ |
| Boundary-supported mesh on a grid of Bezier control points. |
| |
| Parameters |
| ---------- |
| grid: `PointGrid` |
| The grid of control points. |
| u: `jax.Array` |
| The parameter values along the `u` direction in the range [0, 1]. |
| v: `jax.Array` |
| The parameter values along the `v` direction in the range [0, 1]. |
| |
| Returns |
| ------- |
| mesh: `jax_fdm.FDMesh` |
| The mesh. |
| """ |
| |
| srf_points = calculate_bezier_surface_points_from_grid(grid, u, v) |
|
|
| num_u = u.shape[0] |
| num_v = v.shape[0] |
| faces = calculate_mesh_grid_faces(num_u - 1, num_v - 1) |
| mesh = FDMesh.from_vertices_and_faces(srf_points, faces) |
|
|
| |
| mesh.vertices_supports(mesh.vertices_on_boundary()) |
|
|
| return mesh |
|
|
|
|
| def calculate_mesh_grid_faces(nx, ny): |
| """ |
| Generate the indices of the mesh faces of the grid. |
| |
| Parameters |
| ---------- |
| nx: `int` |
| The number of points along the `x` direction. |
| ny: `int` |
| The number of points along the `y` direction. |
| |
| Returns |
| ------- |
| faces: `list` of `list` of `int` |
| The indices of the mesh faces. |
| """ |
| faces = [] |
| for i, j in product(range(nx), range(ny)): |
| face = [ |
| i * (ny + 1) + j, |
| (i + 1) * (ny + 1) + j, |
| (i + 1) * (ny + 1) + j + 1, |
| i * (ny + 1) + j + 1, |
| ] |
| faces.append(face) |
|
|
| return faces |
|
|
|
|
| def calculate_mesh_tube_faces(nx, ny): |
| """ |
| Generate the indices of the mesh faces of a tube. |
| |
| Parameters |
| ---------- |
| nx: `int` |
| The number of points along the `x` direction. |
| ny: `int` |
| The number of points along the `y` direction. |
| |
| Returns |
| ------- |
| faces: `list` of `list` of `int` |
| The indices of the mesh faces. |
| """ |
| faces = calculate_mesh_grid_faces(nx, ny) |
|
|
| num_xy = (nx + 1) * (ny + 1) |
| starts = range(0, num_xy, ny + 1) |
| ends = range(ny, num_xy + ny, ny + 1) |
|
|
| for (a, b), (d, c) in zip(pairwise(starts), pairwise(ends)): |
| face = [d, c, b, a] |
| faces.append(face) |
|
|
| return faces |
|
|
|
|
| def calculate_bezier_surface_points_from_grid(grid, u, v): |
| """ |
| Evaluate points on a Bezier surface from a grid of control points. |
| |
| Parameters |
| ---------- |
| grid: `PointGrid` |
| The grid of control points. |
| u: `jax.Array` |
| The parameter values along the `u` direction in the range [0, 1]. |
| v: `jax.Array` |
| The parameter values along the `v` direction in the range [0, 1]. |
| |
| Returns |
| ------- |
| points: `jax.Array` |
| The points on the surface. |
| """ |
| |
| control_points = grid.points() |
|
|
| |
| surface_points = evaluate_bezier_surface(control_points, u, v) |
|
|
| return jnp.reshape(surface_points, (-1, 3)) |
|
|