File size: 21,293 Bytes
4bd136e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
"""
Visualization script to convert motion tokens to SMPL-X 3D animation.
Requires VQ-VAE checkpoint, dataset stats, and SMPL-X model files.

Usage:
    # Visualize from LLM output string
    python visualize.py --tokens "<MOT_BEGIN><motion_177><motion_135>...<MOT_END>"
    
    # Visualize from saved file
    python visualize.py --input motion_output.txt
    
    # Generate and visualize in one go
    python visualize.py --prompt "walking" --stage 3
    
    # Custom paths
    python visualize.py --tokens "..." --vqvae-ckpt /path/to/vqvae.pt --smplx-dir /path/to/smplx
"""
import os
import sys
import re
import argparse
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from config import WORK_DIR, DATA_DIR

# Try importing visualization dependencies
try:
    import plotly.graph_objects as go
except ImportError:
    print("Installing plotly...")
    os.system("pip install -q plotly")
    import plotly.graph_objects as go

try:
    import smplx
except ImportError:
    print("Installing smplx...")
    os.system("pip install -q smplx==0.1.28")
    import smplx

# =====================================================================
# Configuration - can be overridden via command-line or environment
# =====================================================================
# VQ-VAE checkpoint path (trained motion encoder/decoder)
VQVAE_CHECKPOINT = os.environ.get(
    "VQVAE_CHECKPOINT",
    os.path.join(DATA_DIR, "vqvae_model.pt")
)

# Dataset normalization stats (mean/std used during VQ-VAE training)
STATS_PATH = os.environ.get(
    "VQVAE_STATS_PATH",
    os.path.join(DATA_DIR, "vqvae_stats.pt")
)

# SMPL-X model directory (contains SMPLX_NEUTRAL.npz, etc.)
SMPLX_MODEL_DIR = os.environ.get(
    "SMPLX_MODEL_DIR",
    os.path.join(DATA_DIR, "smplx_models")
)

# Output directory for HTML animations
OUTPUT_DIR = os.environ.get("VIS_OUTPUT_DIR", WORK_DIR)

# Device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# VQ-VAE architecture params (must match training config)
SMPL_DIM = 182
CODEBOOK_SIZE = 512
CODE_DIM = 512
VQ_ARGS = dict(
    width=512,
    depth=3,
    down_t=2,
    stride_t=2,
    dilation_growth_rate=3,
    activation='relu',
    norm=None,
    quantizer="ema_reset"
)

# SMPL-X parameter layout (must match VQ-VAE training)
PARAM_DIMS = [10, 63, 45, 45, 3, 10, 3, 3]
PARAM_NAMES = ["betas", "body_pose", "left_hand_pose", "right_hand_pose",
               "trans", "expression", "jaw_pose", "eye_pose"]

# =====================================================================
# Import VQ-VAE architecture
# =====================================================================
try:
    # Add SignMotionGPT to path if not already
    sign_mgpt_dir = os.path.join(os.path.dirname(__file__))
    if sign_mgpt_dir not in sys.path:
        sys.path.insert(0, sign_mgpt_dir)
    
    from mGPT.archs.mgpt_vq import VQVae
except ImportError as e:
    print(f"❌ Could not import VQVae: {e}")
    print("Make sure mGPT/archs/mgpt_vq.py exists in the project.")
    sys.exit(1)


# =====================================================================
# VQ-VAE Wrapper
# =====================================================================
class MotionGPT_VQVAE_Wrapper(nn.Module):
    """Wrapper matching the VQ-VAE training setup"""
    def __init__(self, smpl_dim=SMPL_DIM, codebook_size=CODEBOOK_SIZE, 
                 code_dim=CODE_DIM, **kwargs):
        super().__init__()
        self.vqvae = VQVae(
            nfeats=smpl_dim,
            code_num=codebook_size,
            code_dim=code_dim,
            output_emb_width=code_dim,
            **kwargs
        )


