Spaces:
Sleeping
Sleeping
Fix render
Browse files- README.md +1 -0
- app_new.py +165 -170
README.md
CHANGED
|
@@ -8,6 +8,7 @@ sdk_version: "6.1.0"
|
|
| 8 |
app_file: app_new.py
|
| 9 |
pinned: false
|
| 10 |
python_version: "3.10"
|
|
|
|
| 11 |
---
|
| 12 |
|
| 13 |
# MoMask: Text-to-Motion Generation
|
|
|
|
| 8 |
app_file: app_new.py
|
| 9 |
pinned: false
|
| 10 |
python_version: "3.10"
|
| 11 |
+
short_description: Text-to-3D motion generation using ONNX INT8 models
|
| 12 |
---
|
| 13 |
|
| 14 |
# MoMask: Text-to-Motion Generation
|
app_new.py
CHANGED
|
@@ -13,10 +13,9 @@ import clip
|
|
| 13 |
import matplotlib
|
| 14 |
matplotlib.use('Agg')
|
| 15 |
import matplotlib.pyplot as plt
|
| 16 |
-
from matplotlib.animation import FuncAnimation
|
| 17 |
-
import mpl_toolkits.mplot3d.axes3d as p3
|
| 18 |
-
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
|
| 19 |
from pathlib import Path
|
|
|
|
| 20 |
# ============ Quaternion Operations ============
|
| 21 |
def qinv(q):
|
| 22 |
"""Invert quaternion"""
|
|
@@ -30,11 +29,11 @@ def qrot(q, v):
|
|
| 30 |
assert q.shape[-1] == 4
|
| 31 |
assert v.shape[-1] == 3
|
| 32 |
assert q.shape[:-1] == v.shape[:-1]
|
| 33 |
-
|
| 34 |
original_shape = list(v.shape)
|
| 35 |
q = q.contiguous().view(-1, 4)
|
| 36 |
v = v.contiguous().view(-1, 3)
|
| 37 |
-
|
| 38 |
qvec = q[:, 1:]
|
| 39 |
uv = torch.cross(qvec, v, dim=1)
|
| 40 |
uuv = torch.cross(qvec, uv, dim=1)
|
|
@@ -66,10 +65,9 @@ def get_session(name):
|
|
| 66 |
path = ONNX_DIR / f"{name}.onnx"
|
| 67 |
if not path.exists():
|
| 68 |
raise FileNotFoundError(f"Model not found: {path}")
|
| 69 |
-
sessions[name] = ort.InferenceSession(str(path), providers=[
|
| 70 |
return sessions[name]
|
| 71 |
|
| 72 |
-
|
| 73 |
# ============ Motion Recovery ============
|
| 74 |
def recover_root_rot_pos(data):
|
| 75 |
"""Recover root rotation and position from motion data"""
|
|
@@ -104,75 +102,111 @@ def recover_from_ric(data, joints_num=22):
|
|
| 104 |
# ============ Visualization ============
|
| 105 |
def plot_3d_motion(save_path, joints, title, fps=20):
|
| 106 |
"""Create MP4 video of 3D skeleton motion"""
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
def plot_xzPlane(minx, maxx, miny, minz, maxz):
|
| 141 |
-
verts = [[minx, miny, minz], [minx, miny, maxz],
|
| 142 |
-
[maxx, miny, maxz], [maxx, miny, minz]]
|
| 143 |
-
xz_plane = Poly3DCollection([verts])
|
| 144 |
-
xz_plane.set_facecolor((0.5, 0.5, 0.5, 0.5))
|
| 145 |
-
ax.add_collection3d(xz_plane)
|
| 146 |
-
|
| 147 |
-
def update(index):
|
| 148 |
-
for line in ax.lines[:]: line.remove()
|
| 149 |
-
for coll in ax.collections[:]: coll.remove()
|
| 150 |
-
ax.view_init(elev=120, azim=-90)
|
| 151 |
-
ax.dist = 7.5
|
| 152 |
-
|
| 153 |
-
plot_xzPlane(MINS[0] - trajec[index, 0], MAXS[0] - trajec[index, 0], 0,
|
| 154 |
-
MINS[2] - trajec[index, 1], MAXS[2] - trajec[index, 1])
|
| 155 |
-
|
| 156 |
-
if index > 1:
|
| 157 |
-
ax.plot3D(trajec[:index, 0] - trajec[index, 0],
|
| 158 |
-
np.zeros_like(trajec[:index, 0]),
|
| 159 |
-
trajec[:index, 1] - trajec[index, 1],
|
| 160 |
-
linewidth=1.0, color='blue')
|
| 161 |
-
|
| 162 |
-
for i, (chain, color) in enumerate(zip(kinematic_tree, colors)):
|
| 163 |
-
linewidth = 4.0 if i < 5 else 2.0
|
| 164 |
-
ax.plot3D(data[index, chain, 0], data[index, chain, 1],
|
| 165 |
-
data[index, chain, 2], linewidth=linewidth, color=color)
|
| 166 |
-
|
| 167 |
-
plt.axis('off')
|
| 168 |
-
ax.set_xticklabels([])
|
| 169 |
-
ax.set_yticklabels([])
|
| 170 |
-
ax.set_zticklabels([])
|
| 171 |
-
|
| 172 |
-
ani = FuncAnimation(fig, update, frames=frame_number, interval=1000/fps, repeat=False)
|
| 173 |
-
ani.save(save_path, fps=fps)
|
| 174 |
plt.close()
|
| 175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
|
| 177 |
# ============ Sampling Utilities ============
|
| 178 |
def cosine_schedule(t):
|
|
@@ -183,7 +217,7 @@ def top_k_filter(logits, k=0.9):
|
|
| 183 |
"""Apply top-k filtering"""
|
| 184 |
k = int((1 - k) * logits.shape[-1])
|
| 185 |
val, ind = torch.topk(logits, k, dim=-1)
|
| 186 |
-
probs = torch.full_like(logits, float(
|
| 187 |
probs.scatter_(-1, ind, val)
|
| 188 |
return probs
|
| 189 |
|
|
@@ -191,187 +225,147 @@ def gumbel_sample(logits, temperature=1.0):
|
|
| 191 |
"""Gumbel softmax sampling"""
|
| 192 |
gumbels = -torch.log(-torch.log(torch.rand_like(logits) + 1e-8) + 1e-8)
|
| 193 |
return ((logits / max(temperature, 1e-10)) + gumbels).argmax(dim=-1)
|
| 194 |
-
|
| 195 |
# ============ Main Generation Pipeline ============
|
| 196 |
-
def generate_motion(text, motion_length=0, seed=None):
|
| 197 |
-
"""Generate motion from text prompt
|
| 198 |
-
|
| 199 |
-
Args:
|
| 200 |
-
text: Text description of motion
|
| 201 |
-
motion_length: Length in seconds (0 = auto-estimate)
|
| 202 |
-
seed: Random seed for reproducibility
|
| 203 |
-
|
| 204 |
-
Returns:
|
| 205 |
-
joints: 3D joint positions (N, 22, 3)
|
| 206 |
-
video_path: Path to rendered MP4
|
| 207 |
-
"""
|
| 208 |
if seed is not None:
|
| 209 |
torch.manual_seed(seed)
|
| 210 |
np.random.seed(seed)
|
| 211 |
|
| 212 |
-
# Load mean/std for denormalization
|
| 213 |
mean = np.load(ONNX_DIR / "mean.npy")
|
| 214 |
std = np.load(ONNX_DIR / "std.npy")
|
| 215 |
|
| 216 |
-
# 1. Tokenize text with CLIP
|
| 217 |
tokens = clip.tokenize([text], truncate=True)
|
| 218 |
|
| 219 |
-
# 2. Encode text with CLIP
|
| 220 |
clip_sess = get_session("clip_text")
|
| 221 |
-
text_emb = clip_sess.run(None, {
|
| 222 |
|
| 223 |
-
# 3. Estimate motion length
|
| 224 |
if motion_length <= 0:
|
| 225 |
len_sess = get_session("length_estimator")
|
| 226 |
-
len_logits = len_sess.run(None, {
|
| 227 |
-
# Sample from distribution
|
| 228 |
probs = torch.softmax(torch.from_numpy(len_logits), dim=-1)
|
| 229 |
token_len = torch.multinomial(probs, 1).item()
|
| 230 |
else:
|
| 231 |
-
# Convert seconds to tokens (20 fps, 4 frames per token)
|
| 232 |
token_len = int(motion_length * 20 / 4)
|
| 233 |
|
| 234 |
-
token_len = max(2, min(token_len, 49))
|
| 235 |
m_length = token_len * 4
|
| 236 |
-
max_len = 49
|
| 237 |
|
| 238 |
print(f"Generating motion: '{text}' ({m_length} frames, {m_length/20:.1f}s)")
|
| 239 |
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
pad_id = 513 # Pad token ID
|
| 243 |
ids = torch.full((1, max_len), pad_id, dtype=torch.long)
|
| 244 |
ids[:, :token_len] = mask_id
|
| 245 |
scores = torch.zeros(1, max_len)
|
| 246 |
-
scores[:, token_len:] = 1e5
|
| 247 |
|
| 248 |
-
# Create padding mask (True = padded)
|
| 249 |
padding_mask = np.zeros((1, max_len), dtype=bool)
|
| 250 |
padding_mask[:, token_len:] = True
|
| 251 |
|
| 252 |
-
# 5. Iterative generation with MaskTransformer
|
| 253 |
mask_sess = get_session("mask_transformer")
|
| 254 |
|
| 255 |
for step in range(TIMESTEPS):
|
| 256 |
t = step / TIMESTEPS
|
| 257 |
rand_mask_prob = cosine_schedule(torch.tensor(t)).item()
|
| 258 |
-
|
| 259 |
-
# Number of tokens to mask (only in valid region)
|
| 260 |
num_masked = max(1, int(rand_mask_prob * token_len))
|
| 261 |
|
| 262 |
-
# Get lowest scoring positions to mask (only in valid region)
|
| 263 |
valid_scores = scores[:, :token_len].clone()
|
| 264 |
_, sorted_idx = valid_scores.sort(dim=1)
|
| 265 |
mask_pos = sorted_idx[:, :num_masked]
|
| 266 |
is_mask = torch.zeros(1, token_len, dtype=torch.bool)
|
| 267 |
is_mask.scatter_(1, mask_pos, True)
|
| 268 |
|
| 269 |
-
# Apply mask only to valid positions
|
| 270 |
ids[:, :token_len] = torch.where(is_mask, mask_id, ids[:, :token_len])
|
| 271 |
|
| 272 |
-
# Run transformer with fixed max_len
|
| 273 |
logits = mask_sess.run(None, {
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
})[0]
|
| 278 |
-
|
| 279 |
-
logits = torch.from_numpy(logits) # (1, 514, max_len)
|
| 280 |
|
| 281 |
-
|
| 282 |
-
logits = logits[:, :512, :token_len]
|
| 283 |
-
logits = logits.permute(0, 2, 1)
|
| 284 |
|
| 285 |
-
# Apply temperature and top-k filtering
|
| 286 |
filtered_logits = top_k_filter(logits / TEMPERATURE, TOPK_FILTER)
|
|
|
|
| 287 |
|
| 288 |
-
# Sample new tokens
|
| 289 |
-
new_ids = gumbel_sample(filtered_logits, TEMPERATURE) # (1, token_len)
|
| 290 |
-
|
| 291 |
-
# Get confidence scores
|
| 292 |
probs = torch.softmax(filtered_logits, dim=-1)
|
| 293 |
new_scores = probs.gather(-1, new_ids.unsqueeze(-1)).squeeze(-1)
|
| 294 |
|
| 295 |
-
# Update only masked positions (in valid region)
|
| 296 |
ids[:, :token_len] = torch.where(is_mask, new_ids, ids[:, :token_len])
|
| 297 |
scores[:, :token_len] = torch.where(is_mask, new_scores, scores[:, :token_len])
|
| 298 |
-
|
| 299 |
-
# 6. Residual refinement with ResidualTransformer
|
| 300 |
res_sess = get_session("residual_transformer")
|
| 301 |
num_quantizers = 6
|
| 302 |
|
| 303 |
-
|
| 304 |
-
res_token_embed = np.load(ONNX_DIR / "res_token_embed.npy") # (5, 513, 512)
|
| 305 |
|
| 306 |
-
# Initialize all quantizer codes
|
| 307 |
all_codes = torch.zeros(1, max_len, num_quantizers, dtype=torch.long)
|
| 308 |
all_codes[:, :, 0] = ids
|
| 309 |
|
| 310 |
-
# Accumulate code embeddings for residual refinement
|
| 311 |
history_sum = np.zeros((1, max_len, 512), dtype=np.float32)
|
| 312 |
motion_ids = ids.clone()
|
| 313 |
|
| 314 |
for q in range(1, num_quantizers):
|
| 315 |
-
|
| 316 |
-
token_embed = res_token_embed[q-1] # (513, 512)
|
| 317 |
-
# Gather embeddings for each position (clamp padding to valid range)
|
| 318 |
clamped_ids = np.clip(motion_ids[0].numpy(), 0, 512)
|
| 319 |
-
gathered = token_embed[clamped_ids]
|
| 320 |
-
history_sum += gathered[np.newaxis, :, :]
|
| 321 |
|
| 322 |
q_id = np.array([q], dtype=np.int64)
|
| 323 |
|
| 324 |
logits = res_sess.run(None, {
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
})[0]
|
| 330 |
|
| 331 |
logits = torch.from_numpy(logits)[:, :512, :token_len].permute(0, 2, 1)
|
| 332 |
new_ids_q = gumbel_sample(logits, 1.0)
|
| 333 |
all_codes[:, :token_len, q] = new_ids_q
|
| 334 |
-
motion_ids[:, :token_len] = new_ids_q
|
| 335 |
|
| 336 |
-
# 7. Decode motion with VQVAE (only valid tokens)
|
| 337 |
decoder_sess = get_session("vqvae_decoder")
|
| 338 |
valid_codes = all_codes[:, :token_len, :].numpy()
|
| 339 |
motion = decoder_sess.run(None, {
|
| 340 |
-
|
| 341 |
-
})[0]
|
| 342 |
|
| 343 |
-
# Upsample to full length (token_len -> m_length via stride=2, down_t=2 -> 4x)
|
| 344 |
motion = np.repeat(motion, 4, axis=1)[:, :m_length, :]
|
| 345 |
-
|
| 346 |
-
# 8. Denormalize
|
| 347 |
motion = motion * std + mean
|
| 348 |
|
| 349 |
-
# 9. Recover 3D joint positions
|
| 350 |
motion_tensor = torch.from_numpy(motion).float()
|
| 351 |
joints = recover_from_ric(motion_tensor, JOINTS_NUM)
|
| 352 |
-
joints = joints.squeeze(0).numpy()
|
| 353 |
|
| 354 |
-
|
| 355 |
-
video_path = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False).name
|
| 356 |
plot_3d_motion(video_path, joints, text, fps=20)
|
| 357 |
|
| 358 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
|
|
|
|
| 360 |
# ============ Gradio Interface ============
|
| 361 |
def create_demo():
|
| 362 |
import gradio as gr
|
| 363 |
|
| 364 |
-
def generate_fn(text, length, seed):
|
| 365 |
if not text or text.strip() == "":
|
| 366 |
-
return None
|
| 367 |
seed = int(seed) if seed else None
|
| 368 |
length = float(length) if length else 0
|
| 369 |
-
|
| 370 |
-
return video_path
|
| 371 |
|
| 372 |
with gr.Blocks(title="MoMask") as demo:
|
| 373 |
gr.Markdown("## [MoMask](https://github.com/EricGuo5513/momask-codes) - Text to Motion")
|
| 374 |
-
gr.Markdown("Generate 3D human skeleton animations from text descriptions.")
|
| 375 |
|
| 376 |
with gr.Row():
|
| 377 |
with gr.Column():
|
|
@@ -383,41 +377,42 @@ def create_demo():
|
|
| 383 |
info="0 = auto-estimate")
|
| 384 |
seed = gr.Number(label="Seed", value=42,
|
| 385 |
info="For reproducibility")
|
|
|
|
| 386 |
btn = gr.Button("Generate", variant="primary")
|
| 387 |
|
| 388 |
with gr.Column():
|
| 389 |
video = gr.Video(label="Generated Motion")
|
|
|
|
| 390 |
|
| 391 |
gr.Examples(
|
| 392 |
examples=[
|
| 393 |
-
["A person walks forward", 0, 42],
|
| 394 |
-
["A person is running on a treadmill", 0, 123],
|
| 395 |
-
["A person jumps up and then lands", 0, 456],
|
| 396 |
-
["A person does a salsa dance", 0, 789],
|
| 397 |
-
["A person kicks with their right leg", 0, 101],
|
| 398 |
],
|
| 399 |
-
inputs=[text, length, seed],
|
| 400 |
-
outputs=video,
|
| 401 |
fn=generate_fn,
|
| 402 |
cache_examples=False,
|
| 403 |
)
|
| 404 |
|
| 405 |
-
btn.click(fn=generate_fn, inputs=[text, length, seed], outputs=video)
|
| 406 |
|
| 407 |
return demo
|
| 408 |
|
| 409 |
# ============ CLI ============
|
| 410 |
if __name__ == "__main__":
|
| 411 |
if len(sys.argv) > 1:
|
| 412 |
-
# CLI mode: python app.py "motion description" [length] [seed]
|
| 413 |
text = sys.argv[1]
|
| 414 |
length = float(sys.argv[2]) if len(sys.argv) > 2 else 0
|
| 415 |
seed = int(sys.argv[3]) if len(sys.argv) > 3 else 42
|
| 416 |
|
| 417 |
-
joints, video_path = generate_motion(text, length, seed)
|
| 418 |
-
print(f"
|
|
|
|
| 419 |
print(f"Joints shape: {joints.shape}")
|
| 420 |
else:
|
| 421 |
-
# Gradio mode
|
| 422 |
demo = create_demo()
|
| 423 |
-
demo.launch()
|
|
|
|
| 13 |
import matplotlib
|
| 14 |
matplotlib.use('Agg')
|
| 15 |
import matplotlib.pyplot as plt
|
| 16 |
+
from matplotlib.animation import FuncAnimation, FFMpegWriter
|
|
|
|
|
|
|
| 17 |
from pathlib import Path
|
| 18 |
+
|
| 19 |
# ============ Quaternion Operations ============
|
| 20 |
def qinv(q):
|
| 21 |
"""Invert quaternion"""
|
|
|
|
| 29 |
assert q.shape[-1] == 4
|
| 30 |
assert v.shape[-1] == 3
|
| 31 |
assert q.shape[:-1] == v.shape[:-1]
|
| 32 |
+
|
| 33 |
original_shape = list(v.shape)
|
| 34 |
q = q.contiguous().view(-1, 4)
|
| 35 |
v = v.contiguous().view(-1, 3)
|
| 36 |
+
|
| 37 |
qvec = q[:, 1:]
|
| 38 |
uv = torch.cross(qvec, v, dim=1)
|
| 39 |
uuv = torch.cross(qvec, uv, dim=1)
|
|
|
|
| 65 |
path = ONNX_DIR / f"{name}.onnx"
|
| 66 |
if not path.exists():
|
| 67 |
raise FileNotFoundError(f"Model not found: {path}")
|
| 68 |
+
sessions[name] = ort.InferenceSession(str(path), providers=["CPUExecutionProvider"])
|
| 69 |
return sessions[name]
|
| 70 |
|
|
|
|
| 71 |
# ============ Motion Recovery ============
|
| 72 |
def recover_root_rot_pos(data):
|
| 73 |
"""Recover root rotation and position from motion data"""
|
|
|
|
| 102 |
# ============ Visualization ============
|
| 103 |
def plot_3d_motion(save_path, joints, title, fps=20):
|
| 104 |
"""Create MP4 video of 3D skeleton motion"""
|
| 105 |
+
fig = plt.figure(figsize=(8, 8))
|
| 106 |
+
ax = fig.add_subplot(111, projection="3d")
|
| 107 |
+
COLORS = ["red", "blue", "black", "green", "purple"]
|
| 108 |
+
|
| 109 |
+
def init():
|
| 110 |
+
ax.set_xlim(-1.5, 1.5)
|
| 111 |
+
ax.set_ylim(-1.5, 1.5)
|
| 112 |
+
ax.set_zlim(0, 2)
|
| 113 |
+
ax.set_xlabel("X")
|
| 114 |
+
ax.set_ylabel("Z")
|
| 115 |
+
ax.set_zlabel("Y (up)")
|
| 116 |
+
ax.set_title(title)
|
| 117 |
+
return []
|
| 118 |
+
|
| 119 |
+
lines = []
|
| 120 |
+
for i, chain in enumerate(T2M_KINEMATIC_CHAIN):
|
| 121 |
+
line, = ax.plot([], [], [], color=COLORS[i], linewidth=2, marker="o", markersize=3)
|
| 122 |
+
lines.append(line)
|
| 123 |
+
|
| 124 |
+
def update(frame):
|
| 125 |
+
data = joints[frame]
|
| 126 |
+
for i, chain in enumerate(T2M_KINEMATIC_CHAIN):
|
| 127 |
+
x = [data[j, 0] for j in chain]
|
| 128 |
+
y = [data[j, 2] for j in chain]
|
| 129 |
+
z = [data[j, 1] for j in chain]
|
| 130 |
+
lines[i].set_data(x, y)
|
| 131 |
+
lines[i].set_3d_properties(z)
|
| 132 |
+
ax.view_init(elev=20, azim=45 + frame * 0.5)
|
| 133 |
+
return lines
|
| 134 |
+
|
| 135 |
+
ani = FuncAnimation(fig, update, frames=len(joints), init_func=init, blit=False, interval=1000//fps)
|
| 136 |
+
writer = FFMpegWriter(fps=fps, bitrate=2000)
|
| 137 |
+
ani.save(save_path, writer=writer)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
plt.close()
|
| 139 |
+
# ============ BVH Export ============
|
| 140 |
+
def joints_to_bvh(joints, output_path, fps=20):
|
| 141 |
+
"""Convert joint positions to BVH format for Blender import."""
|
| 142 |
+
n_frames, n_joints, _ = joints.shape
|
| 143 |
+
|
| 144 |
+
joint_names = [
|
| 145 |
+
"Hips", "LeftUpLeg", "RightUpLeg", "Spine", "LeftLeg", "RightLeg",
|
| 146 |
+
"Spine1", "LeftFoot", "RightFoot", "Spine2", "LeftToe", "RightToe",
|
| 147 |
+
"Neck", "LeftShoulder", "RightShoulder", "Head", "LeftArm", "RightArm",
|
| 148 |
+
"LeftForeArm", "RightForeArm", "LeftHand", "RightHand"
|
| 149 |
+
]
|
| 150 |
+
|
| 151 |
+
parents = [-1, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14, 16, 17, 18, 19]
|
| 152 |
+
|
| 153 |
+
offsets = np.zeros((n_joints, 3))
|
| 154 |
+
ref_frame = joints[0]
|
| 155 |
+
for i in range(n_joints):
|
| 156 |
+
if parents[i] >= 0:
|
| 157 |
+
offsets[i] = ref_frame[i] - ref_frame[parents[i]]
|
| 158 |
+
|
| 159 |
+
scale = 100.0
|
| 160 |
+
offsets *= scale
|
| 161 |
+
joints_scaled = joints * scale
|
| 162 |
+
|
| 163 |
+
with open(output_path, "w") as f:
|
| 164 |
+
f.write("HIERARCHY" + chr(10))
|
| 165 |
+
|
| 166 |
+
def write_joint(idx, indent):
|
| 167 |
+
name = joint_names[idx]
|
| 168 |
+
off = offsets[idx]
|
| 169 |
+
prefix = " " * indent
|
| 170 |
+
|
| 171 |
+
children = [i for i, p in enumerate(parents) if p == idx]
|
| 172 |
+
|
| 173 |
+
if idx == 0:
|
| 174 |
+
f.write(f"ROOT {name}" + chr(10))
|
| 175 |
+
else:
|
| 176 |
+
f.write(f"{prefix}JOINT {name}" + chr(10))
|
| 177 |
+
|
| 178 |
+
f.write(f"{prefix}{{" + chr(10))
|
| 179 |
+
f.write(f"{prefix} OFFSET {off[0]:.6f} {off[1]:.6f} {off[2]:.6f}" + chr(10))
|
| 180 |
+
|
| 181 |
+
if idx == 0:
|
| 182 |
+
f.write(f"{prefix} CHANNELS 6 Xposition Yposition Zposition Xrotation Yrotation Zrotation" + chr(10))
|
| 183 |
+
else:
|
| 184 |
+
f.write(f"{prefix} CHANNELS 3 Xrotation Yrotation Zrotation" + chr(10))
|
| 185 |
+
|
| 186 |
+
if children:
|
| 187 |
+
for child in children:
|
| 188 |
+
write_joint(child, indent + 1)
|
| 189 |
+
else:
|
| 190 |
+
f.write(f"{prefix} End Site" + chr(10))
|
| 191 |
+
f.write(f"{prefix} {{" + chr(10))
|
| 192 |
+
f.write(f"{prefix} OFFSET 0.0 0.0 0.0" + chr(10))
|
| 193 |
+
f.write(f"{prefix} }}" + chr(10))
|
| 194 |
+
|
| 195 |
+
f.write(f"{prefix}}}" + chr(10))
|
| 196 |
+
|
| 197 |
+
write_joint(0, 0)
|
| 198 |
+
|
| 199 |
+
f.write("MOTION" + chr(10))
|
| 200 |
+
f.write(f"Frames: {n_frames}" + chr(10))
|
| 201 |
+
f.write(f"Frame Time: {1.0/fps:.6f}" + chr(10))
|
| 202 |
+
|
| 203 |
+
for frame in range(n_frames):
|
| 204 |
+
root_pos = joints_scaled[frame, 0]
|
| 205 |
+
values = [root_pos[0], root_pos[1], root_pos[2]]
|
| 206 |
+
values.extend([0.0] * 3 * n_joints)
|
| 207 |
+
f.write(" ".join(f"{v:.6f}" for v in values) + chr(10))
|
| 208 |
+
|
| 209 |
+
return output_path
|
| 210 |
|
| 211 |
# ============ Sampling Utilities ============
|
| 212 |
def cosine_schedule(t):
|
|
|
|
| 217 |
"""Apply top-k filtering"""
|
| 218 |
k = int((1 - k) * logits.shape[-1])
|
| 219 |
val, ind = torch.topk(logits, k, dim=-1)
|
| 220 |
+
probs = torch.full_like(logits, float("-inf"))
|
| 221 |
probs.scatter_(-1, ind, val)
|
| 222 |
return probs
|
| 223 |
|
|
|
|
| 225 |
"""Gumbel softmax sampling"""
|
| 226 |
gumbels = -torch.log(-torch.log(torch.rand_like(logits) + 1e-8) + 1e-8)
|
| 227 |
return ((logits / max(temperature, 1e-10)) + gumbels).argmax(dim=-1)
|
|
|
|
| 228 |
# ============ Main Generation Pipeline ============
|
| 229 |
+
def generate_motion(text, motion_length=0, seed=None, export_bvh=False):
|
| 230 |
+
"""Generate motion from text prompt"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
if seed is not None:
|
| 232 |
torch.manual_seed(seed)
|
| 233 |
np.random.seed(seed)
|
| 234 |
|
|
|
|
| 235 |
mean = np.load(ONNX_DIR / "mean.npy")
|
| 236 |
std = np.load(ONNX_DIR / "std.npy")
|
| 237 |
|
|
|
|
| 238 |
tokens = clip.tokenize([text], truncate=True)
|
| 239 |
|
|
|
|
| 240 |
clip_sess = get_session("clip_text")
|
| 241 |
+
text_emb = clip_sess.run(None, {"text_tokens": tokens.numpy()})[0]
|
| 242 |
|
|
|
|
| 243 |
if motion_length <= 0:
|
| 244 |
len_sess = get_session("length_estimator")
|
| 245 |
+
len_logits = len_sess.run(None, {"text_embedding": text_emb})[0]
|
|
|
|
| 246 |
probs = torch.softmax(torch.from_numpy(len_logits), dim=-1)
|
| 247 |
token_len = torch.multinomial(probs, 1).item()
|
| 248 |
else:
|
|
|
|
| 249 |
token_len = int(motion_length * 20 / 4)
|
| 250 |
|
| 251 |
+
token_len = max(2, min(token_len, 49))
|
| 252 |
m_length = token_len * 4
|
| 253 |
+
max_len = 49
|
| 254 |
|
| 255 |
print(f"Generating motion: '{text}' ({m_length} frames, {m_length/20:.1f}s)")
|
| 256 |
|
| 257 |
+
mask_id = 512
|
| 258 |
+
pad_id = 513
|
|
|
|
| 259 |
ids = torch.full((1, max_len), pad_id, dtype=torch.long)
|
| 260 |
ids[:, :token_len] = mask_id
|
| 261 |
scores = torch.zeros(1, max_len)
|
| 262 |
+
scores[:, token_len:] = 1e5
|
| 263 |
|
|
|
|
| 264 |
padding_mask = np.zeros((1, max_len), dtype=bool)
|
| 265 |
padding_mask[:, token_len:] = True
|
| 266 |
|
|
|
|
| 267 |
mask_sess = get_session("mask_transformer")
|
| 268 |
|
| 269 |
for step in range(TIMESTEPS):
|
| 270 |
t = step / TIMESTEPS
|
| 271 |
rand_mask_prob = cosine_schedule(torch.tensor(t)).item()
|
|
|
|
|
|
|
| 272 |
num_masked = max(1, int(rand_mask_prob * token_len))
|
| 273 |
|
|
|
|
| 274 |
valid_scores = scores[:, :token_len].clone()
|
| 275 |
_, sorted_idx = valid_scores.sort(dim=1)
|
| 276 |
mask_pos = sorted_idx[:, :num_masked]
|
| 277 |
is_mask = torch.zeros(1, token_len, dtype=torch.bool)
|
| 278 |
is_mask.scatter_(1, mask_pos, True)
|
| 279 |
|
|
|
|
| 280 |
ids[:, :token_len] = torch.where(is_mask, mask_id, ids[:, :token_len])
|
| 281 |
|
|
|
|
| 282 |
logits = mask_sess.run(None, {
|
| 283 |
+
"motion_ids": ids.numpy(),
|
| 284 |
+
"cond_vector": text_emb,
|
| 285 |
+
"padding_mask": padding_mask
|
| 286 |
+
})[0]
|
|
|
|
|
|
|
| 287 |
|
| 288 |
+
logits = torch.from_numpy(logits)
|
| 289 |
+
logits = logits[:, :512, :token_len]
|
| 290 |
+
logits = logits.permute(0, 2, 1)
|
| 291 |
|
|
|
|
| 292 |
filtered_logits = top_k_filter(logits / TEMPERATURE, TOPK_FILTER)
|
| 293 |
+
new_ids = gumbel_sample(filtered_logits, TEMPERATURE)
|
| 294 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 295 |
probs = torch.softmax(filtered_logits, dim=-1)
|
| 296 |
new_scores = probs.gather(-1, new_ids.unsqueeze(-1)).squeeze(-1)
|
| 297 |
|
|
|
|
| 298 |
ids[:, :token_len] = torch.where(is_mask, new_ids, ids[:, :token_len])
|
| 299 |
scores[:, :token_len] = torch.where(is_mask, new_scores, scores[:, :token_len])
|
|
|
|
|
|
|
| 300 |
res_sess = get_session("residual_transformer")
|
| 301 |
num_quantizers = 6
|
| 302 |
|
| 303 |
+
res_token_embed = np.load(ONNX_DIR / "res_token_embed.npy")
|
|
|
|
| 304 |
|
|
|
|
| 305 |
all_codes = torch.zeros(1, max_len, num_quantizers, dtype=torch.long)
|
| 306 |
all_codes[:, :, 0] = ids
|
| 307 |
|
|
|
|
| 308 |
history_sum = np.zeros((1, max_len, 512), dtype=np.float32)
|
| 309 |
motion_ids = ids.clone()
|
| 310 |
|
| 311 |
for q in range(1, num_quantizers):
|
| 312 |
+
token_embed = res_token_embed[q-1]
|
|
|
|
|
|
|
| 313 |
clamped_ids = np.clip(motion_ids[0].numpy(), 0, 512)
|
| 314 |
+
gathered = token_embed[clamped_ids]
|
| 315 |
+
history_sum += gathered[np.newaxis, :, :]
|
| 316 |
|
| 317 |
q_id = np.array([q], dtype=np.int64)
|
| 318 |
|
| 319 |
logits = res_sess.run(None, {
|
| 320 |
+
"motion_codes": history_sum.astype(np.float32),
|
| 321 |
+
"q_id": q_id,
|
| 322 |
+
"cond_vector": text_emb,
|
| 323 |
+
"padding_mask": padding_mask
|
| 324 |
})[0]
|
| 325 |
|
| 326 |
logits = torch.from_numpy(logits)[:, :512, :token_len].permute(0, 2, 1)
|
| 327 |
new_ids_q = gumbel_sample(logits, 1.0)
|
| 328 |
all_codes[:, :token_len, q] = new_ids_q
|
| 329 |
+
motion_ids[:, :token_len] = new_ids_q
|
| 330 |
|
|
|
|
| 331 |
decoder_sess = get_session("vqvae_decoder")
|
| 332 |
valid_codes = all_codes[:, :token_len, :].numpy()
|
| 333 |
motion = decoder_sess.run(None, {
|
| 334 |
+
"code_indices": valid_codes
|
| 335 |
+
})[0]
|
| 336 |
|
|
|
|
| 337 |
motion = np.repeat(motion, 4, axis=1)[:, :m_length, :]
|
|
|
|
|
|
|
| 338 |
motion = motion * std + mean
|
| 339 |
|
|
|
|
| 340 |
motion_tensor = torch.from_numpy(motion).float()
|
| 341 |
joints = recover_from_ric(motion_tensor, JOINTS_NUM)
|
| 342 |
+
joints = joints.squeeze(0).numpy()
|
| 343 |
|
| 344 |
+
video_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
|
|
|
|
| 345 |
plot_3d_motion(video_path, joints, text, fps=20)
|
| 346 |
|
| 347 |
+
bvh_path = None
|
| 348 |
+
if export_bvh:
|
| 349 |
+
bvh_path = tempfile.NamedTemporaryFile(suffix=".bvh", delete=False).name
|
| 350 |
+
joints_to_bvh(joints, bvh_path, fps=20)
|
| 351 |
+
print(f"BVH exported: {bvh_path}")
|
| 352 |
|
| 353 |
+
return joints, video_path, bvh_path
|
| 354 |
# ============ Gradio Interface ============
|
| 355 |
def create_demo():
|
| 356 |
import gradio as gr
|
| 357 |
|
| 358 |
+
def generate_fn(text, length, seed, export_bvh):
|
| 359 |
if not text or text.strip() == "":
|
| 360 |
+
return None, None
|
| 361 |
seed = int(seed) if seed else None
|
| 362 |
length = float(length) if length else 0
|
| 363 |
+
joints, video_path, bvh_path = generate_motion(text, length, seed, export_bvh)
|
| 364 |
+
return video_path, bvh_path
|
| 365 |
|
| 366 |
with gr.Blocks(title="MoMask") as demo:
|
| 367 |
gr.Markdown("## [MoMask](https://github.com/EricGuo5513/momask-codes) - Text to Motion")
|
| 368 |
+
gr.Markdown("Generate 3D human skeleton animations from text descriptions. Download BVH for Blender!")
|
| 369 |
|
| 370 |
with gr.Row():
|
| 371 |
with gr.Column():
|
|
|
|
| 377 |
info="0 = auto-estimate")
|
| 378 |
seed = gr.Number(label="Seed", value=42,
|
| 379 |
info="For reproducibility")
|
| 380 |
+
export_bvh = gr.Checkbox(label="Export BVH for Blender", value=True)
|
| 381 |
btn = gr.Button("Generate", variant="primary")
|
| 382 |
|
| 383 |
with gr.Column():
|
| 384 |
video = gr.Video(label="Generated Motion")
|
| 385 |
+
bvh_file = gr.File(label="BVH Download")
|
| 386 |
|
| 387 |
gr.Examples(
|
| 388 |
examples=[
|
| 389 |
+
["A person walks forward", 0, 42, True],
|
| 390 |
+
["A person is running on a treadmill", 0, 123, True],
|
| 391 |
+
["A person jumps up and then lands", 0, 456, True],
|
| 392 |
+
["A person does a salsa dance", 0, 789, True],
|
| 393 |
+
["A person kicks with their right leg", 0, 101, True],
|
| 394 |
],
|
| 395 |
+
inputs=[text, length, seed, export_bvh],
|
| 396 |
+
outputs=[video, bvh_file],
|
| 397 |
fn=generate_fn,
|
| 398 |
cache_examples=False,
|
| 399 |
)
|
| 400 |
|
| 401 |
+
btn.click(fn=generate_fn, inputs=[text, length, seed, export_bvh], outputs=[video, bvh_file])
|
| 402 |
|
| 403 |
return demo
|
| 404 |
|
| 405 |
# ============ CLI ============
|
| 406 |
if __name__ == "__main__":
|
| 407 |
if len(sys.argv) > 1:
|
|
|
|
| 408 |
text = sys.argv[1]
|
| 409 |
length = float(sys.argv[2]) if len(sys.argv) > 2 else 0
|
| 410 |
seed = int(sys.argv[3]) if len(sys.argv) > 3 else 42
|
| 411 |
|
| 412 |
+
joints, video_path, bvh_path = generate_motion(text, length, seed, export_bvh=True)
|
| 413 |
+
print(f"Video: {video_path}")
|
| 414 |
+
print(f"BVH: {bvh_path}")
|
| 415 |
print(f"Joints shape: {joints.shape}")
|
| 416 |
else:
|
|
|
|
| 417 |
demo = create_demo()
|
| 418 |
+
demo.launch()
|