File size: 6,976 Bytes
0e267a7 69b50e2 0e267a7 69b50e2 0e267a7 69b50e2 0e267a7 69b50e2 0e267a7 69b50e2 0e267a7 69b50e2 0e267a7 69b50e2 0e267a7 69b50e2 0e267a7 69b50e2 0e267a7 69b50e2 0e267a7 69b50e2 0e267a7 69b50e2 0e267a7 69b50e2 0e267a7 69b50e2 0e267a7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import os
import torch
import numpy as np
from models.llama_model import LLaMAHF, LLaMAHFConfig
import models.tae as tae
import options.option_transformer as option_trans
import warnings
import smplx
from utils import bvh, quat
from utils.face_z_align_util import rotation_6d_to_matrix, matrix_to_axis_angle, axis_angle_to_quaternion
from sentence_transformers import SentenceTransformer
warnings.filterwarnings('ignore')
# --- save_motion_as_bvh function is unchanged ---
def save_motion_as_bvh(motion_data, output_path, fps=30):
print(f"--- Starting direct conversion to BVH: {os.path.basename(output_path)} ---")
try:
if isinstance(motion_data, torch.Tensor): motion_data = motion_data.detach().cpu().numpy()
if motion_data.ndim == 3 and motion_data.shape[0] == 1: motion_data = motion_data.squeeze(0)
elif motion_data.ndim != 2: raise ValueError(f"Input motion data must be 2D, but got shape {motion_data.shape}")
njoint = 22; nfrm, _ = motion_data.shape
rotations_matrix = rotation_6d_to_matrix(torch.from_numpy(motion_data[:, 8+6*njoint : 8+12*njoint]).reshape(nfrm, -1, 6)).numpy()
global_heading_diff_rot_6d = torch.from_numpy(motion_data[:, 2:8])
global_heading_diff_rot = rotation_6d_to_matrix(global_heading_diff_rot_6d).numpy()
global_heading_rot = np.zeros_like(global_heading_diff_rot); global_heading_rot[0] = global_heading_diff_rot[0]
for i in range(1, nfrm): global_heading_rot[i] = np.matmul(global_heading_diff_rot[i], global_heading_rot[i-1])
velocities_root_xy = motion_data[:, :2]; height = motion_data[:, 8 : 8+3*njoint].reshape(nfrm, -1, 3)[:, 0, 1]
inv_global_heading_rot = np.transpose(global_heading_rot, (0, 2, 1)); rotations_matrix[:, 0, ...] = np.matmul(inv_global_heading_rot, rotations_matrix[:, 0, ...])
velocities_root_xyz = np.zeros((nfrm, 3)); velocities_root_xyz[:, 0] = velocities_root_xy[:, 0]; velocities_root_xyz[:, 2] = velocities_root_xy[:, 1]
velocities_root_xyz[1:, :] = np.matmul(inv_global_heading_rot[:-1], velocities_root_xyz[1:, :, None]).squeeze(-1)
root_translation = np.cumsum(velocities_root_xyz, axis=0); root_translation[:, 1] = height
axis_angle = matrix_to_axis_angle(torch.from_numpy(rotations_matrix)).numpy().reshape(nfrm, -1); poses_24_joints = np.zeros((nfrm, 72)); poses_24_joints[:, :66] = axis_angle
model = smplx.create(model_path="body_models/human_model_files", model_type="smpl", gender="NEUTRAL"); parents = model.parents.detach().cpu().numpy()
rest_pose = model().joints.detach().cpu().numpy().squeeze()[:24,:]; offsets = rest_pose - rest_pose[parents]; offsets[0] = np.array([0,0,0])
rotations_quat = axis_angle_to_quaternion(torch.from_numpy(poses_24_joints.reshape(-1, 24, 3))).numpy(); rotations_euler = np.degrees(quat.to_euler(rotations_quat, order="zyx"))
positions = np.zeros_like(rotations_quat[..., :3]); positions[:, 0] = root_translation
joint_names = ["Pelvis", "Left_hip", "Right_hip", "Spine1", "Left_knee", "Right_knee", "Spine2", "Left_ankle", "Right_ankle", "Spine3", "Left_foot", "Right_foot", "Neck", "Left_collar", "Right_collar", "Head", "Left_shoulder", "Right_shoulder", "Left_elbow", "Right_elbow", "Left_wrist", "Right_wrist", "Left_hand", "Right_hand"]
bvh.save(output_path, {"rotations": rotations_euler, "positions": positions, "offsets": offsets, "parents": parents, "names": joint_names, "order": "zyx", "frametime": 1.0 / fps})
print(f"✅ BVH file saved successfully to {output_path}")
except Exception as e:
print(f"❌ BVH Conversion Failed. Error: {e}"); import traceback; traceback.print_exc()
if __name__ == '__main__':
comp_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
args = option_trans.get_args_parser()
torch.manual_seed(args.seed)
# --- Load Models ---
print("Loading models for MotionStreamer...")
t5_model = SentenceTransformer('sentencet5-xxl/')
t5_model.eval()
for p in t5_model.parameters():
p.requires_grad = False
print("Loading Causal TAE (t2m_babel) checkpoint...")
tae_net = tae.Causal_HumanTAE(
hidden_size=1024, down_t=2, stride_t=2, depth=3, dilation_growth_rate=3,
latent_dim=16, clip_range=[-30, 20]
)
tae_ckpt = torch.load('Causal_TAE_t2m_babel/net_last.pth', map_location='cpu')
tae_net.load_state_dict(tae_ckpt['net'], strict=True)
tae_net.eval()
tae_net.to(comp_device)
config = LLaMAHFConfig.from_name('Normal_size')
config.block_size = 78
trans_encoder = LLaMAHF(config, args.num_diffusion_head_layers, args.latent_dim, comp_device)
# --- THIS IS THE FIX ---
print("Loading your trained MotionStreamer checkpoint from 'motionstreamer_model/latest.pth'...")
# Make sure this path is correct relative to where you run the script
checkpoint_path = 'motionstreamer_model/latest.pth'
trans_ckpt = torch.load(checkpoint_path, map_location='cpu')
# Create a new state dict without the 'module.' prefix
unwrapped_state_dict = {}
for key, value in trans_ckpt['trans'].items():
if key.startswith('module.'):
# Strip the 'module.' prefix
unwrapped_state_dict[key[len('module.'):]] = value
else:
# Keep keys that don't have the prefix (just in case)
unwrapped_state_dict[key] = value
# Load the unwrapped state dict
trans_encoder.load_state_dict(unwrapped_state_dict, strict=True)
print("Successfully loaded unwrapped checkpoint.")
# --- END FIX ---
trans_encoder.eval()
trans_encoder.to(comp_device)
# --- Rest of the script is unchanged ---
print("Loading mean/std from BABEL dataset...")
mean = np.load('babel_272/t2m_babel_mean_std/Mean.npy')
std = np.load('babel_272/t2m_babel_mean_std/Std.npy')
motion_history = torch.empty(0, 16).to(comp_device)
cfg_scale = 10.0
print(f"Generating motion for text: '{args.text}' with CFG scale: {cfg_scale}")
with torch.no_grad():
# Use the new two-forward sampling method to match training
_, motion_latents = trans_encoder.sample_for_eval_CFG_babel_inference_two_forward(
B_text=args.text,
A_motion=motion_history,
tokenizer='t5-xxl',
clip_model=t5_model,
device=comp_device,
cfg=cfg_scale,
length=240,
temperature=1.3
)
print("Decoding latents to full motion...")
motion_seqs = tae_net.forward_decoder(motion_latents)
motion = motion_seqs.detach().cpu().numpy()
motion_denormalized = motion * std + mean
output_dir = 'demo_output_streamer'
if not os.path.exists(output_dir): os.makedirs(output_dir)
output_bvh_path = os.path.join(output_dir, f'{args.text.replace(" ", "_")}_cfg{cfg_scale}.bvh')
save_motion_as_bvh(motion_denormalized, output_bvh_path, fps=30) |