# =====================================================================
# Token Parsing
# =====================================================================
def parse_motion_tokens(token_str):
    """
    Parse motion tokens from LLM output string.
    Accepts:
      - "<MOT_BEGIN><motion_177><motion_135>...<MOT_END>"
      - "177 135 152 200 46..."
      - List/array of ints
    
    Returns:
        List of token integers
    """
    if isinstance(token_str, (list, tuple, np.ndarray)):
        return [int(x) for x in token_str]
    
    if not isinstance(token_str, str):
        raise ValueError("Tokens must be string or list-like")
    
    # Try extracting <motion_ID> tokens
    matches = re.findall(r'<motion_(\d+)>', token_str)
    if matches:
        return [int(x) for x in matches]
    
    # Try space-separated numbers
    token_str = token_str.strip()
    if token_str:
        try:
            return [int(x) for x in token_str.split()]
        except ValueError:
            pass
    
    raise ValueError(f"Could not parse motion tokens from: {token_str[:100]}...")


# =====================================================================
# Model Loading
# =====================================================================
def load_vqvae(checkpoint_path, device=DEVICE, vq_args=VQ_ARGS):
    """Load trained VQ-VAE model from checkpoint"""
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(
            f"VQ-VAE checkpoint not found: {checkpoint_path}\n"
            f"Please download it and set VQVAE_CHECKPOINT environment variable "
            f"or use --vqvae-ckpt argument."
        )
    
    print(f"Loading VQ-VAE from: {checkpoint_path}")
    model = MotionGPT_VQVAE_Wrapper(
        smpl_dim=SMPL_DIM,
        codebook_size=CODEBOOK_SIZE,
        code_dim=CODE_DIM,
        **vq_args
    ).to(device)
    
    ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
    state_dict = ckpt.get('model_state_dict', ckpt)
    model.load_state_dict(state_dict, strict=False)
    model.eval()
    
    print(f"βœ… VQ-VAE loaded (codebook size: {CODEBOOK_SIZE})")
    return model


def load_stats(stats_path):
    """Load normalization statistics (mean/std) used during VQ-VAE training"""
    if not stats_path or not os.path.exists(stats_path):
        print(f"⚠️  Stats file not found: {stats_path}")
        print("   Will skip denormalization (may affect quality)")
        return None, None
    
    print(f"Loading stats from: {stats_path}")
    st = torch.load(stats_path, map_location='cpu', weights_only=False)
    mean = st.get('mean', 0)
    std = st.get('std', 1)
    
    # Convert to numpy
    if torch.is_tensor(mean):
        mean = mean.cpu().numpy()
    if torch.is_tensor(std):
        std = std.cpu().numpy()
    
    print(f"βœ… Stats loaded (mean shape: {np.array(mean).shape})")
    return mean, std


def load_smplx_model(model_dir, device=DEVICE):
    """Load SMPL-X body model"""
    if not os.path.exists(model_dir):
        raise FileNotFoundError(
            f"SMPL-X model directory not found: {model_dir}\n"
            f"Please download SMPL-X models and set SMPLX_MODEL_DIR environment variable "
            f"or use --smplx-dir argument."
        )
    
    print(f"Loading SMPL-X from: {model_dir}")
    model = smplx.SMPLX(
        model_path=model_dir,
        model_type='smplx',
        gender='neutral',
        use_pca=False,
        create_global_orient=True,
        create_body_pose=True,
        create_betas=True,
        create_expression=True,
        create_jaw_pose=True,
        create_left_hand_pose=True,
        create_right_hand_pose=True,
        create_transl=True
    ).to(device)
    
    print(f"βœ… SMPL-X loaded")
    return model


