Nekochu commited on
Commit
e8911be
·
verified ·
1 Parent(s): ebf1dc0

Fix render

Browse files
Files changed (2) hide show
  1. README.md +1 -0
  2. app_new.py +165 -170
README.md CHANGED
@@ -8,6 +8,7 @@ sdk_version: "6.1.0"
8
  app_file: app_new.py
9
  pinned: false
10
  python_version: "3.10"
 
11
  ---
12
 
13
  # MoMask: Text-to-Motion Generation
 
8
  app_file: app_new.py
9
  pinned: false
10
  python_version: "3.10"
11
+ short_description: Text-to-3D motion generation using ONNX INT8 models
12
  ---
13
 
14
  # MoMask: Text-to-Motion Generation
app_new.py CHANGED
@@ -13,10 +13,9 @@ import clip
13
  import matplotlib
14
  matplotlib.use('Agg')
15
  import matplotlib.pyplot as plt
16
- from matplotlib.animation import FuncAnimation
17
- import mpl_toolkits.mplot3d.axes3d as p3
18
- from mpl_toolkits.mplot3d.art3d import Poly3DCollection
19
  from pathlib import Path
 
20
  # ============ Quaternion Operations ============
21
  def qinv(q):
22
  """Invert quaternion"""
@@ -30,11 +29,11 @@ def qrot(q, v):
30
  assert q.shape[-1] == 4
31
  assert v.shape[-1] == 3
32
  assert q.shape[:-1] == v.shape[:-1]
33
-
34
  original_shape = list(v.shape)
35
  q = q.contiguous().view(-1, 4)
36
  v = v.contiguous().view(-1, 3)
37
-
38
  qvec = q[:, 1:]
39
  uv = torch.cross(qvec, v, dim=1)
40
  uuv = torch.cross(qvec, uv, dim=1)
@@ -66,10 +65,9 @@ def get_session(name):
66
  path = ONNX_DIR / f"{name}.onnx"
67
  if not path.exists():
68
  raise FileNotFoundError(f"Model not found: {path}")
69
- sessions[name] = ort.InferenceSession(str(path), providers=['CPUExecutionProvider'])
70
  return sessions[name]
71
 
72
-
73
  # ============ Motion Recovery ============
74
  def recover_root_rot_pos(data):
75
  """Recover root rotation and position from motion data"""
@@ -104,75 +102,111 @@ def recover_from_ric(data, joints_num=22):
104
  # ============ Visualization ============
105
  def plot_3d_motion(save_path, joints, title, fps=20):
106
  """Create MP4 video of 3D skeleton motion"""
