Spaces:
Sleeping
Sleeping
| import matplotlib.pyplot as plt | |
| from mesh_model.mesh_analysis.quadmesh_analysis import QuadMeshTopoAnalysis | |
| from mesh_model.mesh_struct.mesh_elements import Dart, Node | |
| from mesh_model.mesh_struct.mesh import Mesh | |
| import numpy as np | |
| from mesh_model.reader import read_gmsh | |
| def plot_mesh(mesh: Mesh, debug=False, scores=False) -> None: | |
| """ | |
| Plot a mesh using matplotlib | |
| :param mesh: a Mesh | |
| :param debug: debug mode to plot darts ID and nodes ID | |
| """ | |
| if scores: | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| else: | |
| fig, ax = plt.subplots(figsize=(10, 10)) | |
| subplot_mesh(mesh, debug=debug, scores=scores) | |
| plt.tight_layout() | |
| plt.savefig("bunny.png", dpi=300, bbox_inches="tight", pad_inches=0) | |
| plt.show(block=True) | |
| def save_mesh_plot(mesh: Mesh, filename: str, debug=False, scores=True) -> None: | |
| if scores: | |
| fig, ax = plt.subplots(figsize=(10, 11)) | |
| else: | |
| fig, ax = plt.subplots(figsize=(10, 10)) | |
| subplot_mesh(mesh, debug=debug, scores=scores) | |
| plt.savefig(filename, dpi=300, bbox_inches="tight", pad_inches=0) | |
| def subplot_mesh(mesh: Mesh, debug=False, id=None, scores=False, irregularities=False) -> None: | |
| """ | |
| Plot a mesh using matplotlib for subplots with many meshes | |
| :param mesh: a Mesh | |
| """ | |
| faces = mesh.active_faces() | |
| nodes = mesh.active_nodes() | |
| nodes = np.array([list[:2] for list in nodes]) | |
| active_darts = mesh.active_darts() | |
| d = Dart(mesh, active_darts[0][0]) | |
| d1 = d.get_beta(1) | |
| d11 = d1.get_beta(1) | |
| d111 = d11.get_beta(1) | |
| tri, quad = False, False | |
| if d111 == d: | |
| tri=True | |
| else: | |
| quad=True | |
| if tri: | |
| for dart_id in faces: | |
| d1 = Dart(mesh, dart_id) | |
| d2 = d1.get_beta(1) | |
| d3 = d2.get_beta(1) | |
| n1 = d1.get_node() | |
| n2 = d2.get_node() | |
| n3 = d3.get_node() | |
| # Nodes coordinates | |
| p1 = np.array([n1.x(), n1.y()]) | |
| p2 = np.array([n2.x(), n2.y()]) | |
| p3 = np.array([n3.x(), n3.y()]) | |
| polygon = np.array([(n1.x(), n1.y()), (n2.x(), n2.y()), (n3.x(), n3.y()), (n1.x(), n1.y())]) | |
| plt.plot(polygon[:, 0], polygon[:, 1], 'k-') | |
| if debug: | |
| # Plot darts ID | |
| mid1 = (p1 + p2) / 2 | |
| mid2 = (p2 + p3) / 2 | |
| mid3 = (p3 + p1) / 2 | |
| centroid = (p1 + p2 + p3) / 3 | |
| pos1 = mid1 +0.2* (centroid - mid1) | |
| pos2 = mid2 +0.2* (centroid - mid2) | |
| pos3 = mid3 +0.2* (centroid - mid3) | |
| plt.text(*pos1, f"{d1.id}", color='blue', fontsize=8, ha='center', va='center') | |
| plt.text(*pos2, f"{d2.id}", color='blue', fontsize=8, ha='center', va='center') | |
| plt.text(*pos3, f"{d3.id}", color='blue', fontsize=8, ha='center', va='center') | |
| if debug: | |
| # Plot nodes ID | |
| n_id =0 | |
| for n_info in mesh.nodes: | |
| if n_info[2] >=0: | |
| plt.text(n_info[0] + 0.03, n_info[1] - 0.02, f"{n_id}", fontsize=10, color='red', ha='right', va='top') | |
| n_id+=1 | |
| elif quad: | |
| for dart_id in faces: | |
| d1 = Dart(mesh, dart_id) | |
| d2 = d1.get_beta(1) | |
| d3 = d2.get_beta(1) | |
| d4 = d3.get_beta(1) | |
| n1 = d1.get_node() | |
| n2 = d2.get_node() | |
| n3 = d3.get_node() | |
| n4 = d4.get_node() | |
| polygon = np.array([(n1.x(), n1.y()), (n2.x(), n2.y()), (n3.x(), n3.y()), (n4.x(), n4.y()), (n1.x(), n1.y())]) | |
| plt.plot(polygon[:, 0], polygon[:, 1], 'k-', zorder=1) | |
| # plt.fill(polygon[:, 0], polygon[:, 1], facecolor="#B2DFDB", edgecolor="k", zorder=1) | |
| if debug: | |
| #Plot darts ID | |
| # Nodes coordinates | |
| p1 = np.array([n1.x(), n1.y()]) | |
| p2 = np.array([n2.x(), n2.y()]) | |
| p3 = np.array([n3.x(), n3.y()]) | |
| p4 = np.array([n4.x(), n4.y()]) | |
| mid1 = (p1 + p2) / 2 | |
| mid2 = (p2 + p3) / 2 | |
| mid3 = (p3 + p4) / 2 | |
| mid4 = (p4 + p1) / 2 | |
| centroid = (p1 + p2 + p3 + p4) / 4 | |
| pos1 = mid1 + 0.2 * (centroid - mid1) | |
| pos2 = mid2 + 0.2 * (centroid - mid2) | |
| pos3 = mid3 + 0.2 * (centroid - mid3) | |
| pos4 = mid4 + 0.2 * (centroid - mid4) | |
| plt.text(*pos1, f"{d1.id}", color='blue', fontsize=8, ha='center', va='center') | |
| plt.text(*pos2, f"{d2.id}", color='blue', fontsize=8, ha='center', va='center') | |
| plt.text(*pos3, f"{d3.id}", color='blue', fontsize=8, ha='center', va='center') | |
| plt.text(*pos4, f"{d4.id}", color='blue', fontsize=8, ha='center', va='center') | |
| if debug : | |
| # Plot nodes ID | |
| n_id = 0 | |
| for n_info in mesh.nodes: | |
| if n_info[2] >= 0: | |
| plt.text(n_info[0] + 0.03, n_info[1] - 0.02, f"{n_id}", fontsize=10, color='red', ha='right', | |
| va='top') | |
| n_id += 1 | |
| if irregularities: | |
| n_id = 0 | |
| for n_info in mesh.nodes: | |
| n = Node(mesh, n_id) | |
| s = -1*n.get_score() | |
| teal = "#008080" | |
| salmon = "#FA8072" | |
| green_pastel = "#77DD77" | |
| if s > 0: | |
| color = teal | |
| radius = 0.25 | |
| show_text = True | |
| elif s < 0: | |
| color = salmon | |
| radius = 0.25 | |
| show_text = True | |
| else: # s == 0 | |
| color = green_pastel | |
| radius = 0.1 # plus petit | |
| show_text = False | |
| # Dessiner le cercle | |
| circle = plt.Circle((n_info[0], n_info[1]), radius=radius, | |
| color=color, zorder=2) | |
| plt.gca().add_patch(circle) | |
| # Ajouter le texte si nécessaire | |
| if show_text: | |
| plt.text(n_info[0], n_info[1], f"{s:.0f}", fontsize=17, | |
| color="white", ha="center", va="center", zorder=3) | |
| n_id += 1 | |
| else: | |
| raise NotImplementedError | |
| teal = "#008080" | |
| # Tracer les sommets | |
| if not irregularities: | |
| plt.plot(nodes[:, 0], nodes[:, 1], 'o', color='teal') # 'ro' pour des points rouges | |
| plt.grid(False) | |
| plt.axis('off') | |
| if scores: | |
| if quad: | |
| ma = QuadMeshTopoAnalysis(mesh) | |
| _, score, ideal_score = ma.global_score() | |
| if id is not None: | |
| plt.title(f"Mesh {id} \n s: {score:.0f} - s*: {ideal_score:.0f}", fontsize=25, pad=10) | |
| else: | |
| plt.title(f"s: {score:.0f} - s*: {ideal_score:.0f}", fontsize=30, pad=10) | |
| def plot_dataset(dataset: list[Mesh]) -> None: | |
| """ | |
| Plot all the meshes of a dataset with subplot. | |
| :param dataset: a list with all the meshes | |
| """ | |
| nb_mesh = len(dataset) | |
| sqrt_mesh = np.sqrt(nb_mesh) | |
| if float(sqrt_mesh).is_integer(): | |
| nb_lines = int(sqrt_mesh) | |
| nb_columns = int(sqrt_mesh) | |
| else: | |
| nb_lines = round(sqrt_mesh) | |
| nb_columns = int(sqrt_mesh)+1 | |
| _, _ = plt.subplots(nb_lines, nb_columns, figsize=(20,22)) # 20,22 ou 27,15 | |
| for i, mesh in enumerate(dataset, 1): | |
| plt.subplot(nb_lines, nb_columns, i) | |
| subplot_mesh(mesh, id=i-1, scores=True) | |
| #plt.title('Mesh {}'.format(i)) | |
| plt.tight_layout() | |
| plt.show() | |
| def save_dataset_plot(dataset: list[Mesh], filename: str) -> None: | |
| """ | |
| Plot all the meshes of a dataset with subplot. | |
| :param dataset: a list with all the meshes | |
| :param filename: the name of the file | |
| """ | |
| nb_mesh = len(dataset) | |
| if nb_mesh == 1: | |
| save_mesh_plot(dataset[0], filename, scores=True) | |
| else: | |
| sqrt_mesh = np.sqrt(nb_mesh) | |
| if float(sqrt_mesh).is_integer(): | |
| nb_lines = int(sqrt_mesh) | |
| nb_columns = int(sqrt_mesh) | |
| else: | |
| nb_lines = round(sqrt_mesh) | |
| nb_columns = int(sqrt_mesh)+1 | |
| _, _ = plt.subplots(nb_lines, nb_columns, figsize=(20,22)) # 20,22 ou 27, 15 | |
| for i, mesh in enumerate(dataset, 1): | |
| plt.subplot(nb_lines, nb_columns, i) | |
| subplot_mesh(mesh, id=i-1, scores=True) | |
| plt.tight_layout() | |
| plt.savefig(filename, dpi=300, bbox_inches="tight", pad_inches=0) | |
| def dataset_plt(dataset: list[Mesh]): | |
| """ | |
| Plot all the meshes of a dataset with subplot. | |
| :param dataset: a list with all the meshes | |
| """ | |
| nb_mesh = len(dataset) | |
| sqrt_mesh = np.sqrt(nb_mesh) | |
| if float(sqrt_mesh).is_integer(): | |
| nb_lines = int(sqrt_mesh) | |
| nb_columns = int(sqrt_mesh) | |
| else: | |
| nb_lines = round(sqrt_mesh) | |
| nb_columns = int(sqrt_mesh) +1 | |
| fig, _ = plt.subplots(nb_lines, nb_columns) | |
| for i, mesh in enumerate(dataset, 1): | |
| plt.subplot(nb_lines, nb_columns, i) | |
| subplot_mesh(mesh) | |
| plt.title('Mesh {}'.format(i)) | |
| plt.tight_layout() | |
| return fig | |