# =====================================================================
# Token Decoding
# =====================================================================
def decode_tokens_to_params(tokens, vqvae_model, mean=None, std=None, device=DEVICE):
    """
    Decode motion tokens to SMPL-X parameters.
    
    Args:
        tokens: List of motion token IDs
        vqvae_model: Trained VQ-VAE model
        mean: Optional normalization mean
        std: Optional normalization std
        device: Device to run on
    
    Returns:
        numpy array of shape (T, SMPL_DIM) with SMPL-X parameters
    """
    if not tokens:
        return np.zeros((0, SMPL_DIM), dtype=np.float32)
    
    # Prepare token indices
    idx = torch.tensor(tokens, dtype=torch.long, device=device).unsqueeze(0)  # (1, T_q)
    T_q = idx.shape[1]
    
    quantizer = vqvae_model.vqvae.quantizer
    
    # Get code dimension
    if hasattr(quantizer, "codebook"):
        codebook = quantizer.codebook.to(device)
        code_dim = codebook.shape[1]
    else:
        code_dim = CODE_DIM
    
    # Dequantize tokens
    x_quantized = None
    if hasattr(quantizer, "dequantize"):
        try:
            with torch.no_grad():
                dq = quantizer.dequantize(idx)
            if dq is not None:
                dq = dq.contiguous()
                # Ensure shape is (N, code_dim, T_q)
                if dq.ndim == 3 and dq.shape[1] == code_dim:
                    x_quantized = dq
                elif dq.ndim == 3 and dq.shape[1] == T_q:
                    x_quantized = dq.permute(0, 2, 1).contiguous()
                else:
                    x_quantized = None
        except Exception:
            x_quantized = None
    
    # Fallback: manual codebook lookup
    if x_quantized is None:
        if not hasattr(quantizer, "codebook"):
            raise RuntimeError("No dequantize method and no codebook available")
        with torch.no_grad():
            emb = codebook[idx]  # (1, T_q, code_dim)
            x_quantized = emb.permute(0, 2, 1).contiguous()  # (1, code_dim, T_q)
    
    # Decode through VQ-VAE decoder
    with torch.no_grad():
        x_dec = vqvae_model.vqvae.decoder(x_quantized)
        smpl_out = vqvae_model.vqvae.postprocess(x_dec)  # (1, T_out, SMPL_DIM)
        params_np = smpl_out.squeeze(0).cpu().numpy()  # (T_out, SMPL_DIM)
    
    # Denormalize if stats provided
    if (mean is not None) and (std is not None):
        mean_arr = np.array(mean).reshape(1, -1)
        std_arr = np.array(std).reshape(1, -1)
        params_np = (params_np * std_arr) + mean_arr
    
    return params_np


# =====================================================================
# SMPL-X Parameter to Vertices
# =====================================================================
def params_to_vertices(params_seq, smplx_model, batch_size=32):
    """
    Convert SMPL-X parameters to 3D vertices.
    
    Args:
        params_seq: numpy array (T, SMPL_DIM)
        smplx_model: loaded SMPL-X model
        batch_size: batch size for processing
    
    Returns:
        verts: numpy array (T, V, 3)
        faces: numpy array (F, 3)
    """
    # Compute parameter slicing indices
    starts = np.cumsum([0] + PARAM_DIMS[:-1])
    ends = starts + np.array(PARAM_DIMS)
    
    T = params_seq.shape[0]
    all_verts = []
    
    # Infer number of body joints
    num_body_joints = getattr(smplx_model, "NUM_BODY_JOINTS", 21)
    
    with torch.no_grad():
        for s in range(0, T, batch_size):
            batch = params_seq[s:s+batch_size]  # (B, SMPL_DIM)
            B = batch.shape[0]
            
            # Extract parameters
            np_parts = {}
            for name, st, ed in zip(PARAM_NAMES, starts, ends):
                np_parts[name] = batch[:, st:ed].astype(np.float32)
            
            # Convert to tensors
            tensor_parts = {
                name: torch.from_numpy(arr).to(DEVICE)
                for name, arr in np_parts.items()
            }
            
            # Handle body pose (may or may not include global orient)
            body_t = tensor_parts['body_pose']
            L_body = body_t.shape[1]
            expected_no_go = num_body_joints * 3
            expected_with_go = (num_body_joints + 1) * 3
            
            if L_body == expected_with_go:
                global_orient = body_t[:, :3].contiguous()
                body_pose_only = body_t[:, 3:].contiguous()
            elif L_body == expected_no_go:
                global_orient = torch.zeros((B, 3), dtype=torch.float32, device=DEVICE)
                body_pose_only = body_t
            else:
                # Best-effort fallback
                if L_body > expected_no_go:
                    global_orient = body_t[:, :3].contiguous()
                    body_pose_only = body_t[:, 3:].contiguous()
                else:
                    pad_len = max(0, expected_no_go - L_body)
                    body_pose_only = F.pad(body_t, (0, pad_len))
                    global_orient = torch.zeros((B, 3), dtype=torch.float32, device=DEVICE)
            
            # Call SMPL-X
            out = smplx_model(
                betas=tensor_parts['betas'],
                global_orient=global_orient,
                body_pose=body_pose_only,
                left_hand_pose=tensor_parts['left_hand_pose'],
                right_hand_pose=tensor_parts['right_hand_pose'],
                expression=tensor_parts['expression'],
                jaw_pose=tensor_parts['jaw_pose'],
                leye_pose=tensor_parts['eye_pose'],
                reye_pose=tensor_parts['eye_pose'],
                transl=tensor_parts['trans'],
                return_verts=True
            )
            
            verts = out.vertices.detach().cpu().numpy()  # (B, V, 3)
            all_verts.append(verts)
    
    verts_all = np.concatenate(all_verts, axis=0)  # (T, V, 3)
    faces = smplx_model.faces.astype(np.int32)
    
    return verts_all, faces