107
- kinematic_tree = T2M_KINEMATIC_CHAIN
108
- figsize = (6, 6)
109
- radius = 4
110
-
111
- # Prepare title
112
- title_sp = title.split(' ')
113
- if len(title_sp) > 10:
114
- title = '\n'.join([' '.join(title_sp[:10]), ' '.join(title_sp[10:])])
115
-
116
- # Prepare data
117
- data = joints.copy().reshape(len(joints), -1, 3)
118
- frame_number = data.shape[0]
119
-
120
- MINS = data.min(axis=0).min(axis=0)
121
- MAXS = data.max(axis=0).max(axis=0)
122
-
123
- height_offset = MINS[1]
124
- data[:, :, 1] -= height_offset
125
- trajec = data[:, 0, [0, 2]]
126
- data[..., 0] -= data[:, 0:1, 0]
127
- data[..., 2] -= data[:, 0:1, 2]
128
-
129
- fig = plt.figure(figsize=figsize)
130
- ax = p3.Axes3D(fig)
131
-
132
- ax.set_xlim3d([-radius / 2, radius / 2])
133
- ax.set_ylim3d([0, radius])
134
- ax.set_zlim3d([0, radius])
135
- fig.suptitle(title, fontsize=12)
136
- ax.grid(False)
137
-
138
- colors = ['red', 'blue', 'black', 'red', 'blue']
139
-
140
- def plot_xzPlane(minx, maxx, miny, minz, maxz):
141
- verts = [[minx, miny, minz], [minx, miny, maxz],
142
- [maxx, miny, maxz], [maxx, miny, minz]]
143
- xz_plane = Poly3DCollection([verts])
144
- xz_plane.set_facecolor((0.5, 0.5, 0.5, 0.5))
145
- ax.add_collection3d(xz_plane)
146
-
147
- def update(index):
148
- for line in ax.lines[:]: line.remove()
149
- for coll in ax.collections[:]: coll.remove()
150
- ax.view_init(elev=120, azim=-90)
151
- ax.dist = 7.5
152
-
153
- plot_xzPlane(MINS[0] - trajec[index, 0], MAXS[0] - trajec[index, 0], 0,
154
- MINS[2] - trajec[index, 1], MAXS[2] - trajec[index, 1])
155
-
156
- if index > 1:
157
- ax.plot3D(trajec[:index, 0] - trajec[index, 0],
158
- np.zeros_like(trajec[:index, 0]),
159
- trajec[:index, 1] - trajec[index, 1],
160
- linewidth=1.0, color='blue')
161
-
162
- for i, (chain, color) in enumerate(zip(kinematic_tree, colors)):
163
- linewidth = 4.0 if i < 5 else 2.0
164
- ax.plot3D(data[index, chain, 0], data[index, chain, 1],
165
- data[index, chain, 2], linewidth=linewidth, color=color)
166
-
167
- plt.axis('off')
168
- ax.set_xticklabels([])
169
- ax.set_yticklabels([])
170
- ax.set_zticklabels([])
171
-
172
- ani = FuncAnimation(fig, update, frames=frame_number, interval=1000/fps, repeat=False)
173
- ani.save(save_path, fps=fps)
174
  plt.close()
175
- return save_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
  # ============ Sampling Utilities ============
178
  def cosine_schedule(t):
@@ -183,7 +217,7 @@ def top_k_filter(logits, k=0.9):
183
  """Apply top-k filtering"""
184
  k = int((1 - k) * logits.shape[-1])
185
  val, ind = torch.topk(logits, k, dim=-1)
186
- probs = torch.full_like(logits, float('-inf'))
187
  probs.scatter_(-1, ind, val)
188
  return probs
189
 
@@ -191,187 +225,147 @@ def gumbel_sample(logits, temperature=1.0):
191
  """Gumbel softmax sampling"""
192
  gumbels = -torch.log(-torch.log(torch.rand_like(logits) + 1e-8) + 1e-8)
193
  return ((logits / max(temperature, 1e-10)) + gumbels).argmax(dim=-1)
194
-
195
  # ============ Main Generation Pipeline ============
196
- def generate_motion(text, motion_length=0, seed=None):
197
- """Generate motion from text prompt
198
-
199
- Args:
200
- text: Text description of motion
201
- motion_length: Length in seconds (0 = auto-estimate)
202
- seed: Random seed for reproducibility
203
-
204
- Returns:
205
- joints: 3D joint positions (N, 22, 3)
206
- video_path: Path to rendered MP4
207
- """
208
  if seed is not None:
209
  torch.manual_seed(seed)
210
  np.random.seed(seed)
211
 
212
- # Load mean/std for denormalization
213
  mean = np.load(ONNX_DIR / "mean.npy")
214
  std = np.load(ONNX_DIR / "std.npy")
215
 
216
- # 1. Tokenize text with CLIP
217
  tokens = clip.tokenize([text], truncate=True)
218
 
219
- # 2. Encode text with CLIP
220
  clip_sess = get_session("clip_text")
221
- text_emb = clip_sess.run(None, {'text_tokens': tokens.numpy()})[0] # (1, 512)
222
 
223
- # 3. Estimate motion length
224
  if motion_length <= 0:
225
  len_sess = get_session("length_estimator")
226
- len_logits = len_sess.run(None, {'text_embedding': text_emb})[0]
227
- # Sample from distribution
228
  probs = torch.softmax(torch.from_numpy(len_logits), dim=-1)
229
  token_len = torch.multinomial(probs, 1).item()
230
  else:
231
- # Convert seconds to tokens (20 fps, 4 frames per token)
232
  token_len = int(motion_length * 20 / 4)
