SignMotionGPT / visualize.py
rdz-falcon's picture
Deploy SignMotionGPT Demo with LFS
4bd136e
"""
Visualization script to convert motion tokens to SMPL-X 3D animation.
Requires VQ-VAE checkpoint, dataset stats, and SMPL-X model files.
Usage:
# Visualize from LLM output string
python visualize.py --tokens "<MOT_BEGIN><motion_177><motion_135>...<MOT_END>"
# Visualize from saved file
python visualize.py --input motion_output.txt
# Generate and visualize in one go
python visualize.py --prompt "walking" --stage 3
# Custom paths
python visualize.py --tokens "..." --vqvae-ckpt /path/to/vqvae.pt --smplx-dir /path/to/smplx
"""
import os
import sys
import re
import argparse
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from config import WORK_DIR, DATA_DIR
# Try importing visualization dependencies
try:
import plotly.graph_objects as go
except ImportError:
print("Installing plotly...")
os.system("pip install -q plotly")
import plotly.graph_objects as go
try:
import smplx
except ImportError:
print("Installing smplx...")
os.system("pip install -q smplx==0.1.28")
import smplx
# =====================================================================
# Configuration - can be overridden via command-line or environment
# =====================================================================
# VQ-VAE checkpoint path (trained motion encoder/decoder)
VQVAE_CHECKPOINT = os.environ.get(
"VQVAE_CHECKPOINT",
os.path.join(DATA_DIR, "vqvae_model.pt")
)
# Dataset normalization stats (mean/std used during VQ-VAE training)
STATS_PATH = os.environ.get(
"VQVAE_STATS_PATH",
os.path.join(DATA_DIR, "vqvae_stats.pt")
)
# SMPL-X model directory (contains SMPLX_NEUTRAL.npz, etc.)
SMPLX_MODEL_DIR = os.environ.get(
"SMPLX_MODEL_DIR",
os.path.join(DATA_DIR, "smplx_models")
)
# Output directory for HTML animations
OUTPUT_DIR = os.environ.get("VIS_OUTPUT_DIR", WORK_DIR)
# Device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# VQ-VAE architecture params (must match training config)
SMPL_DIM = 182
CODEBOOK_SIZE = 512
CODE_DIM = 512
VQ_ARGS = dict(
width=512,
depth=3,
down_t=2,
stride_t=2,
dilation_growth_rate=3,
activation='relu',
norm=None,
quantizer="ema_reset"
)
# SMPL-X parameter layout (must match VQ-VAE training)
PARAM_DIMS = [10, 63, 45, 45, 3, 10, 3, 3]
PARAM_NAMES = ["betas", "body_pose", "left_hand_pose", "right_hand_pose",
"trans", "expression", "jaw_pose", "eye_pose"]
# =====================================================================
# Import VQ-VAE architecture
# =====================================================================
try:
# Add SignMotionGPT to path if not already
sign_mgpt_dir = os.path.join(os.path.dirname(__file__))
if sign_mgpt_dir not in sys.path:
sys.path.insert(0, sign_mgpt_dir)
from mGPT.archs.mgpt_vq import VQVae
except ImportError as e:
print(f"❌ Could not import VQVae: {e}")
print("Make sure mGPT/archs/mgpt_vq.py exists in the project.")
sys.exit(1)
# =====================================================================
# VQ-VAE Wrapper
# =====================================================================
class MotionGPT_VQVAE_Wrapper(nn.Module):
"""Wrapper matching the VQ-VAE training setup"""
def __init__(self, smpl_dim=SMPL_DIM, codebook_size=CODEBOOK_SIZE,
code_dim=CODE_DIM, **kwargs):
super().__init__()
self.vqvae = VQVae(
nfeats=smpl_dim,
code_num=codebook_size,
code_dim=code_dim,
output_emb_width=code_dim,
**kwargs
)
# =====================================================================
# Token Parsing
# =====================================================================
def parse_motion_tokens(token_str):
"""
Parse motion tokens from LLM output string.
Accepts:
- "<MOT_BEGIN><motion_177><motion_135>...<MOT_END>"
- "177 135 152 200 46..."
- List/array of ints
Returns:
List of token integers
"""
if isinstance(token_str, (list, tuple, np.ndarray)):
return [int(x) for x in token_str]
if not isinstance(token_str, str):
raise ValueError("Tokens must be string or list-like")
# Try extracting <motion_ID> tokens
matches = re.findall(r'<motion_(\d+)>', token_str)
if matches:
return [int(x) for x in matches]
# Try space-separated numbers
token_str = token_str.strip()
if token_str:
try:
return [int(x) for x in token_str.split()]
except ValueError:
pass
raise ValueError(f"Could not parse motion tokens from: {token_str[:100]}...")
# =====================================================================
# Model Loading
# =====================================================================
def load_vqvae(checkpoint_path, device=DEVICE, vq_args=VQ_ARGS):
"""Load trained VQ-VAE model from checkpoint"""
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(
f"VQ-VAE checkpoint not found: {checkpoint_path}\n"
f"Please download it and set VQVAE_CHECKPOINT environment variable "
f"or use --vqvae-ckpt argument."
)
print(f"Loading VQ-VAE from: {checkpoint_path}")
model = MotionGPT_VQVAE_Wrapper(
smpl_dim=SMPL_DIM,
codebook_size=CODEBOOK_SIZE,
code_dim=CODE_DIM,
**vq_args
).to(device)
ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
state_dict = ckpt.get('model_state_dict', ckpt)
model.load_state_dict(state_dict, strict=False)
model.eval()
print(f"✅ VQ-VAE loaded (codebook size: {CODEBOOK_SIZE})")
return model
def load_stats(stats_path):
"""Load normalization statistics (mean/std) used during VQ-VAE training"""
if not stats_path or not os.path.exists(stats_path):
print(f"⚠️ Stats file not found: {stats_path}")
print(" Will skip denormalization (may affect quality)")
return None, None
print(f"Loading stats from: {stats_path}")
st = torch.load(stats_path, map_location='cpu', weights_only=False)
mean = st.get('mean', 0)
std = st.get('std', 1)
# Convert to numpy
if torch.is_tensor(mean):
mean = mean.cpu().numpy()
if torch.is_tensor(std):
std = std.cpu().numpy()
print(f"✅ Stats loaded (mean shape: {np.array(mean).shape})")
return mean, std
def load_smplx_model(model_dir, device=DEVICE):
"""Load SMPL-X body model"""
if not os.path.exists(model_dir):
raise FileNotFoundError(
f"SMPL-X model directory not found: {model_dir}\n"
f"Please download SMPL-X models and set SMPLX_MODEL_DIR environment variable "
f"or use --smplx-dir argument."
)
print(f"Loading SMPL-X from: {model_dir}")
model = smplx.SMPLX(
model_path=model_dir,
model_type='smplx',
gender='neutral',
use_pca=False,
create_global_orient=True,
create_body_pose=True,
create_betas=True,
create_expression=True,
create_jaw_pose=True,
create_left_hand_pose=True,
create_right_hand_pose=True,
create_transl=True
).to(device)
print(f"✅ SMPL-X loaded")
return model
# =====================================================================
# Token Decoding
# =====================================================================
def decode_tokens_to_params(tokens, vqvae_model, mean=None, std=None, device=DEVICE):
"""
Decode motion tokens to SMPL-X parameters.
Args:
tokens: List of motion token IDs
vqvae_model: Trained VQ-VAE model
mean: Optional normalization mean
std: Optional normalization std
device: Device to run on
Returns:
numpy array of shape (T, SMPL_DIM) with SMPL-X parameters
"""
if not tokens:
return np.zeros((0, SMPL_DIM), dtype=np.float32)
# Prepare token indices
idx = torch.tensor(tokens, dtype=torch.long, device=device).unsqueeze(0) # (1, T_q)
T_q = idx.shape[1]
quantizer = vqvae_model.vqvae.quantizer
# Get code dimension
if hasattr(quantizer, "codebook"):
codebook = quantizer.codebook.to(device)
code_dim = codebook.shape[1]
else:
code_dim = CODE_DIM
# Dequantize tokens
x_quantized = None
if hasattr(quantizer, "dequantize"):
try:
with torch.no_grad():
dq = quantizer.dequantize(idx)
if dq is not None:
dq = dq.contiguous()
# Ensure shape is (N, code_dim, T_q)
if dq.ndim == 3 and dq.shape[1] == code_dim:
x_quantized = dq
elif dq.ndim == 3 and dq.shape[1] == T_q:
x_quantized = dq.permute(0, 2, 1).contiguous()
else:
x_quantized = None
except Exception:
x_quantized = None
# Fallback: manual codebook lookup
if x_quantized is None:
if not hasattr(quantizer, "codebook"):
raise RuntimeError("No dequantize method and no codebook available")
with torch.no_grad():
emb = codebook[idx] # (1, T_q, code_dim)
x_quantized = emb.permute(0, 2, 1).contiguous() # (1, code_dim, T_q)
# Decode through VQ-VAE decoder
with torch.no_grad():
x_dec = vqvae_model.vqvae.decoder(x_quantized)
smpl_out = vqvae_model.vqvae.postprocess(x_dec) # (1, T_out, SMPL_DIM)
params_np = smpl_out.squeeze(0).cpu().numpy() # (T_out, SMPL_DIM)
# Denormalize if stats provided
if (mean is not None) and (std is not None):
mean_arr = np.array(mean).reshape(1, -1)
std_arr = np.array(std).reshape(1, -1)
params_np = (params_np * std_arr) + mean_arr
return params_np
# =====================================================================
# SMPL-X Parameter to Vertices
# =====================================================================
def params_to_vertices(params_seq, smplx_model, batch_size=32):
"""
Convert SMPL-X parameters to 3D vertices.
Args:
params_seq: numpy array (T, SMPL_DIM)
smplx_model: loaded SMPL-X model
batch_size: batch size for processing
Returns:
verts: numpy array (T, V, 3)
faces: numpy array (F, 3)
"""
# Compute parameter slicing indices
starts = np.cumsum([0] + PARAM_DIMS[:-1])
ends = starts + np.array(PARAM_DIMS)
T = params_seq.shape[0]
all_verts = []
# Infer number of body joints
num_body_joints = getattr(smplx_model, "NUM_BODY_JOINTS", 21)
with torch.no_grad():
for s in range(0, T, batch_size):
batch = params_seq[s:s+batch_size] # (B, SMPL_DIM)
B = batch.shape[0]
# Extract parameters
np_parts = {}
for name, st, ed in zip(PARAM_NAMES, starts, ends):
np_parts[name] = batch[:, st:ed].astype(np.float32)
# Convert to tensors
tensor_parts = {
name: torch.from_numpy(arr).to(DEVICE)
for name, arr in np_parts.items()
}
# Handle body pose (may or may not include global orient)
body_t = tensor_parts['body_pose']
L_body = body_t.shape[1]
expected_no_go = num_body_joints * 3
expected_with_go = (num_body_joints + 1) * 3
if L_body == expected_with_go:
global_orient = body_t[:, :3].contiguous()
body_pose_only = body_t[:, 3:].contiguous()
elif L_body == expected_no_go:
global_orient = torch.zeros((B, 3), dtype=torch.float32, device=DEVICE)
body_pose_only = body_t
else:
# Best-effort fallback
if L_body > expected_no_go:
global_orient = body_t[:, :3].contiguous()
body_pose_only = body_t[:, 3:].contiguous()
else:
pad_len = max(0, expected_no_go - L_body)
body_pose_only = F.pad(body_t, (0, pad_len))
global_orient = torch.zeros((B, 3), dtype=torch.float32, device=DEVICE)
# Call SMPL-X
out = smplx_model(
betas=tensor_parts['betas'],
global_orient=global_orient,
body_pose=body_pose_only,
left_hand_pose=tensor_parts['left_hand_pose'],
right_hand_pose=tensor_parts['right_hand_pose'],
expression=tensor_parts['expression'],
jaw_pose=tensor_parts['jaw_pose'],
leye_pose=tensor_parts['eye_pose'],
reye_pose=tensor_parts['eye_pose'],
transl=tensor_parts['trans'],
return_verts=True
)
verts = out.vertices.detach().cpu().numpy() # (B, V, 3)
all_verts.append(verts)
verts_all = np.concatenate(all_verts, axis=0) # (T, V, 3)
faces = smplx_model.faces.astype(np.int32)
return verts_all, faces
# =====================================================================
# Visualization
# =====================================================================
def animate_motion(verts, faces, title="Generated Motion", output_path=None, fps=20):
"""
Create interactive 3D animation using Plotly.
Args:
verts: numpy array (T, V, 3)
faces: numpy array (F, 3)
title: Plot title
output_path: Path to save HTML file
fps: Frames per second for animation
Returns:
Plotly figure object
"""
T, V, _ = verts.shape
i, j, k = faces.T.tolist()
# Initial mesh
mesh = go.Mesh3d(
x=verts[0, :, 0],
y=verts[0, :, 1],
z=verts[0, :, 2],
i=i, j=j, k=k,
name=title,
flatshading=True,
opacity=0.7
)
# Create frames
frames = [
go.Frame(
data=[go.Mesh3d(
x=verts[t, :, 0],
y=verts[t, :, 1],
z=verts[t, :, 2],
i=i, j=j, k=k,
flatshading=True,
opacity=0.7
)],
name=str(t)
)
for t in range(T)
]
# Create figure
fig = go.Figure(data=[mesh], frames=frames)
fig.update_layout(
title_text=title,
scene=dict(
aspectmode='data',
xaxis=dict(visible=False),
yaxis=dict(visible=False),
zaxis=dict(visible=False),
camera=dict(eye=dict(x=0, y=-2, z=0.7))
),
updatemenus=[dict(
type="buttons",
buttons=[
dict(
label="Play",
method="animate",
args=[None, {
"frame": {"duration": 1000//fps, "redraw": True},
"fromcurrent": True
}]
),
dict(
label="Pause",
method="animate",
args=[[None], {
"frame": {"duration": 0, "redraw": False}
}]
)
]
)]
)
# Save HTML
if output_path:
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
fig.write_html(output_path)
print(f"✅ Animation saved to: {output_path}")
return fig
# =====================================================================
# Main Visualization Pipeline
# =====================================================================
def visualize(
tokens,
vqvae_ckpt=VQVAE_CHECKPOINT,
stats_path=STATS_PATH,
smplx_dir=SMPLX_MODEL_DIR,
output_html=None,
title="Generated Motion",
fps=20
):
"""
Complete visualization pipeline: tokens -> vertices -> animation.
Args:
tokens: Motion tokens (string or list of ints)
vqvae_ckpt: Path to VQ-VAE checkpoint
stats_path: Path to normalization stats
smplx_dir: Path to SMPL-X model directory
output_html: Path to save HTML animation
title: Animation title
fps: Frames per second
Returns:
Plotly figure object
"""
print("="*60)
print("Motion Visualization Pipeline")
print("="*60)
# Parse tokens
print("\n[1/5] Parsing tokens...")
token_list = parse_motion_tokens(tokens)
print(f" Parsed {len(token_list)} tokens")
if not token_list:
print("❌ No tokens to visualize")
return None
# Load models
print("\n[2/5] Loading VQ-VAE...")
vq_model = load_vqvae(vqvae_ckpt, device=DEVICE)
print("\n[3/5] Loading normalization stats...")
mean, std = load_stats(stats_path)
print("\n[4/5] Loading SMPL-X model...")
smplx_model = load_smplx_model(smplx_dir, device=DEVICE)
# Decode tokens
print("\n[5/5] Decoding and rendering...")
print(" Decoding tokens to SMPL-X parameters...")
params = decode_tokens_to_params(token_list, vq_model, mean, std, device=DEVICE)
print(f" Decoded params shape: {params.shape}")
if params.shape[0] == 0:
print("❌ No frames produced from decoder")
return None
# Convert to vertices
print(" Converting parameters to vertices...")
verts, faces = params_to_vertices(params, smplx_model, batch_size=32)
print(f" Vertices shape: {verts.shape}, Faces: {faces.shape}")
# Create animation
print(" Creating animation...")
if output_html is None:
output_html = os.path.join(OUTPUT_DIR, "motion_animation.html")
fig = animate_motion(verts, faces, title=title, output_path=output_html, fps=fps)
print("\n" + "="*60)
print("✅ Visualization complete!")
print("="*60)
return fig
# =====================================================================
# CLI
# =====================================================================
def main():
parser = argparse.ArgumentParser(
description="Visualize motion tokens as 3D SMPL-X animation"
)
# Input options (mutually exclusive)
input_group = parser.add_mutually_exclusive_group(required=True)
input_group.add_argument(
"--tokens",
type=str,
help="Motion tokens string (e.g., '<MOT_BEGIN><motion_177>...<MOT_END>' or '177 135 152...')"
)
input_group.add_argument(
"--input",
type=str,
help="Path to file containing motion tokens"
)
input_group.add_argument(
"--prompt",
type=str,
help="Generate tokens from text prompt first (requires --stage)"
)
# Generation options (if using --prompt)
parser.add_argument(
"--stage",
type=int,
default=3,
choices=[1, 2, 3],
help="Stage model to use for generation (default: 3)"
)
# Model paths
parser.add_argument(
"--vqvae-ckpt",
type=str,
default=VQVAE_CHECKPOINT,
help=f"Path to VQ-VAE checkpoint (default: {VQVAE_CHECKPOINT})"
)
parser.add_argument(
"--stats",
type=str,
default=STATS_PATH,
help=f"Path to normalization stats (default: {STATS_PATH})"
)
parser.add_argument(
"--smplx-dir",
type=str,
default=SMPLX_MODEL_DIR,
help=f"Path to SMPL-X model directory (default: {SMPLX_MODEL_DIR})"
)
# Output options
parser.add_argument(
"--output",
type=str,
default=None,
help="Path to save HTML animation (default: motion_animation.html)"
)
parser.add_argument(
"--title",
type=str,
default="Generated Motion",
help="Animation title"
)
parser.add_argument(
"--fps",
type=int,
default=20,
help="Frames per second for animation (default: 20)"
)
args = parser.parse_args()
# Get tokens
if args.prompt:
# Generate tokens first using inference.py
print("Generating motion tokens from prompt...")
from inference import inference
tokens = inference(
prompt=args.prompt,
stage=args.stage,
output_file=None,
per_prompt_vocab=True
)
elif args.input:
# Read from file
with open(args.input, 'r') as f:
tokens = f.read().strip()
else:
# Direct token string
tokens = args.tokens
# Visualize
visualize(
tokens=tokens,
vqvae_ckpt=args.vqvae_ckpt,
stats_path=args.stats,
smplx_dir=args.smplx_dir,
output_html=args.output,
title=args.title,
fps=args.fps
)
if __name__ == "__main__":
main()