# =====================================================================
# Visualization
# =====================================================================
def animate_motion(verts, faces, title="Generated Motion", output_path=None, fps=20):
    """
    Create interactive 3D animation using Plotly.
    
    Args:
        verts: numpy array (T, V, 3)
        faces: numpy array (F, 3)
        title: Plot title
        output_path: Path to save HTML file
        fps: Frames per second for animation
    
    Returns:
        Plotly figure object
    """
    T, V, _ = verts.shape
    i, j, k = faces.T.tolist()
    
    # Initial mesh
    mesh = go.Mesh3d(
        x=verts[0, :, 0],
        y=verts[0, :, 1],
        z=verts[0, :, 2],
        i=i, j=j, k=k,
        name=title,
        flatshading=True,
        opacity=0.7
    )
    
    # Create frames
    frames = [
        go.Frame(
            data=[go.Mesh3d(
                x=verts[t, :, 0],
                y=verts[t, :, 1],
                z=verts[t, :, 2],
                i=i, j=j, k=k,
                flatshading=True,
                opacity=0.7
            )],
            name=str(t)
        )
        for t in range(T)
    ]
    
    # Create figure
    fig = go.Figure(data=[mesh], frames=frames)
    
    fig.update_layout(
        title_text=title,
        scene=dict(
            aspectmode='data',
            xaxis=dict(visible=False),
            yaxis=dict(visible=False),
            zaxis=dict(visible=False),
            camera=dict(eye=dict(x=0, y=-2, z=0.7))
        ),
        updatemenus=[dict(
            type="buttons",
            buttons=[
                dict(
                    label="Play",
                    method="animate",
                    args=[None, {
                        "frame": {"duration": 1000//fps, "redraw": True},
                        "fromcurrent": True
                    }]
                ),
                dict(
                    label="Pause",
                    method="animate",
                    args=[[None], {
                        "frame": {"duration": 0, "redraw": False}
                    }]
                )
            ]
        )]
    )
    
    # Save HTML
    if output_path:
        Path(output_path).parent.mkdir(parents=True, exist_ok=True)
        fig.write_html(output_path)
        print(f"βœ… Animation saved to: {output_path}")
    
    return fig


# =====================================================================
# Main Visualization Pipeline
# =====================================================================
def visualize(
    tokens,
    vqvae_ckpt=VQVAE_CHECKPOINT,
    stats_path=STATS_PATH,
    smplx_dir=SMPLX_MODEL_DIR,
    output_html=None,
    title="Generated Motion",
    fps=20
):
    """
    Complete visualization pipeline: tokens -> vertices -> animation.
    
    Args:
        tokens: Motion tokens (string or list of ints)
        vqvae_ckpt: Path to VQ-VAE checkpoint
        stats_path: Path to normalization stats
        smplx_dir: Path to SMPL-X model directory
        output_html: Path to save HTML animation
        title: Animation title
        fps: Frames per second
    
    Returns:
        Plotly figure object
    """
    print("="*60)
    print("Motion Visualization Pipeline")
    print("="*60)
    
    # Parse tokens
    print("\n[1/5] Parsing tokens...")
    token_list = parse_motion_tokens(tokens)
    print(f"   Parsed {len(token_list)} tokens")
    if not token_list:
        print("❌ No tokens to visualize")
        return None
    
    # Load models
    print("\n[2/5] Loading VQ-VAE...")
    vq_model = load_vqvae(vqvae_ckpt, device=DEVICE)
    
    print("\n[3/5] Loading normalization stats...")
    mean, std = load_stats(stats_path)
    
    print("\n[4/5] Loading SMPL-X model...")
    smplx_model = load_smplx_model(smplx_dir, device=DEVICE)
    
    # Decode tokens
    print("\n[5/5] Decoding and rendering...")
    print("   Decoding tokens to SMPL-X parameters...")
    params = decode_tokens_to_params(token_list, vq_model, mean, std, device=DEVICE)
    print(f"   Decoded params shape: {params.shape}")
    
    if params.shape[0] == 0:
        print("❌ No frames produced from decoder")
        return None
    
    # Convert to vertices
    print("   Converting parameters to vertices...")
    verts, faces = params_to_vertices(params, smplx_model, batch_size=32)
    print(f"   Vertices shape: {verts.shape}, Faces: {faces.shape}")
    
    # Create animation
    print("   Creating animation...")
    if output_html is None:
        output_html = os.path.join(OUTPUT_DIR, "motion_animation.html")
    
    fig = animate_motion(verts, faces, title=title, output_path=output_html, fps=fps)
    
    print("\n" + "="*60)
    print("βœ… Visualization complete!")
    print("="*60)
    
    return fig


# =====================================================================
# CLI
# =====================================================================
def main():
    parser = argparse.ArgumentParser(
        description="Visualize motion tokens as 3D SMPL-X animation"
    )
    
    # Input options (mutually exclusive)
    input_group = parser.add_mutually_exclusive_group(required=True)
    input_group.add_argument(
        "--tokens",
        type=str,
        help="Motion tokens string (e.g., '<MOT_BEGIN><motion_177>...<MOT_END>' or '177 135 152...')"
    )
    input_group.add_argument(
        "--input",
        type=str,
        help="Path to file containing motion tokens"
    )
    input_group.add_argument(
        "--prompt",
        type=str,
        help="Generate tokens from text prompt first (requires --stage)"
    )
    
    # Generation options (if using --prompt)
    parser.add_argument(
        "--stage",
        type=int,
        default=3,
        choices=[1, 2, 3],
        help="Stage model to use for generation (default: 3)"
    )
    
    # Model paths
    parser.add_argument(
        "--vqvae-ckpt",
        type=str,
        default=VQVAE_CHECKPOINT,
        help=f"Path to VQ-VAE checkpoint (default: {VQVAE_CHECKPOINT})"
    )
    parser.add_argument(
        "--stats",
        type=str,
        default=STATS_PATH,
        help=f"Path to normalization stats (default: {STATS_PATH})"
    )
    parser.add_argument(
        "--smplx-dir",
        type=str,
        default=SMPLX_MODEL_DIR,
        help=f"Path to SMPL-X model directory (default: {SMPLX_MODEL_DIR})"
    )
    
    # Output options
    parser.add_argument(
        "--output",
        type=str,
        default=None,
        help="Path to save HTML animation (default: motion_animation.html)"
    )
    parser.add_argument(
        "--title",
        type=str,
        default="Generated Motion",
        help="Animation title"
    )
    parser.add_argument(
        "--fps",
        type=int,
        default=20,
        help="Frames per second for animation (default: 20)"
    )
    
    args = parser.parse_args()
    
    # Get tokens
    if args.prompt:
        # Generate tokens first using inference.py
        print("Generating motion tokens from prompt...")
        from inference import inference
        tokens = inference(
            prompt=args.prompt,
            stage=args.stage,
            output_file=None,
            per_prompt_vocab=True
        )
    elif args.input:
        # Read from file
        with open(args.input, 'r') as f:
            tokens = f.read().strip()
    else:
        # Direct token string
        tokens = args.tokens
    
    # Visualize
    visualize(
        tokens=tokens,
        vqvae_ckpt=args.vqvae_ckpt,
        stats_path=args.stats,
        smplx_dir=args.smplx_dir,
        output_html=args.output,
        title=args.title,
        fps=args.fps
    )


if __name__ == "__main__":
    main()