233
 
234
- token_len = max(2, min(token_len, 49)) # Clamp to valid range
235
  m_length = token_len * 4
236
- max_len = 49 # Fixed sequence length for ONNX model
237
 
238
  print(f"Generating motion: '{text}' ({m_length} frames, {m_length/20:.1f}s)")
239
 
240
- # 4. Initialize with mask/pad tokens
241
- mask_id = 512 # Mask token ID
242
- pad_id = 513 # Pad token ID
243
  ids = torch.full((1, max_len), pad_id, dtype=torch.long)
244
  ids[:, :token_len] = mask_id
245
  scores = torch.zeros(1, max_len)
246
- scores[:, token_len:] = 1e5 # High scores for padding (won't be masked)
247
 
248
- # Create padding mask (True = padded)
249
  padding_mask = np.zeros((1, max_len), dtype=bool)
250
  padding_mask[:, token_len:] = True
251
 
252
- # 5. Iterative generation with MaskTransformer
253
  mask_sess = get_session("mask_transformer")
254
 
255
  for step in range(TIMESTEPS):
256
  t = step / TIMESTEPS
257
  rand_mask_prob = cosine_schedule(torch.tensor(t)).item()
258
-
259
- # Number of tokens to mask (only in valid region)
260
  num_masked = max(1, int(rand_mask_prob * token_len))
261
 
262
- # Get lowest scoring positions to mask (only in valid region)
263
  valid_scores = scores[:, :token_len].clone()
264
  _, sorted_idx = valid_scores.sort(dim=1)
265
  mask_pos = sorted_idx[:, :num_masked]
266
  is_mask = torch.zeros(1, token_len, dtype=torch.bool)
267
  is_mask.scatter_(1, mask_pos, True)
268
 
269
- # Apply mask only to valid positions
270
  ids[:, :token_len] = torch.where(is_mask, mask_id, ids[:, :token_len])
271
 
272
- # Run transformer with fixed max_len
273
  logits = mask_sess.run(None, {
274
- 'motion_ids': ids.numpy(),
275
- 'cond_vector': text_emb,
276
- 'padding_mask': padding_mask
277
- })[0] # (1, num_tokens, max_len)
278
-
279
- logits = torch.from_numpy(logits) # (1, 514, max_len)
280
 
281
- # Get logits for valid positions only
282
- logits = logits[:, :512, :token_len] # Remove mask/pad tokens, trim to valid len
283
- logits = logits.permute(0, 2, 1) # (1, token_len, 512)
284
 
285
- # Apply temperature and top-k filtering
286
  filtered_logits = top_k_filter(logits / TEMPERATURE, TOPK_FILTER)
 
287
 
288
- # Sample new tokens
289
- new_ids = gumbel_sample(filtered_logits, TEMPERATURE) # (1, token_len)
290
-
291
- # Get confidence scores
292
  probs = torch.softmax(filtered_logits, dim=-1)
293
  new_scores = probs.gather(-1, new_ids.unsqueeze(-1)).squeeze(-1)
294
 
295
- # Update only masked positions (in valid region)
296
  ids[:, :token_len] = torch.where(is_mask, new_ids, ids[:, :token_len])
297
  scores[:, :token_len] = torch.where(is_mask, new_scores, scores[:, :token_len])
298
-
299
- # 6. Residual refinement with ResidualTransformer
300
  res_sess = get_session("residual_transformer")
301
  num_quantizers = 6
302
 
303
- # Load token embeddings for residual transformer
304
- res_token_embed = np.load(ONNX_DIR / "res_token_embed.npy") # (5, 513, 512)
305
 
306
- # Initialize all quantizer codes
307
  all_codes = torch.zeros(1, max_len, num_quantizers, dtype=torch.long)
308
  all_codes[:, :, 0] = ids
309
 
310
- # Accumulate code embeddings for residual refinement
311
  history_sum = np.zeros((1, max_len, 512), dtype=np.float32)
312
  motion_ids = ids.clone()
313
 
314
  for q in range(1, num_quantizers):
