qbhf2 commited on
Commit
ec5eedb
·
1 Parent(s): 407f992

new file: custom_utils/smpl_visualizer.py

Browse files

modified: custom_utils/startup.sh - new installs

custom_utils/smpl_visualizer.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: disable-error-code="assignment"
2
+ #
3
+ # Asymmetric properties are supported in Pyright, but not yet in mypy.
4
+ # - https://github.com/python/mypy/issues/3004
5
+ # - https://github.com/python/mypy/pull/11643
6
+ """SMPL visualizer (Skinned Mesh)
7
+
8
+ Requires a .npz model file.
9
+
10
+ See here for download instructions:
11
+ https://github.com/vchoutas/smplx?tab=readme-ov-file#downloading-the-model
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import time
17
+ from dataclasses import dataclass
18
+ from pathlib import Path
19
+ from typing import List, Tuple
20
+
21
+ import numpy as np
22
+ import tyro
23
+
24
+ import viser
25
+ import viser.transforms as tf
26
+
27
+
28
+ @dataclass(frozen=True)
29
+ class SmplFkOutputs:
30
+ T_world_joint: np.ndarray # (num_joints, 4, 4)
31
+ T_parent_joint: np.ndarray # (num_joints, 4, 4)
32
+
33
+
34
+ class SmplHelper:
35
+ """Helper for models in the SMPL family, implemented in numpy. Does not include blend skinning."""
36
+
37
+ def __init__(self, model_path: Path) -> None:
38
+ assert model_path.suffix.lower() == ".npz", "Model should be an .npz file!"
39
+ body_dict = dict(**np.load(model_path, allow_pickle=True))
40
+
41
+ self.J_regressor = body_dict["J_regressor"]
42
+ self.weights = body_dict["weights"]
43
+ self.v_template = body_dict["v_template"]
44
+ self.posedirs = body_dict["posedirs"]
45
+ self.shapedirs = body_dict["shapedirs"]
46
+ self.faces = body_dict["f"]
47
+
48
+ self.num_joints: int = self.weights.shape[-1]
49
+ self.num_betas: int = self.shapedirs.shape[-1]
50
+ self.parent_idx: np.ndarray = body_dict["kintree_table"][0]
51
+
52
+ def get_tpose(self, betas: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
53
+ # Get shaped vertices + joint positions, when all local poses are identity.
54
+ v_tpose = self.v_template + np.einsum("vxb,b->vx", self.shapedirs, betas)
55
+ j_tpose = np.einsum("jv,vx->jx", self.J_regressor, v_tpose)
56
+ return v_tpose, j_tpose
57
+
58
+ def get_outputs(
59
+ self, betas: np.ndarray, joint_rotmats: np.ndarray
60
+ ) -> SmplFkOutputs:
61
+ # Get shaped vertices + joint positions, when all local poses are identity.
62
+ v_tpose = self.v_template + np.einsum("vxb,b->vx", self.shapedirs, betas)
63
+ j_tpose = np.einsum("jv,vx->jx", self.J_regressor, v_tpose)
64
+
65
+ # Local SE(3) transforms.
66
+ T_parent_joint = np.zeros((self.num_joints, 4, 4)) + np.eye(4)
67
+ T_parent_joint[:, :3, :3] = joint_rotmats
68
+ T_parent_joint[0, :3, 3] = j_tpose[0]
69
+ T_parent_joint[1:, :3, 3] = j_tpose[1:] - j_tpose[self.parent_idx[1:]]
70
+
71
+ # Forward kinematics.
72
+ T_world_joint = T_parent_joint.copy()
73
+ for i in range(1, self.num_joints):
74
+ T_world_joint[i] = T_world_joint[self.parent_idx[i]] @ T_parent_joint[i]
75
+
76
+ return SmplFkOutputs(T_world_joint, T_parent_joint)
77
+
78
+
79
+ def main(model_path: Path) -> None:
80
+ server = viser.ViserServer()
81
+ server.scene.set_up_direction("+y")
82
+
83
+ # Main loop. We'll read pose/shape from the GUI elements, compute the mesh,
84
+ # and then send the updated mesh in a loop.
85
+ model = SmplHelper(model_path)
86
+ gui_elements = make_gui_elements(
87
+ server,
88
+ num_betas=model.num_betas,
89
+ num_joints=model.num_joints,
90
+ parent_idx=model.parent_idx,
91
+ )
92
+ v_tpose, j_tpose = model.get_tpose(np.zeros((model.num_betas,)))
93
+ mesh_handle = server.scene.add_mesh_skinned(
94
+ "/human",
95
+ v_tpose,
96
+ model.faces,
97
+ bone_wxyzs=tf.SO3.identity(batch_axes=(model.num_joints,)).wxyz,
98
+ bone_positions=j_tpose,
99
+ skin_weights=model.weights,
100
+ wireframe=gui_elements.gui_wireframe.value,
101
+ color=gui_elements.gui_rgb.value,
102
+ )
103
+ server.scene.add_grid("/grid", position=(0.0, -1.3, 0.0), plane="xz")
104
+
105
+ while True:
106
+ # Do nothing if no change.
107
+ time.sleep(0.02)
108
+ if not gui_elements.changed:
109
+ continue
110
+
111
+ # Shapes changed: update vertices / joint positions.
112
+ if gui_elements.betas_changed:
113
+ v_tpose, j_tpose = model.get_tpose(
114
+ np.array([gui_beta.value for gui_beta in gui_elements.gui_betas])
115
+ )
116
+ mesh_handle.vertices = v_tpose
117
+ mesh_handle.bone_positions = j_tpose
118
+
119
+ gui_elements.changed = False
120
+ gui_elements.betas_changed = False
121
+
122
+ # Render as wireframe?
123
+ mesh_handle.wireframe = gui_elements.gui_wireframe.value
124
+
125
+ # Compute SMPL outputs.
126
+ smpl_outputs = model.get_outputs(
127
+ betas=np.array([x.value for x in gui_elements.gui_betas]),
128
+ joint_rotmats=np.stack(
129
+ [
130
+ tf.SO3.exp(np.array(x.value)).as_matrix()
131
+ for x in gui_elements.gui_joints
132
+ ],
133
+ axis=0,
134
+ ),
135
+ )
136
+
137
+ # Match transform control gizmos to joint positions.
138
+ for i, control in enumerate(gui_elements.transform_controls):
139
+ control.position = smpl_outputs.T_parent_joint[i, :3, 3]
140
+ mesh_handle.bones[i].wxyz = tf.SO3.from_matrix(
141
+ smpl_outputs.T_world_joint[i, :3, :3]
142
+ ).wxyz
143
+ mesh_handle.bones[i].position = smpl_outputs.T_world_joint[i, :3, 3]
144
+
145
+
146
+ @dataclass
147
+ class GuiElements:
148
+ """Structure containing handles for reading from GUI elements."""
149
+
150
+ gui_rgb: viser.GuiInputHandle[Tuple[int, int, int]]
151
+ gui_wireframe: viser.GuiInputHandle[bool]
152
+ gui_betas: List[viser.GuiInputHandle[float]]
153
+ gui_joints: List[viser.GuiInputHandle[Tuple[float, float, float]]]
154
+ transform_controls: List[viser.TransformControlsHandle]
155
+
156
+ changed: bool
157
+ """This flag will be flipped to True whenever any input is changed."""
158
+
159
+ betas_changed: bool
160
+ """This flag will be flipped to True whenever the shape changes."""
161
+
162
+
163
+ def make_gui_elements(
164
+ server: viser.ViserServer,
165
+ num_betas: int,
166
+ num_joints: int,
167
+ parent_idx: np.ndarray,
168
+ ) -> GuiElements:
169
+ """Make GUI elements for interacting with the model."""
170
+
171
+ tab_group = server.gui.add_tab_group()
172
+
173
+ def set_changed(_) -> None:
174
+ out.changed = True # out is defined later!
175
+
176
+ def set_betas_changed(_) -> None:
177
+ out.betas_changed = True
178
+ out.changed = True
179
+
180
+ # GUI elements: mesh settings + visibility.
181
+ with tab_group.add_tab("View", viser.Icon.VIEWFINDER):
182
+ gui_rgb = server.gui.add_rgb("Color", initial_value=(90, 200, 255))
183
+ gui_wireframe = server.gui.add_checkbox("Wireframe", initial_value=False)
184
+ gui_show_controls = server.gui.add_checkbox("Handles", initial_value=True)
185
+ gui_control_size = server.gui.add_slider(
186
+ "Handle size", min=0.0, max=10.0, step=0.01, initial_value=1.0
187
+ )
188
+
189
+ gui_rgb.on_update(set_changed)
190
+ gui_wireframe.on_update(set_changed)
191
+
192
+ @gui_show_controls.on_update
193
+ def _(_):
194
+ for control in transform_controls:
195
+ control.visible = gui_show_controls.value
196
+
197
+ @gui_control_size.on_update
198
+ def _(_):
199
+ for control in transform_controls:
200
+ prefixed_joint_name = control.name
201
+ control.scale = (
202
+ 0.2
203
+ * (0.75 ** prefixed_joint_name.count("/"))
204
+ * gui_control_size.value
205
+ )
206
+
207
+ # GUI elements: shape parameters.
208
+ with tab_group.add_tab("Shape", viser.Icon.BOX):
209
+ gui_reset_shape = server.gui.add_button("Reset Shape")
210
+ gui_random_shape = server.gui.add_button("Random Shape")
211
+
212
+ @gui_reset_shape.on_click
213
+ def _(_):
214
+ for beta in gui_betas:
215
+ beta.value = 0.0
216
+
217
+ @gui_random_shape.on_click
218
+ def _(_):
219
+ for beta in gui_betas:
220
+ beta.value = np.random.normal(loc=0.0, scale=1.0)
221
+
222
+ gui_betas = []
223
+ for i in range(num_betas):
224
+ beta = server.gui.add_slider(
225
+ f"beta{i}", min=-5.0, max=5.0, step=0.01, initial_value=0.0
226
+ )
227
+ gui_betas.append(beta)
228
+ beta.on_update(set_betas_changed)
229
+
230
+ # GUI elements: joint angles.
231
+ with tab_group.add_tab("Joints", viser.Icon.ANGLE):
232
+ gui_reset_joints = server.gui.add_button("Reset Joints")
233
+ gui_random_joints = server.gui.add_button("Random Joints")
234
+
235
+ @gui_reset_joints.on_click
236
+ def _(_):
237
+ for joint in gui_joints:
238
+ joint.value = (0.0, 0.0, 0.0)
239
+
240
+ @gui_random_joints.on_click
241
+ def _(_):
242
+ rng = np.random.default_rng()
243
+ for joint in gui_joints:
244
+ joint.value = tf.SO3.sample_uniform(rng).log()
245
+
246
+ gui_joints: List[viser.GuiInputHandle[Tuple[float, float, float]]] = []
247
+ for i in range(num_joints):
248
+ gui_joint = server.gui.add_vector3(
249
+ label=f"Joint {i}",
250
+ initial_value=(0.0, 0.0, 0.0),
251
+ step=0.05,
252
+ )
253
+ gui_joints.append(gui_joint)
254
+
255
+ def set_callback_in_closure(i: int) -> None:
256
+ @gui_joint.on_update
257
+ def _(_):
258
+ transform_controls[i].wxyz = tf.SO3.exp(
259
+ np.array(gui_joints[i].value)
260
+ ).wxyz
261
+ out.changed = True
262
+
263
+ set_callback_in_closure(i)
264
+
265
+ # Transform control gizmos on joints.
266
+ transform_controls: List[viser.TransformControlsHandle] = []
267
+ prefixed_joint_names = [] # Joint names, but prefixed with parents.
268
+ for i in range(num_joints):
269
+ prefixed_joint_name = f"joint_{i}"
270
+ if i > 0:
271
+ prefixed_joint_name = (
272
+ prefixed_joint_names[parent_idx[i]] + "/" + prefixed_joint_name
273
+ )
274
+ prefixed_joint_names.append(prefixed_joint_name)
275
+ controls = server.scene.add_transform_controls(
276
+ f"/smpl/{prefixed_joint_name}",
277
+ depth_test=False,
278
+ scale=0.2 * (0.75 ** prefixed_joint_name.count("/")),
279
+ disable_axes=True,
280
+ disable_sliders=True,
281
+ visible=gui_show_controls.value,
282
+ )
283
+ transform_controls.append(controls)
284
+
285
+ def set_callback_in_closure(i: int) -> None:
286
+ @controls.on_update
287
+ def _(_) -> None:
288
+ axisangle = tf.SO3(transform_controls[i].wxyz).log()
289
+ gui_joints[i].value = (axisangle[0], axisangle[1], axisangle[2])
290
+
291
+ set_callback_in_closure(i)
292
+
293
+ out = GuiElements(
294
+ gui_rgb,
295
+ gui_wireframe,
296
+ gui_betas,
297
+ gui_joints,
298
+ transform_controls=transform_controls,
299
+ changed=True,
300
+ betas_changed=False,
301
+ )
302
+ return out
303
+
304
+
305
+ if __name__ == "__main__":
306
+ tyro.cli(main, description=__doc__)
custom_utils/startup.sh CHANGED
@@ -23,7 +23,7 @@ conda activate garmentcode
23
 
24
  # Установить зависимости через pip
25
  echo "Installing dependencies..."
26
- pip install scipy pyaml>=6.0 svgwrite psutil gradio bpy matplotlib svgpathtools cairosvg nicegui trimesh libigl pyrender cgal numpy pygarment
27
 
28
  # Настроить CUDA и другие пути
29
  echo "Adding paths to .bashrc..."
 
23
 
24
  # Установить зависимости через pip
25
  echo "Installing dependencies..."
26
+ pip install scipy pyaml>=6.0 svgwrite psutil gradio bpy matplotlib svgpathtools cairosvg nicegui trimesh libigl pyrender cgal numpy pygarment pyliblzfse viser[examples]
27
 
28
  # Настроить CUDA и другие пути
29
  echo "Adding paths to .bashrc..."