vae-fdm / src /neural_fdm /mesh.py
Efradeca's picture
Upload folder using huggingface_hub
fc7d689 verified
from itertools import product
import jax.numpy as jnp
from compas.utilities import geometric_key, pairwise
from jax_fdm.datastructures import FDMesh
from neural_fdm.generators import evaluate_bezier_surface
def create_mesh_from_tube_generator(generator, config, *args, **kwargs):
"""
Boundary-supported mesh on a tube. The mesh has group tags.
Parameters
----------
generator: `TubePointGenerator`
The tube generator.
config: `dict`
The configuration dictionary.
Returns
-------
mesh: `jax_fdm.FDMesh`
The mesh.
"""
# shorthands
tube = generator
fix_rings = not config["loss"]["shape"]["include"]
# generate base FD Mesh
points = tube.points_on_tube()
points = jnp.reshape(points, (-1, 3))
num_u = tube.num_levels
num_v = tube.num_sides
faces = calculate_mesh_tube_faces(num_u - 1, num_v - 1)
mesh = FDMesh.from_vertices_and_faces(points, faces)
# define structural system
for vertices in mesh.vertices_on_boundaries():
mesh.vertices_supports(vertices)
# tag edges as either rings or cables
# first assume all edges ar cables
mesh.edges_attribute("tag", "cable")
# then, search for ring edges by geometric key
points = jnp.reshape(points, tube.shape_tube)
points_rings = points[tube.levels_rings_comp, :, :].tolist()
gkey_key = mesh.gkey_key()
num_ring_edges = 0
for points_ring in points_rings:
for line in pairwise(points_ring + points_ring[:1]):
edge = tuple([gkey_key[geometric_key(pt)] for pt in line])
if not mesh.has_edge(edge):
u, v = edge
edge = v, u
assert mesh.has_edge(edge)
mesh.edge_attribute(edge, "tag", "ring")
num_ring_edges += 1
# NOTE: fix ring supports if no shape error in loss
if fix_rings:
mesh.vertices_supports(edge)
assert num_ring_edges == tube.num_rings * tube.num_sides
return mesh
def create_mesh_from_bezier_generator(generator, *args, **kwargs):
"""
Boundary-supported mesh on bezier surface.
Parameters
----------
generator: `BezierSurfacePointGenerator`
The bezier surface generator.
Returns
-------
mesh: `jax_fdm.FDMesh`
The mesh.
"""
# unpack parameters
bezier = generator.surface
u = generator.u
v = generator.v
# generate base FD Mesh
srf_points = bezier.evaluate_points(u, v)
srf_points = jnp.reshape(srf_points, (-1, 3))
num_u = u.shape[0]
num_v = v.shape[0]
faces = calculate_mesh_grid_faces(num_u - 1, num_v - 1)
mesh = FDMesh.from_vertices_and_faces(srf_points, faces)
# define structural system
mesh.vertices_supports(mesh.vertices_on_boundary())
return mesh
def create_mesh_from_grid(grid, u, v):
"""
Boundary-supported mesh on a grid of Bezier control points.
Parameters
----------
grid: `PointGrid`
The grid of control points.
u: `jax.Array`
The parameter values along the `u` direction in the range [0, 1].
v: `jax.Array`
The parameter values along the `v` direction in the range [0, 1].
Returns
-------
mesh: `jax_fdm.FDMesh`
The mesh.
"""
# generate base FD Mesh
srf_points = calculate_bezier_surface_points_from_grid(grid, u, v)
num_u = u.shape[0]
num_v = v.shape[0]
faces = calculate_mesh_grid_faces(num_u - 1, num_v - 1)
mesh = FDMesh.from_vertices_and_faces(srf_points, faces)
# define structural system
mesh.vertices_supports(mesh.vertices_on_boundary())
return mesh
def calculate_mesh_grid_faces(nx, ny):
"""
Generate the indices of the mesh faces of the grid.
Parameters
----------
nx: `int`
The number of points along the `x` direction.
ny: `int`
The number of points along the `y` direction.
Returns
-------
faces: `list` of `list` of `int`
The indices of the mesh faces.
"""
faces = []
for i, j in product(range(nx), range(ny)):
face = [
i * (ny + 1) + j,
(i + 1) * (ny + 1) + j,
(i + 1) * (ny + 1) + j + 1,
i * (ny + 1) + j + 1,
]
faces.append(face)
return faces
def calculate_mesh_tube_faces(nx, ny):
"""
Generate the indices of the mesh faces of a tube.
Parameters
----------
nx: `int`
The number of points along the `x` direction.
ny: `int`
The number of points along the `y` direction.
Returns
-------
faces: `list` of `list` of `int`
The indices of the mesh faces.
"""
faces = calculate_mesh_grid_faces(nx, ny)
num_xy = (nx + 1) * (ny + 1)
starts = range(0, num_xy, ny + 1)
ends = range(ny, num_xy + ny, ny + 1)
for (a, b), (d, c) in zip(pairwise(starts), pairwise(ends)):
face = [d, c, b, a]
faces.append(face)
return faces
def calculate_bezier_surface_points_from_grid(grid, u, v):
"""
Evaluate points on a Bezier surface from a grid of control points.
Parameters
----------
grid: `PointGrid`
The grid of control points.
u: `jax.Array`
The parameter values along the `u` direction in the range [0, 1].
v: `jax.Array`
The parameter values along the `v` direction in the range [0, 1].
Returns
-------
points: `jax.Array`
The points on the surface.
"""
# generate control points grid
control_points = grid.points()
# sample surface points on bezier
surface_points = evaluate_bezier_surface(control_points, u, v)
return jnp.reshape(surface_points, (-1, 3))