315
- # Get token embeddings for current motion_ids using layer q-1
316
- token_embed = res_token_embed[q-1] # (513, 512)
317
- # Gather embeddings for each position (clamp padding to valid range)
318
  clamped_ids = np.clip(motion_ids[0].numpy(), 0, 512)
319
- gathered = token_embed[clamped_ids] # (max_len, 512)
320
- history_sum += gathered[np.newaxis, :, :] # Accumulate
321
 
322
  q_id = np.array([q], dtype=np.int64)
323
 
324
  logits = res_sess.run(None, {
325
- 'motion_codes': history_sum.astype(np.float32),
326
- 'q_id': q_id,
327
- 'cond_vector': text_emb,
328
- 'padding_mask': padding_mask
329
  })[0]
330
 
331
  logits = torch.from_numpy(logits)[:, :512, :token_len].permute(0, 2, 1)
332
  new_ids_q = gumbel_sample(logits, 1.0)
333
  all_codes[:, :token_len, q] = new_ids_q
334
- motion_ids[:, :token_len] = new_ids_q # Update for next iteration
335
 
336
- # 7. Decode motion with VQVAE (only valid tokens)
337
  decoder_sess = get_session("vqvae_decoder")
338
  valid_codes = all_codes[:, :token_len, :].numpy()
339
  motion = decoder_sess.run(None, {
340
- 'code_indices': valid_codes
341
- })[0] # (1, token_len, 263)
342
 
343
- # Upsample to full length (token_len -> m_length via stride=2, down_t=2 -> 4x)
344
  motion = np.repeat(motion, 4, axis=1)[:, :m_length, :]
345
-
346
- # 8. Denormalize
347
  motion = motion * std + mean
348
 
349
- # 9. Recover 3D joint positions
350
  motion_tensor = torch.from_numpy(motion).float()
351
  joints = recover_from_ric(motion_tensor, JOINTS_NUM)
352
- joints = joints.squeeze(0).numpy() # (m_length, 22, 3)
353
 
354
- # 10. Render video
355
- video_path = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False).name
356
  plot_3d_motion(video_path, joints, text, fps=20)
357
 
358
- return joints, video_path
 
 
 
 
359
 
 
360
  # ============ Gradio Interface ============
361
  def create_demo():
362
  import gradio as gr
363
 
364
- def generate_fn(text, length, seed):
365
  if not text or text.strip() == "":
366
- return None
367
  seed = int(seed) if seed else None
368
  length = float(length) if length else 0
369
- _, video_path = generate_motion(text, length, seed)
370
- return video_path
371
 
372
  with gr.Blocks(title="MoMask") as demo:
373
  gr.Markdown("## [MoMask](https://github.com/EricGuo5513/momask-codes) - Text to Motion")
374
- gr.Markdown("Generate 3D human skeleton animations from text descriptions.")
375
 
376
  with gr.Row():
377
  with gr.Column():
@@ -383,41 +377,42 @@ def create_demo():
383
  info="0 = auto-estimate")
384
  seed = gr.Number(label="Seed", value=42,
385
  info="For reproducibility")
 
386
  btn = gr.Button("Generate", variant="primary")
387
 
388
  with gr.Column():
389
  video = gr.Video(label="Generated Motion")
 
390
 
391
  gr.Examples(
392
  examples=[
393
- ["A person walks forward", 0, 42],
394
- ["A person is running on a treadmill", 0, 123],
395
- ["A person jumps up and then lands", 0, 456],
396
- ["A person does a salsa dance", 0, 789],
397
- ["A person kicks with their right leg", 0, 101],
398
  ],
399
- inputs=[text, length, seed],
400
- outputs=video,
401
  fn=generate_fn,
402
  cache_examples=False,
403
  )
404
 
405
- btn.click(fn=generate_fn, inputs=[text, length, seed], outputs=video)
406
 
407
  return demo
408
 
409
  # ============ CLI ============
410
  if __name__ == "__main__":
411
  if len(sys.argv) > 1:
412
- # CLI mode: python app.py "motion description" [length] [seed]
413
  text = sys.argv[1]
