Lingteng Qiu (邱陵腾)
rm assets & wheels
434b0b0
from copy import deepcopy
from re import S
from typing import Union
import torch
import open3d as o3d
import open3d.visualization.gui as gui
import open3d.visualization.rendering as rendering
from core.opt import MeshOptimizer
import numpy as np
from util.func import to_numpy
from util.snapshot import Snapshot
class Viewer:
def __init__(
self,
target_vertices: torch.Tensor, # V,3
target_faces: torch.Tensor, # F,3
snapshots: list[Snapshot],
vertex_colors: dict[str, list[np.array]],
):
self._target_vertices = target_vertices
self._target_faces = target_faces
self._snapshots = snapshots
self._vertex_colors = vertex_colors
self._window_o3 = gui.Application.instance.create_window(
"Continuous Remeshing", 1000, 800
)
self._window_o3.set_on_layout(self._layout)
self._scene_widget = gui.SceneWidget()
self._scene_widget.scene = rendering.Open3DScene(self._window_o3.renderer)
self._scene_widget.scene.set_background([0.5, 0.5, 0.5, 1])
bbox = o3d.geometry.AxisAlignedBoundingBox([-1, -1, -1], [1, 1, 1])
self._scene_widget.setup_camera(60, bbox, [0, 0, 0])
self._window_o3.add_child(self._scene_widget)
self._scene_widget.set_on_mouse(self._on_mouse)
# lights
self._scene_widget.scene.scene.enable_sun_light(False)
self._scene_widget.scene.scene.set_indirect_light(
gui.Application.instance.resource_path + "/park2"
)
self._scene_widget.scene.scene.enable_indirect_light(True)
self._scene_widget.scene.scene.set_indirect_light_intensity(45000)
self._scene_widget.scene.show_skybox(False)
# right panel
margins = gui.Margins(*[self._window_o3.theme.default_margin] * 4)
spacing = self._window_o3.theme.default_layout_spacing
self._right_panel = gui.Vert(spacing, margins)
def make_checkbox(name, checked):
checkbox = gui.Checkbox(name)
checkbox.checked = checked
checkbox.set_on_checked(lambda *args: self._update())
self._right_panel.add_child(checkbox)
return checkbox
self._mesh_checkbox = make_checkbox("Show Mesh", True)
self._colorbox = gui.Combobox()
for item in [
"Gray",
"Relative Velocity nu",
"Reference Edge Length l_ref",
*self._vertex_colors.keys(),
]:
self._colorbox.add_item(item)
self._colorbox.set_on_selection_changed(lambda *args: self._update())
self._right_panel.add_child(self._colorbox)
self._clim_slider = gui.Slider(gui.Slider.DOUBLE)
self._clim_slider.double_value = 0.2
self._clim_slider.set_limits(1e-3, 1)
self._clim_slider.set_on_value_changed(lambda *args: self._update())
self._right_panel.add_child(self._clim_slider)
self._edges_checkbox = make_checkbox("Show Edges", True)
self._target_mesh_checkbox = make_checkbox("Show Target Mesh", False)
self._right_panel.add_child(gui.Label("Ctrl-Click Mesh For Plot!"))
self._target_edges_checkbox = make_checkbox("Show Target Edges", False)
self._positions_checkbox = make_checkbox("Plot Positions", False)
self._gradients_checkbox = make_checkbox("Plot Gradients", False)
self._m1_checkbox = make_checkbox("Plot m1", False)
self._m2_checkbox = make_checkbox("Plot m2", False)
self._nu_checkbox = make_checkbox("Plot nu", True)
self._lref_checkbox = make_checkbox("Plot l_ref", True)
self._window_o3.add_child(self._right_panel)
# bottom panel
self._bottom_panel = gui.VGrid(cols=2, spacing=spacing)
self._snapshot_slider = gui.Slider(gui.Slider.INT)
self._snapshot_slider.int_value = len(self._snapshots) - 1
self._snapshot_slider.set_limits(0, len(self._snapshots) - 1)
self._snapshot_slider.set_on_value_changed(lambda *args: self._update())
self._bottom_panel.add_child(self._snapshot_slider)
self._window_o3.add_child(self._bottom_panel)
self._update()
def _update(self):
snapshot = self._snapshots[self._snapshot_slider.int_value]
self._scene_widget.scene.clear_geometry()
self._scene_widget.scene.show_axes(True)
MaterialType = (
rendering.MaterialRecord
if hasattr(rendering, "MaterialRecord")
else rendering.Material
)
def add_mesh(
name, color, vertices, faces, show_mesh, show_edges, vertex_colors=None
):
vertices_np = vertices.detach().cpu().numpy()
vertices_o3 = o3d.utility.Vector3dVector(vertices_np)
faces_o3 = o3d.utility.Vector3iVector(faces.type(torch.int32).cpu().numpy())
triangleMesh = o3d.geometry.TriangleMesh(vertices_o3, faces_o3)
triangleMesh.compute_vertex_normals()
if vertex_colors is not None:
vertex_colors_np = to_numpy(vertex_colors)
triangleMesh.vertex_colors = o3d.utility.Vector3dVector(
vertex_colors_np
)
if show_mesh:
material = MaterialType()
if vertex_colors is None:
material.shader = "defaultLit"
material.base_color = color
self._scene_widget.scene.add_geometry(
f"{name}_triangleMesh", triangleMesh, material
)
if show_edges:
edges_material = MaterialType()
edges_material.base_color = [0, 0, 0, 1]
edges_material.shader = "unlitLine"
edges_lineset = o3d.geometry.LineSet.create_from_triangle_mesh(
triangleMesh
)
edges_lineset.points = o3d.utility.Vector3dVector(
vertices_np + 1e-4 * np.asarray(triangleMesh.vertex_normals)
)
self._scene_widget.scene.add_geometry(
f"{name}_edges_lineset", edges_lineset, edges_material
)
clim = self._clim_slider.double_value
if self._colorbox.selected_text == "Relative Velocity nu" and isinstance(
snapshot.optimizer, MeshOptimizer
):
vertex_colors = snapshot.optimizer._nu
elif (
self._colorbox.selected_text == "Reference Edge Length l_ref"
and isinstance(snapshot.optimizer, MeshOptimizer)
):
vertex_colors = snapshot.optimizer._ref_len
elif self._colorbox.selected_text in self._vertex_colors.keys():
vertex_colors = self._vertex_colors[self._colorbox.selected_text][
self._snapshot_slider.int_value
]
else:
vertex_colors = None
if vertex_colors is not None:
c = (to_numpy(vertex_colors) / clim).clip(0, 1)
vertex_colors = np.stack((c, 1 - c, np.zeros_like(c)), axis=-1)
add_mesh(
"mesh",
[0.5, 0.5, 0.5, 1],
snapshot.vertices,
snapshot.faces,
self._mesh_checkbox.checked,
self._edges_checkbox.checked,
vertex_colors,
)
add_mesh(
"target",
[0.5, 0.5, 1, 1],
self._target_vertices,
self._target_faces,
self._target_mesh_checkbox.checked,
self._target_edges_checkbox.checked,
)
def _layout(self, layout_context):
r = self._window_o3.content_rect
h = self._bottom_panel.calc_preferred_size(
layout_context, gui.Widget.Constraints()
).height
self._bottom_panel.frame = gui.Rect(0, r.height - h, r.width, h)
r.height -= h
w = 250
self._right_panel.frame = gui.Rect(r.width - w, 0, w, r.height)
r.width -= w
self._scene_widget.frame = r
def _on_mouse(self, event):
if event.type == gui.MouseEvent.Type.BUTTON_DOWN and event.is_modifier_down(
gui.KeyModifier.CTRL
):
self._hit_test(event)
return gui.Widget.EventCallbackResult.HANDLED
return gui.Widget.EventCallbackResult.IGNORED
def _hit_test(self, event):
def depth_callback(depth_image):
f = self._scene_widget.frame
depth = np.asarray(depth_image)[event.y - f.y, event.x - f.x]
if depth == 1.0: # clicked on nothing (i.e. the far plane)
return
# need to flip y https://github.com/isl-org/Open3D/issues/4244
pos = self._scene_widget.scene.camera.unproject(
event.x - f.x, f.height - event.y, depth, f.width, f.height
)
opt = self._snapshots[self._snapshot_slider.int_value].optimizer
vertex = (
(opt.vertices.cpu() - torch.tensor(pos)).norm(dim=-1).argmin().item()
)
gui.Application.instance.post_to_main_thread(
self._window_o3, lambda: self._on_click(pos, vertex)
)
self._scene_widget.scene.scene.render_to_depth_image(depth_callback)
def _on_click(self, pos, vertex):
self._show_plot(vertex)
def _show_plot(self, vertex):
ind = self._snapshot_slider.int_value
device = self._snapshots[0].vertices.device
vert_ind = torch.zeros(len(self._snapshots), dtype=torch.long, device=device)
vert_ind[ind] = vertex
def trace(i):
nonlocal cur_pos
vert_ind[i] = (
(self._snapshots[i].vertices - cur_pos).norm(dim=-1).argmin(dim=0)
)
cur_pos = self._snapshots[i].vertices[vert_ind[i]]
cur_pos = self._snapshots[ind].vertices[vertex]
for i in range(ind - 1, -1, -1):
trace(i)
cur_pos = self._snapshots[ind].vertices[vertex]
for i in range(ind + 1, len(self._snapshots)):
trace(i)
dims = slice(None, None)
grad_scale = 100
from cycler import cycler
import matplotlib.pyplot as plt
plt.gca().set_prop_cycle(cycler(linestyle=["-", "--", ":"][dims]))
def extract(prop):
values = [
prop(self._snapshots[i].optimizer, vert_ind[i])
for i in range(0, len(vert_ind))
]
if isinstance(values[0], torch.Tensor):
values = torch.stack(values).cpu()
return values
s = [s.optimizer._step for s in self._snapshots]
if self._positions_checkbox.checked:
plt.plot(s, extract(lambda opt, v: opt.vertices[v, dims]), "b", label="pos")
if self._gradients_checkbox.checked:
plt.plot(
s,
grad_scale * extract(lambda opt, v: opt.vertices.grad[v, dims]),
"k",
label="grad",
)
m1 = extract(lambda opt, v: opt._m1[v])
m2 = extract(lambda opt, v: opt._m2[v])
velocity = m1 / m2[:, None].sqrt().add_(1e-8) # V,3
speed = velocity.norm(dim=-1)
if self._m1_checkbox.checked:
plt.plot(
s,
grad_scale * extract(lambda opt, v: opt._m1[v, dims]),
"r",
label="m1",
)
if self._m2_checkbox.checked:
plt.plot(
s,
grad_scale * extract(lambda opt, v: opt._m2[v].sqrt()),
"-m",
label="m2.sqrt()",
)
if self._nu_checkbox.checked:
plt.plot(s, speed, color="orange", label="speed")
plt.plot(s, extract(lambda opt, v: opt._nu[v]), "-c", label="nu")
if self._lref_checkbox.checked:
plt.plot(
s, extract(lambda opt, v: opt._ref_len[v]), color="gray", label="l_ref"
)
plt.axvline(x=ind, color="k")
plt.legend()
plt.grid()
plt.show()
def show(
target_vertices: torch.Tensor, # V,3
target_faces: torch.Tensor, # F,3
snapshots: list[Snapshot],
vertex_colors: dict[str, list[np.array]] = {},
):
for vc in vertex_colors.values():
assert [c.shape[0] for c in vc] == [s.vertices.shape[0] for s in snapshots]
gui.Application.instance.initialize()
viewer = Viewer(target_vertices, target_faces, snapshots, vertex_colors)
gui.Application.instance.run()