move-it / inference.py
zirobtc's picture
Upload folder using huggingface_hub
d2a17a9 verified
import torch
import numpy as np
import os
import time
import subprocess
import sys
import smplx
# --- Model Imports ---
from models.llama_model import LLaMAHF, LLaMAHFConfig
from models.tae import Causal_HumanTAE
from sentence_transformers import SentenceTransformer
# --- Direct Imports from Cloned Repo's `utils` folder ---
from utils import bvh, quat
from utils.face_z_align_util import rotation_6d_to_matrix, matrix_to_axis_angle, axis_angle_to_quaternion
# --- A simple logging helper ---
def log_step(message):
timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
print(f"[{timestamp}] - {message}")
# --- Self-Contained Conversion Function with Detailed Logging ---
def convert_to_bvh(motion_data_272, output_path="outputs/final_motion.bvh", fps=60):
log_step("--- Starting Conversion to BVH Format ---")
try:
# --- 1. Initial Data Preparation ---
njoint = 22
motion_data_272 = motion_data_272.squeeze(0)
nfrm, _ = motion_data_272.shape
log_step(f"Input motion has {nfrm} frames and {motion_data_272.shape[1]} dimensions.")
# --- 2. Extract Data Components from 272-dim Vector ---
log_step("Extracting rotation, velocity, and position data...")
rotations_6d = torch.from_numpy(motion_data_272[:, 8+6*njoint : 8+12*njoint]).reshape(nfrm, -1, 6)
rotations_matrix = rotation_6d_to_matrix(rotations_6d).numpy()
global_heading_diff_rot_6d = torch.from_numpy(motion_data_272[:, 2:8])
global_heading_diff_rot = rotation_6d_to_matrix(global_heading_diff_rot_6d).numpy()
velocities_root_xy = motion_data_272[:, :2]
positions_no_heading = motion_data_272[:, 8 : 8+3*njoint].reshape(nfrm, -1, 3)
height = positions_no_heading[:, 0, 1]
log_step(f"Extracted rotations matrix with shape: {rotations_matrix.shape}")
# --- 3. Reconstruct Global Heading and Translation ---
log_step("Reconstructing global heading...")
global_heading_rot = [global_heading_diff_rot[0]]
for R_rel in global_heading_diff_rot[1:]:
global_heading_rot.append(np.matmul(R_rel, global_heading_rot[-1]))
global_heading_rot = np.array(global_heading_rot)
inv_global_heading_rot = np.transpose(global_heading_rot, (0, 2, 1))
rotations_matrix[:, 0, ...] = np.matmul(inv_global_heading_rot, rotations_matrix[:, 0, ...])
log_step("Reconstructing root translation...")
velocities_root_xyz = np.zeros((nfrm, 3))
velocities_root_xyz[:, 0] = velocities_root_xy[:, 0]
velocities_root_xyz[:, 2] = velocities_root_xy[:, 1]
velocities_root_xyz[1:, :] = np.matmul(inv_global_heading_rot[:-1], velocities_root_xyz[1:, :, None]).squeeze(-1)
root_translation = np.cumsum(velocities_root_xyz, axis=0)
root_translation[:, 1] = height
log_step(f"Reconstructed root translation with shape: {root_translation.shape}")
# --- 4. Convert to Final SMPL Pose Format ---
log_step("Converting rotation matrices to axis-angle format...")
axis_angle = matrix_to_axis_angle(torch.from_numpy(rotations_matrix)).numpy().reshape(nfrm, -1)
num_frames = axis_angle.shape[0]
poses_24_joints = np.zeros((num_frames, 72))
poses_24_joints[:, :66] = axis_angle
log_step(f"Padded pose data to 24 joints for SMPL standard, new shape: {poses_24_joints.shape}")
# --- 5. Create and Save BVH File ---
log_step("Loading SMPL model to create BVH skeleton...")
model = smplx.create(model_path="body_models/human_model_files", model_type="smpl", gender="NEUTRAL")
parents = model.parents.detach().cpu().numpy()
rest_pose = model().joints.detach().cpu().numpy().squeeze()[:24,:]
offsets = rest_pose - rest_pose[parents]
offsets[0] = rest_pose[0]
log_step("Converting axis-angle to euler angles for BVH...")
rotations_quat = axis_angle_to_quaternion(torch.from_numpy(poses_24_joints.reshape(-1, 24, 3))).numpy()
rotations_euler = np.degrees(quat.to_euler(rotations_quat, order="zyx"))
positions = np.zeros_like(rotations_quat[..., :3])
positions[:, 0] = root_translation
log_step("Assembling final BVH data structure...")
# <<<<<<<<<<<<<<<<<<<<<<<< THE FIX IS HERE >>>>>>>>>>>>>>>>>>>>>>>>
# Use the hardcoded list of joint names from the official conversion script.
joint_names = [
"Pelvis", "Left_hip", "Right_hip", "Spine1", "Left_knee", "Right_knee",
"Spine2", "Left_ankle", "Right_ankle", "Spine3", "Left_foot", "Right_foot",
"Neck", "Left_collar", "Right_collar", "Head", "Left_shoulder",
"Right_shoulder", "Left_elbow", "Right_elbow", "Left_wrist", "Right_wrist",
"Left_palm", "Right_palm",
]
bvh_data = {
"rotations": rotations_euler,
"positions": offsets + positions,
"offsets": offsets,
"parents": parents,
"names": joint_names, # Use the correct, hardcoded list
"order": "zyx",
"frametime": 1.0 / fps,
}
log_step(f"Saving BVH file to {output_path}...")
bvh.save(output_path, bvh_data)
log_step(f"βœ… BVH file saved successfully to {output_path}")
except Exception as e:
log_step(f"❌ BVH Conversion Failed. Error: {e}")
import traceback
traceback.print_exc()
def main():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
log_step(f"Using device: {device}")
text_prompt = "a person walks forward"
causal_tae_checkpoint = './Causal_TAE/net_last.pth'
output_dir = "outputs"
os.makedirs(output_dir, exist_ok=True)
log_step("Loading Causal Temporal Autoencoder (TAE)...")
causal_tae = Causal_HumanTAE(
latent_dim=16, down_t=2, depth=3, stride_t=2, clip_range=[-30.0, 20.0]
).to(device)
state_dict = torch.load(causal_tae_checkpoint, map_location=device, weights_only=True)['net']
causal_tae.load_state_dict(state_dict, strict=True)
causal_tae.eval()
log_step("βœ… TAE loaded successfully.")
log_step("Loading Text Encoder (T5-XXL)...")
text_encoder = SentenceTransformer('sentence-transformers/sentence-t5-xxl', device=device)
log_step("βœ… Text Encoder loaded successfully.")
log_step("Loading MotionStreamer model architecture...")
config = LLaMAHFConfig.from_name("Normal_size")
motion_streamer = LLaMAHF(config).to(device)
motion_streamer.eval()
log_step("βœ… MotionStreamer loaded successfully.")
log_step(f"Starting motion generation for text: '{text_prompt}'")
with torch.no_grad():
impossible_pose = torch.zeros(1, 4, 272, device=device)
reference_end_latent, _, _ = causal_tae.encode(impossible_pose)
reference_end_token = reference_end_latent.detach()
log_step("Autoregressive generation started...")
motion_latents = motion_streamer.sample_for_eval_CFG_inference(
clip_text=[text_prompt], clip_model=text_encoder, tokenizer='t5-xxl',
device=device, reference_end_token=reference_end_token,
cfg=4.5, threshold=3.0, temperature=1.0, length=312
)
log_step("βœ… Autoregressive generation finished.")
log_step("Decoding latents into 272-dim motion data...")
with torch.no_grad():
generated_motion_272 = causal_tae.forward_decoder(motion_latents)
log_step(f"272-dim motion data shape: {generated_motion_272.shape}")
convert_to_bvh(generated_motion_272.cpu().numpy(), output_path=os.path.join(output_dir, "final_motion.bvh"))
if __name__ == "__main__":
main()