414
  length = float(sys.argv[2]) if len(sys.argv) > 2 else 0
415
  seed = int(sys.argv[3]) if len(sys.argv) > 3 else 42
416
 
417
- joints, video_path = generate_motion(text, length, seed)
418
- print(f"Generated: {video_path}")
 
419
  print(f"Joints shape: {joints.shape}")
420
  else:
421
- # Gradio mode
422
  demo = create_demo()
423
- demo.launch()
 
13
  import matplotlib
14
  matplotlib.use('Agg')
15
  import matplotlib.pyplot as plt
16
+ from matplotlib.animation import FuncAnimation, FFMpegWriter
 
 
17
  from pathlib import Path
18
+
19
  # ============ Quaternion Operations ============
20
  def qinv(q):
21
  """Invert quaternion"""
 
29
  assert q.shape[-1] == 4
30
  assert v.shape[-1] == 3
31
  assert q.shape[:-1] == v.shape[:-1]
32
+
33
  original_shape = list(v.shape)
34
  q = q.contiguous().view(-1, 4)
35
  v = v.contiguous().view(-1, 3)
36
+
37
  qvec = q[:, 1:]
38
  uv = torch.cross(qvec, v, dim=1)
39
  uuv = torch.cross(qvec, uv, dim=1)
 
65
  path = ONNX_DIR / f"{name}.onnx"
66
  if not path.exists():
67
  raise FileNotFoundError(f"Model not found: {path}")
68
+ sessions[name] = ort.InferenceSession(str(path), providers=["CPUExecutionProvider"])
69
  return sessions[name]
70
 
 
71
  # ============ Motion Recovery ============
72
  def recover_root_rot_pos(data):
73
  """Recover root rotation and position from motion data"""
 
102
  # ============ Visualization ============
103
  def plot_3d_motion(save_path, joints, title, fps=20):
104
  """Create MP4 video of 3D skeleton motion"""
