| from copy import deepcopy |
| from re import S |
| from typing import Union |
| import torch |
| import open3d as o3d |
| import open3d.visualization.gui as gui |
| import open3d.visualization.rendering as rendering |
| from core.opt import MeshOptimizer |
| import numpy as np |
| from util.func import to_numpy |
|
|
| from util.snapshot import Snapshot |
|
|
|
|
| class Viewer: |
| def __init__( |
| self, |
| target_vertices: torch.Tensor, |
| target_faces: torch.Tensor, |
| snapshots: list[Snapshot], |
| vertex_colors: dict[str, list[np.array]], |
| ): |
| self._target_vertices = target_vertices |
| self._target_faces = target_faces |
| self._snapshots = snapshots |
| self._vertex_colors = vertex_colors |
|
|
| self._window_o3 = gui.Application.instance.create_window( |
| "Continuous Remeshing", 1000, 800 |
| ) |
| self._window_o3.set_on_layout(self._layout) |
| self._scene_widget = gui.SceneWidget() |
| self._scene_widget.scene = rendering.Open3DScene(self._window_o3.renderer) |
| self._scene_widget.scene.set_background([0.5, 0.5, 0.5, 1]) |
| bbox = o3d.geometry.AxisAlignedBoundingBox([-1, -1, -1], [1, 1, 1]) |
| self._scene_widget.setup_camera(60, bbox, [0, 0, 0]) |
| self._window_o3.add_child(self._scene_widget) |
| self._scene_widget.set_on_mouse(self._on_mouse) |
|
|
| |
| self._scene_widget.scene.scene.enable_sun_light(False) |
| self._scene_widget.scene.scene.set_indirect_light( |
| gui.Application.instance.resource_path + "/park2" |
| ) |
| self._scene_widget.scene.scene.enable_indirect_light(True) |
| self._scene_widget.scene.scene.set_indirect_light_intensity(45000) |
| self._scene_widget.scene.show_skybox(False) |
|
|
| |
| margins = gui.Margins(*[self._window_o3.theme.default_margin] * 4) |
| spacing = self._window_o3.theme.default_layout_spacing |
| self._right_panel = gui.Vert(spacing, margins) |
|
|
| def make_checkbox(name, checked): |
| checkbox = gui.Checkbox(name) |
| checkbox.checked = checked |
| checkbox.set_on_checked(lambda *args: self._update()) |
| self._right_panel.add_child(checkbox) |
| return checkbox |
|
|
| self._mesh_checkbox = make_checkbox("Show Mesh", True) |
| self._colorbox = gui.Combobox() |
| for item in [ |
| "Gray", |
| "Relative Velocity nu", |
| "Reference Edge Length l_ref", |
| *self._vertex_colors.keys(), |
| ]: |
| self._colorbox.add_item(item) |
| self._colorbox.set_on_selection_changed(lambda *args: self._update()) |
| self._right_panel.add_child(self._colorbox) |
|
|
| self._clim_slider = gui.Slider(gui.Slider.DOUBLE) |
| self._clim_slider.double_value = 0.2 |
| self._clim_slider.set_limits(1e-3, 1) |
| self._clim_slider.set_on_value_changed(lambda *args: self._update()) |
| self._right_panel.add_child(self._clim_slider) |
|
|
| self._edges_checkbox = make_checkbox("Show Edges", True) |
| self._target_mesh_checkbox = make_checkbox("Show Target Mesh", False) |
| self._right_panel.add_child(gui.Label("Ctrl-Click Mesh For Plot!")) |
| self._target_edges_checkbox = make_checkbox("Show Target Edges", False) |
| self._positions_checkbox = make_checkbox("Plot Positions", False) |
| self._gradients_checkbox = make_checkbox("Plot Gradients", False) |
| self._m1_checkbox = make_checkbox("Plot m1", False) |
| self._m2_checkbox = make_checkbox("Plot m2", False) |
| self._nu_checkbox = make_checkbox("Plot nu", True) |
| self._lref_checkbox = make_checkbox("Plot l_ref", True) |
| self._window_o3.add_child(self._right_panel) |
|
|
| |
| self._bottom_panel = gui.VGrid(cols=2, spacing=spacing) |
| self._snapshot_slider = gui.Slider(gui.Slider.INT) |
| self._snapshot_slider.int_value = len(self._snapshots) - 1 |
| self._snapshot_slider.set_limits(0, len(self._snapshots) - 1) |
| self._snapshot_slider.set_on_value_changed(lambda *args: self._update()) |
| self._bottom_panel.add_child(self._snapshot_slider) |
| self._window_o3.add_child(self._bottom_panel) |
|
|
| self._update() |
|
|
| def _update(self): |
| snapshot = self._snapshots[self._snapshot_slider.int_value] |
|
|
| self._scene_widget.scene.clear_geometry() |
|
|
| self._scene_widget.scene.show_axes(True) |
|
|
| MaterialType = ( |
| rendering.MaterialRecord |
| if hasattr(rendering, "MaterialRecord") |
| else rendering.Material |
| ) |
|
|
| def add_mesh( |
| name, color, vertices, faces, show_mesh, show_edges, vertex_colors=None |
| ): |
| vertices_np = vertices.detach().cpu().numpy() |
| vertices_o3 = o3d.utility.Vector3dVector(vertices_np) |
| faces_o3 = o3d.utility.Vector3iVector(faces.type(torch.int32).cpu().numpy()) |
| triangleMesh = o3d.geometry.TriangleMesh(vertices_o3, faces_o3) |
| triangleMesh.compute_vertex_normals() |
| if vertex_colors is not None: |
| vertex_colors_np = to_numpy(vertex_colors) |
| triangleMesh.vertex_colors = o3d.utility.Vector3dVector( |
| vertex_colors_np |
| ) |
|
|
| if show_mesh: |
| material = MaterialType() |
| if vertex_colors is None: |
| material.shader = "defaultLit" |
| material.base_color = color |
| self._scene_widget.scene.add_geometry( |
| f"{name}_triangleMesh", triangleMesh, material |
| ) |
|
|
| if show_edges: |
| edges_material = MaterialType() |
| edges_material.base_color = [0, 0, 0, 1] |
| edges_material.shader = "unlitLine" |
| edges_lineset = o3d.geometry.LineSet.create_from_triangle_mesh( |
| triangleMesh |
| ) |
| edges_lineset.points = o3d.utility.Vector3dVector( |
| vertices_np + 1e-4 * np.asarray(triangleMesh.vertex_normals) |
| ) |
| self._scene_widget.scene.add_geometry( |
| f"{name}_edges_lineset", edges_lineset, edges_material |
| ) |
|
|
| clim = self._clim_slider.double_value |
| if self._colorbox.selected_text == "Relative Velocity nu" and isinstance( |
| snapshot.optimizer, MeshOptimizer |
| ): |
| vertex_colors = snapshot.optimizer._nu |
| elif ( |
| self._colorbox.selected_text == "Reference Edge Length l_ref" |
| and isinstance(snapshot.optimizer, MeshOptimizer) |
| ): |
| vertex_colors = snapshot.optimizer._ref_len |
| elif self._colorbox.selected_text in self._vertex_colors.keys(): |
| vertex_colors = self._vertex_colors[self._colorbox.selected_text][ |
| self._snapshot_slider.int_value |
| ] |
| else: |
| vertex_colors = None |
|
|
| if vertex_colors is not None: |
| c = (to_numpy(vertex_colors) / clim).clip(0, 1) |
| vertex_colors = np.stack((c, 1 - c, np.zeros_like(c)), axis=-1) |
|
|
| add_mesh( |
| "mesh", |
| [0.5, 0.5, 0.5, 1], |
| snapshot.vertices, |
| snapshot.faces, |
| self._mesh_checkbox.checked, |
| self._edges_checkbox.checked, |
| vertex_colors, |
| ) |
| add_mesh( |
| "target", |
| [0.5, 0.5, 1, 1], |
| self._target_vertices, |
| self._target_faces, |
| self._target_mesh_checkbox.checked, |
| self._target_edges_checkbox.checked, |
| ) |
|
|
| def _layout(self, layout_context): |
| r = self._window_o3.content_rect |
|
|
| h = self._bottom_panel.calc_preferred_size( |
| layout_context, gui.Widget.Constraints() |
| ).height |
| self._bottom_panel.frame = gui.Rect(0, r.height - h, r.width, h) |
| r.height -= h |
|
|
| w = 250 |
| self._right_panel.frame = gui.Rect(r.width - w, 0, w, r.height) |
| r.width -= w |
|
|
| self._scene_widget.frame = r |
|
|
| def _on_mouse(self, event): |
| if event.type == gui.MouseEvent.Type.BUTTON_DOWN and event.is_modifier_down( |
| gui.KeyModifier.CTRL |
| ): |
| self._hit_test(event) |
| return gui.Widget.EventCallbackResult.HANDLED |
| return gui.Widget.EventCallbackResult.IGNORED |
|
|
| def _hit_test(self, event): |
| def depth_callback(depth_image): |
| f = self._scene_widget.frame |
| depth = np.asarray(depth_image)[event.y - f.y, event.x - f.x] |
| if depth == 1.0: |
| return |
| |
| pos = self._scene_widget.scene.camera.unproject( |
| event.x - f.x, f.height - event.y, depth, f.width, f.height |
| ) |
|
|
| opt = self._snapshots[self._snapshot_slider.int_value].optimizer |
| vertex = ( |
| (opt.vertices.cpu() - torch.tensor(pos)).norm(dim=-1).argmin().item() |
| ) |
|
|
| gui.Application.instance.post_to_main_thread( |
| self._window_o3, lambda: self._on_click(pos, vertex) |
| ) |
|
|
| self._scene_widget.scene.scene.render_to_depth_image(depth_callback) |
|
|
| def _on_click(self, pos, vertex): |
| self._show_plot(vertex) |
|
|
| def _show_plot(self, vertex): |
| ind = self._snapshot_slider.int_value |
| device = self._snapshots[0].vertices.device |
| vert_ind = torch.zeros(len(self._snapshots), dtype=torch.long, device=device) |
| vert_ind[ind] = vertex |
|
|
| def trace(i): |
| nonlocal cur_pos |
| vert_ind[i] = ( |
| (self._snapshots[i].vertices - cur_pos).norm(dim=-1).argmin(dim=0) |
| ) |
| cur_pos = self._snapshots[i].vertices[vert_ind[i]] |
|
|
| cur_pos = self._snapshots[ind].vertices[vertex] |
| for i in range(ind - 1, -1, -1): |
| trace(i) |
|
|
| cur_pos = self._snapshots[ind].vertices[vertex] |
| for i in range(ind + 1, len(self._snapshots)): |
| trace(i) |
|
|
| dims = slice(None, None) |
|
|
| grad_scale = 100 |
|
|
| from cycler import cycler |
| import matplotlib.pyplot as plt |
|
|
| plt.gca().set_prop_cycle(cycler(linestyle=["-", "--", ":"][dims])) |
|
|
| def extract(prop): |
| values = [ |
| prop(self._snapshots[i].optimizer, vert_ind[i]) |
| for i in range(0, len(vert_ind)) |
| ] |
| if isinstance(values[0], torch.Tensor): |
| values = torch.stack(values).cpu() |
| return values |
|
|
| s = [s.optimizer._step for s in self._snapshots] |
|
|
| if self._positions_checkbox.checked: |
| plt.plot(s, extract(lambda opt, v: opt.vertices[v, dims]), "b", label="pos") |
| if self._gradients_checkbox.checked: |
| plt.plot( |
| s, |
| grad_scale * extract(lambda opt, v: opt.vertices.grad[v, dims]), |
| "k", |
| label="grad", |
| ) |
|
|
| m1 = extract(lambda opt, v: opt._m1[v]) |
| m2 = extract(lambda opt, v: opt._m2[v]) |
| velocity = m1 / m2[:, None].sqrt().add_(1e-8) |
| speed = velocity.norm(dim=-1) |
| if self._m1_checkbox.checked: |
| plt.plot( |
| s, |
| grad_scale * extract(lambda opt, v: opt._m1[v, dims]), |
| "r", |
| label="m1", |
| ) |
| if self._m2_checkbox.checked: |
| plt.plot( |
| s, |
| grad_scale * extract(lambda opt, v: opt._m2[v].sqrt()), |
| "-m", |
| label="m2.sqrt()", |
| ) |
| if self._nu_checkbox.checked: |
| plt.plot(s, speed, color="orange", label="speed") |
| plt.plot(s, extract(lambda opt, v: opt._nu[v]), "-c", label="nu") |
| if self._lref_checkbox.checked: |
| plt.plot( |
| s, extract(lambda opt, v: opt._ref_len[v]), color="gray", label="l_ref" |
| ) |
|
|
| plt.axvline(x=ind, color="k") |
| plt.legend() |
| plt.grid() |
| plt.show() |
|
|
|
|
| def show( |
| target_vertices: torch.Tensor, |
| target_faces: torch.Tensor, |
| snapshots: list[Snapshot], |
| vertex_colors: dict[str, list[np.array]] = {}, |
| ): |
| for vc in vertex_colors.values(): |
| assert [c.shape[0] for c in vc] == [s.vertices.shape[0] for s in snapshots] |
|
|
| gui.Application.instance.initialize() |
| viewer = Viewer(target_vertices, target_faces, snapshots, vertex_colors) |
| gui.Application.instance.run() |
|
|