105
+ fig = plt.figure(figsize=(8, 8))
106
+ ax = fig.add_subplot(111, projection="3d")
107
+ COLORS = ["red", "blue", "black", "green", "purple"]
108
+
109
+ def init():
110
+ ax.set_xlim(-1.5, 1.5)
111
+ ax.set_ylim(-1.5, 1.5)
112
+ ax.set_zlim(0, 2)
113
+ ax.set_xlabel("X")
114
+ ax.set_ylabel("Z")
115
+ ax.set_zlabel("Y (up)")
116
+ ax.set_title(title)
117
+ return []
118
+
119
+ lines = []
120
+ for i, chain in enumerate(T2M_KINEMATIC_CHAIN):
121
+ line, = ax.plot([], [], [], color=COLORS[i], linewidth=2, marker="o", markersize=3)
122
+ lines.append(line)
123
+
124
+ def update(frame):
125
+ data = joints[frame]
126
+ for i, chain in enumerate(T2M_KINEMATIC_CHAIN):
127
+ x = [data[j, 0] for j in chain]
128
+ y = [data[j, 2] for j in chain]
129
+ z = [data[j, 1] for j in chain]
130
+ lines[i].set_data(x, y)
131
+ lines[i].set_3d_properties(z)
132
+ ax.view_init(elev=20, azim=45 + frame * 0.5)
133
+ return lines
134
+
135
+ ani = FuncAnimation(fig, update, frames=len(joints), init_func=init, blit=False, interval=1000//fps)
136
+ writer = FFMpegWriter(fps=fps, bitrate=2000)
137
+ ani.save(save_path, writer=writer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  plt.close()
139
+ # ============ BVH Export ============
140
+ def joints_to_bvh(joints, output_path, fps=20):
141
+ """Convert joint positions to BVH format for Blender import."""
142
+ n_frames, n_joints, _ = joints.shape
143
+
144
+ joint_names = [
145
+ "Hips", "LeftUpLeg", "RightUpLeg", "Spine", "LeftLeg", "RightLeg",
146
+ "Spine1", "LeftFoot", "RightFoot", "Spine2", "LeftToe", "RightToe",
147
+ "Neck", "LeftShoulder", "RightShoulder", "Head", "LeftArm", "RightArm",
148
+ "LeftForeArm", "RightForeArm", "LeftHand", "RightHand"
149
+ ]
150
+
151
+ parents = [-1, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14, 16, 17, 18, 19]
152
+
153
+ offsets = np.zeros((n_joints, 3))
154
+ ref_frame = joints[0]
155
+ for i in range(n_joints):
156
+ if parents[i] >= 0:
157
+ offsets[i] = ref_frame[i] - ref_frame[parents[i]]
158
+
159
+ scale = 100.0
160
+ offsets *= scale
161
+ joints_scaled = joints * scale
162
+
163
+ with open(output_path, "w") as f:
164
+ f.write("HIERARCHY" + chr(10))
165
+
166
+ def write_joint(idx, indent):
167
+ name = joint_names[idx]
168
+ off = offsets[idx]
169
+ prefix = " " * indent
170
+
171
+ children = [i for i, p in enumerate(parents) if p == idx]
172
+
173
+ if idx == 0:
174
+ f.write(f"ROOT {name}" + chr(10))
175
+ else:
176
+ f.write(f"{prefix}JOINT {name}" + chr(10))
177
+
178
+ f.write(f"{prefix}{{" + chr(10))
179
+ f.write(f"{prefix} OFFSET {off[0]:.6f} {off[1]:.6f} {off[2]:.6f}" + chr(10))
180
+
181
+ if idx == 0:
182
+ f.write(f"{prefix} CHANNELS 6 Xposition Yposition Zposition Xrotation Yrotation Zrotation" + chr(10))
183
+ else:
184
+ f.write(f"{prefix} CHANNELS 3 Xrotation Yrotation Zrotation" + chr(10))
185
+
186
+ if children:
187
+ for child in children:
188
+ write_joint(child, indent + 1)
189
+ else:
190
+ f.write(f"{prefix} End Site" + chr(10))
191
+ f.write(f"{prefix} {{" + chr(10))
192
+ f.write(f"{prefix} OFFSET 0.0 0.0 0.0" + chr(10))
193
+ f.write(f"{prefix} }}" + chr(10))
194
+
195
+ f.write(f"{prefix}}}" + chr(10))
196
+
197
+ write_joint(0, 0)
198
+
199
+ f.write("MOTION" + chr(10))
200
+ f.write(f"Frames: {n_frames}" + chr(10))
201
+ f.write(f"Frame Time: {1.0/fps:.6f}" + chr(10))
202
+
203
+ for frame in range(n_frames):
204
+ root_pos = joints_scaled[frame, 0]
205
+ values = [root_pos[0], root_pos[1], root_pos[2]]
206
+ values.extend([0.0] * 3 * n_joints)
207
+ f.write(" ".join(f"{v:.6f}" for v in values) + chr(10))
208
+
209
+ return output_path
210
 
211
  # ============ Sampling Utilities ============
212
  def cosine_schedule(t):
 
217
  """Apply top-k filtering"""
218
  k = int((1 - k) * logits.shape[-1])
219
  val, ind = torch.topk(logits, k, dim=-1)
220
+ probs = torch.full_like(logits, float("-inf"))
221
  probs.scatter_(-1, ind, val)
222
  return probs
223
 
 
225
  """Gumbel softmax sampling"""
226
  gumbels = -torch.log(-torch.log(torch.rand_like(logits) + 1e-8) + 1e-8)
227
  return ((logits / max(temperature, 1e-10)) + gumbels).argmax(dim=-1)
 
228
  # ============ Main Generation Pipeline ============
229
+ def generate_motion(text, motion_length=0, seed=None, export_bvh=False):
230
+ """Generate motion from text prompt"""
 
 
 
 
 
 
 
 
 
 
231
  if seed is not None:
232
  torch.manual_seed(seed)
233
  np.random.seed(seed)
234
 
 
235
  mean = np.load(ONNX_DIR / "mean.npy")
236
  std = np.load(ONNX_DIR / "std.npy")
237
 
 
238
  tokens = clip.tokenize([text], truncate=True)
239
 
 
240
  clip_sess = get_session("clip_text")
241
+ text_emb = clip_sess.run(None, {"text_tokens": tokens.numpy()})[0]
242
 
 
243
  if motion_length <= 0:
244
  len_sess = get_session("length_estimator")
245
+ len_logits = len_sess.run(None, {"text_embedding": text_emb})[0]
 
246
  probs = torch.softmax(torch.from_numpy(len_logits), dim=-1)
247
  token_len = torch.multinomial(probs, 1).item()
248
  else:
 
249
  token_len = int(motion_length * 20 / 4)
250
 
251
+ token_len = max(2, min(token_len, 49))
252
  m_length = token_len * 4
253
+ max_len = 49
254
 
255
  print(f"Generating motion: '{text}' ({m_length} frames, {m_length/20:.1f}s)")
256
 
257
+ mask_id = 512
258
+ pad_id = 513
 
259
  ids = torch.full((1, max_len), pad_id, dtype=torch.long)
260
  ids[:, :token_len] = mask_id
261
  scores = torch.zeros(1, max_len)
262
+ scores[:, token_len:] = 1e5
263
 
 
264
  padding_mask = np.zeros((1, max_len), dtype=bool)
265
  padding_mask[:, token_len:] = True
266
 
 
267
  mask_sess = get_session("mask_transformer")
268
 
269
  for step in range(TIMESTEPS):
270
  t = step / TIMESTEPS
271
  rand_mask_prob = cosine_schedule(torch.tensor(t)).item()
 
 
272
  num_masked = max(1, int(rand_mask_prob * token_len))
273
 
 
274
  valid_scores = scores[:, :token_len].clone()
275
  _, sorted_idx = valid_scores.sort(dim=1)
276
  mask_pos = sorted_idx[:, :num_masked]
277
  is_mask = torch.zeros(1, token_len, dtype=torch.bool)
278
  is_mask.scatter_(1, mask_pos, True)
279
 
 
280
  ids[:, :token_len] = torch.where(is_mask, mask_id, ids[:, :token_len])
281
 
 
282
  logits = mask_sess.run(None, {
283
+ "motion_ids": ids.numpy(),
284
+ "cond_vector": text_emb,
285
+ "padding_mask": padding_mask
286
+ })[0]
 
 
287
 
288
+ logits = torch.from_numpy(logits)
289
+ logits = logits[:, :512, :token_len]
290
+ logits = logits.permute(0, 2, 1)
291
 
 
292
  filtered_logits = top_k_filter(logits / TEMPERATURE, TOPK_FILTER)
293
+ new_ids = gumbel_sample(filtered_logits, TEMPERATURE)
294
 
 
 
 
 
295
  probs = torch.softmax(filtered_logits, dim=-1)
296
  new_scores = probs.gather(-1, new_ids.unsqueeze(-1)).squeeze(-1)
297
 
 
298
  ids[:, :token_len] = torch.where(is_mask, new_ids, ids[:, :token_len])
299
  scores[:, :token_len] = torch.where(is_mask, new_scores, scores[:, :token_len])
 
 
300
  res_sess = get_session("residual_transformer")
301
  num_quantizers = 6
302
 
303
+ res_token_embed = np.load(ONNX_DIR / "res_token_embed.npy")
 
304
 
 
305
  all_codes = torch.zeros(1, max_len, num_quantizers, dtype=torch.long)
306
  all_codes[:, :, 0] = ids
307
 
 
308
  history_sum = np.zeros((1, max_len, 512), dtype=np.float32)
309
  motion_ids = ids.clone()
310
 
311
  for q in range(1, num_quantizers):
312
+ token_embed = res_token_embed[q-1]
 
 
313
  clamped_ids = np.clip(motion_ids[0].numpy(), 0, 512)
314
+ gathered = token_embed[clamped_ids]
315
+ history_sum += gathered[np.newaxis, :, :]
316
 
317
  q_id = np.array([q], dtype=np.int64)
318
 
319
  logits = res_sess.run(None, {
320
+ "motion_codes": history_sum.astype(np.float32),
321
+ "q_id": q_id,
322
+ "cond_vector": text_emb,
323
+ "padding_mask": padding_mask
324
  })[0]
325
 
326
  logits = torch.from_numpy(logits)[:, :512, :token_len].permute(0, 2, 1)
327
  new_ids_q = gumbel_sample(logits, 1.0)
328
  all_codes[:, :token_len, q] = new_ids_q
329
+ motion_ids[:, :token_len] = new_ids_q
330
 
 
331
  decoder_sess = get_session("vqvae_decoder")
332
  valid_codes = all_codes[:, :token_len, :].numpy()
333
  motion = decoder_sess.run(None, {
334
+ "code_indices": valid_codes
335
+ })[0]
336
 
 
337
  motion = np.repeat(motion, 4, axis=1)[:, :m_length, :]
 
 
338
  motion = motion * std + mean
339
 
 
340
  motion_tensor = torch.from_numpy(motion).float()
341
  joints = recover_from_ric(motion_tensor, JOINTS_NUM)
342
+ joints = joints.squeeze(0).numpy()
343
 
344
+ video_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
 
345
  plot_3d_motion(video_path, joints, text, fps=20)
346
 
347
+ bvh_path = None
348
+ if export_bvh:
349
+ bvh_path = tempfile.NamedTemporaryFile(suffix=".bvh", delete=False).name
350
+ joints_to_bvh(joints, bvh_path, fps=20)
351
+ print(f"BVH exported: {bvh_path}")
352
 
353
+ return joints, video_path, bvh_path
354
  # ============ Gradio Interface ============
355
  def create_demo():
356
  import gradio as gr
357
 
358
+ def generate_fn(text, length, seed, export_bvh):
359
  if not text or text.strip() == "":
360
+ return None, None
361
  seed = int(seed) if seed else None
362
  length = float(length) if length else 0
363
+ joints, video_path, bvh_path = generate_motion(text, length, seed, export_bvh)
364
+ return video_path, bvh_path
365
 
366
  with gr.Blocks(title="MoMask") as demo:
367
  gr.Markdown("## [MoMask](https://github.com/EricGuo5513/momask-codes) - Text to Motion")
368
+ gr.Markdown("Generate 3D human skeleton animations from text descriptions. Download BVH for Blender!")
369
 
370
  with gr.Row():
371
  with gr.Column():
 
377
  info="0 = auto-estimate")
378
  seed = gr.Number(label="Seed", value=42,
379
  info="For reproducibility")
380
+ export_bvh = gr.Checkbox(label="Export BVH for Blender", value=True)
381
  btn = gr.Button("Generate", variant="primary")
382
 
383
  with gr.Column():
384
  video = gr.Video(label="Generated Motion")
385
+ bvh_file = gr.File(label="BVH Download")
386
 
387
  gr.Examples(
388
  examples=[
389
+ ["A person walks forward", 0, 42, True],
390
+ ["A person is running on a treadmill", 0, 123, True],
391
+ ["A person jumps up and then lands", 0, 456, True],
392
+ ["A person does a salsa dance", 0, 789, True],
393
+ ["A person kicks with their right leg", 0, 101, True],
394
  ],
395
+ inputs=[text, length, seed, export_bvh],
396
+ outputs=[video, bvh_file],
397
  fn=generate_fn,
398
  cache_examples=False,
399
  )
400
 
401
+ btn.click(fn=generate_fn, inputs=[text, length, seed, export_bvh], outputs=[video, bvh_file])
402
 
403
  return demo
404
 
405
  # ============ CLI ============
406
  if __name__ == "__main__":
407
  if len(sys.argv) > 1:
 
408
  text = sys.argv[1]
409
  length = float(sys.argv[2]) if len(sys.argv) > 2 else 0
410
  seed = int(sys.argv[3]) if len(sys.argv) > 3 else 42
411
 
412
+ joints, video_path, bvh_path = generate_motion(text, length, seed, export_bvh=True)
413
+ print(f"Video: {video_path}")
414
+ print(f"BVH: {bvh_path}")
415
  print(f"Joints shape: {joints.shape}")
416
  else:
 
417
  demo = create_demo()
418
+ demo.launch()