"""Motion-data helpers for skeleton loading, graph creation, reconstruction, and serialization.""" import os import torch import numpy as np from os.path import join as pjoin from sata.mydataset import SkelData from sata.skel_pose_graph import SkelPoseGraph def load_skeleton_from_npz(npz_path): """ Load skeleton data from an npz file without text features. Args: npz_path: motion npz path containing skeleton data Returns: skel_data: SkelData object with tf set to None """ # Load skeleton data. data = np.load(npz_path) # Extract intrinsic skeleton data. lo = data['lo'] # [nJ, 3] go = data['go'] # [nJ, 3] qb = data['qb'] # [nJ] edges = data['edges'] # [nE, 4] # Sort edges by child index. if not (np.arange(edges.shape[0]) == edges[:, 1]).all(): edges = edges[np.argsort(edges[:, 1])] # Create SkelData with tf set to None. skel_data = SkelData( torch.Tensor(lo), torch.Tensor(go), torch.BoolTensor(qb), torch.LongTensor(edges[:, :2]).transpose(1, 0), # [2, nE] torch.LongTensor(edges[:, 2:]), # [nE, 2] None, # tf is None ) return skel_data def load_skeleton_and_tf_from_npz(npz_path, tf_npz_path): """ Load skeleton data and text features from npz files. Args: npz_path: motion npz path containing skeleton data tf_npz_path: text-feature npz path containing tf Returns: skel_data: SkelData object containing tf """ # Load skeleton data without tf. skel_data = load_skeleton_from_npz(npz_path) # Load text features. try: tf_data = np.load(tf_npz_path) if 'tf' not in tf_data: raise KeyError(f"'tf' key not found in {tf_npz_path}") tf = tf_data['tf'] # [nJ, 768] except Exception as e: print(f"[Warning] Failed to load tf from {tf_npz_path}: {e}") print(f"[Warning] Using zero tf instead") nJ = skel_data.lo.shape[0] tf = np.zeros((nJ, 768), dtype=np.float32) # Attach tf to SkelData. skel_data.tf = torch.Tensor(tf) return skel_data def create_graph_list_from_skeleton(skel_data, seq_length): """ Create a graph list of the requested length from skeleton data. Args: skel_data: SkelData object seq_length: sequence length in frames Returns: graphs: list of SkelPoseGraph """ # Create one graph per frame with skeleton data only. graphs = [SkelPoseGraph(skel_data, None) for _ in range(seq_length)] return graphs def create_graph_list_from_single_graph(skel_graph, seq_length): """ Create a graph list of the requested length from one SkelPoseGraph. Args: skel_graph: SkelPoseGraph object containing skeleton data only seq_length: sequence length in frames Returns: graphs: list of SkelPoseGraph """ # Extract SkelData to avoid repeated conversion. skel_data = SkelData( lo=skel_graph.lo, go=skel_graph.go, qb=skel_graph.qb, edge_index=skel_graph.edge_index, edge_feature=skel_graph.edge_feature, tf=skel_graph.tf if hasattr(skel_graph, 'tf') else torch.zeros(skel_graph.lo.shape[0], 768) ) # Create the graph list. graphs = [SkelPoseGraph(skel_data, None) for _ in range(seq_length)] return graphs def process_hatD_to_qrc(hatD_full, src_batch_full, actual_frames, num_nodes_per_frame, out_rep_cfg, ms_dict): """ Convert hatD into q, r, and c. Args: hatD_full: [T*num_nodes, D] decoded features src_batch_full: complete source batch used for post-processing actual_frames: actual frame count num_nodes_per_frame: nodes per frame out_rep_cfg: output representation config ms_dict: mean/std dictionary Returns: q: [T, nJ, 6] - quaternion (6D representation) r: [T, 1, 4] - root transform c: [T, nJ, 1] - contact """ from sata.mymodel import parse_hatD # Extract q, r, and c directly with parse_hatD. root_ids = src_batch_full.ptr[:-1] out = parse_hatD(hatD_full, root_ids, out_rep_cfg, ms_dict) # Extract q, r, and c according to the config. q = out.get('q', None) # quaternion [T*nJ, 6] r = out.get('r', None) # root position [T, 4] c = out.get('c', None) # contact [T*nJ, 1] # Reshape q, r, c to [T, nJ, ...] format if q is not None: q = q.view(actual_frames, num_nodes_per_frame, -1) # [T, nJ, 6] if r is not None: # r is already [T, 4] because parse_hatD handles root_ids. r = r.unsqueeze(1) # [T, 1, 4] for downstream consistency if c is not None: c = c.view(actual_frames, num_nodes_per_frame, -1) # [T, nJ, 1] return q, r, c def compute_qv_from_qR(qR): """ Compute angular velocity from a rotation-matrix sequence. qR: [T, nJ, 3, 3] rotation matrices Returns qv: [T, nJ, 6] angular velocity in 6D representation Based on motion_to_graph.py: q_vel[1:] = rotations[:-1].swapaxes(-2, -1) @ rotations[1:] """ T, nJ = qR.shape[0], qR.shape[1] # Initialize as identity matrices. q_vel_R = torch.eye(3, device=qR.device, dtype=qR.dtype)[None, None, ...].repeat(T, nJ, 1, 1) # Compute relative rotation: R[t-1].T @ R[t]. if T > 1: q_vel_R[1:] = qR[:-1].transpose(-2, -1) @ qR[1:] # Convert to 6D representation. q_vel_flat = q_vel_R.reshape(-1, 3, 3) # [T*nJ, 3, 3] # Use the first two columns as the 6D representation. qv_flat = torch.cat([q_vel_flat[:, :, 0], q_vel_flat[:, :, 1]], dim=-1) # [T*nJ, 6] qv = qv_flat.reshape(T, nJ, 6) return qv def reconstruct_p_pv_qv_from_qrc(q, r, c, src_batch, consq_n, device): """ Reconstruct p, pv, and qv from q, r, and c. Mirrors the implementation in reconstruction_qrc_2_same.py. Args: q: [T, nJ, 6] - quaternion (6D representation) r: [T, 1, 4] - root transform c: [T, nJ, 1] - contact src_batch: Batch object containing skeleton info consq_n: T, sequence length device: torch device Returns: p: [T-1, nJ, 3] - joint positions (excluding frame 1) pv: [T-1, nJ, 3] - joint velocities (excluding frame 1) qv: [T-1, nJ, 6] - joint angular velocities (excluding frame 1) q_out: [T-1, nJ, 6] - joint rotations (excluding frame 1) r_out: [T-1, 4] - root transform (excluding frame 1) c_out: [T-1, nJ, 1] - contact (excluding frame 1) Note: every output drops the first frame to keep temporal dimensions consistent. """ from sata.mymodel import FK, accum_root from sata.utils import tensor_utils nJ = q.shape[1] # 1. Convert q from 6D representation to rotation matrices qR. q_flat = q.reshape(-1, 6) # [T*nJ, 6] qR_flat = tensor_utils.tensor_q2qR(q_flat) # [T*nJ, 3, 3] qR = qR_flat.reshape(consq_n, nJ, 3, 3) # [T, nJ, 3, 3] # 2. Run forward kinematics. r_squeezed = r.squeeze(1) # [T, 1, 4] -> [T, 4] fk_T_flat = FK( lo=src_batch.lo, # [T*nJ, 3] qR=qR_flat, # [T*nJ, 3, 3] r=r_squeezed, # [T, 4] root_ids=src_batch.ptr[:-1], # [T] skel_depth=src_batch.skel_depth, # [T*nJ] skel_edge_index=src_batch.edge_index, # [2, T*nE] ) # [T*nJ, 4, 4] # Extract positions. p_flat = fk_T_flat[..., :3, 3] # [T*nJ, 3] p_full = p_flat.reshape(consq_n, nJ, 3) # [T, nJ, 3] # 3. Compute angular velocity qv. qv_full = compute_qv_from_qR(qR) # [T, nJ, 6] # 4. Compute positional velocity pv. # Accumulate root transforms. r_for_accum = r # [T, 1, 4] rT_accum = accum_root(r_for_accum, consq_n, apply_height=False, grad_truncate_k=0) # [T, 1, 4, 4] facing_transforms = rT_accum[:, 0, :, :] # [T, 4, 4] # Convert p to global coordinates. p_T = tensor_utils.tensor_p2T(p_full.reshape(-1, 3)) # [T*nJ, 4, 4] p_T = p_T.reshape(consq_n, nJ, 4, 4) # [T, nJ, 4, 4] facing_T_expanded = facing_transforms.unsqueeze(1) # [T, 1, 4, 4] global_p_T = facing_T_expanded @ p_T # [T, nJ, 4, 4] global_p = global_p_T[..., :3, 3] # [T, nJ, 3] # Compute global position differences. global_p_vel = torch.zeros_like(global_p) if consq_n > 1: global_p_vel[1:] = global_p[1:] - global_p[:-1] # Convert back to the facing frame. facing_inv_rot = torch.inverse(facing_transforms)[:, :3, :3] # [T, 3, 3] facing_inv_rot = facing_inv_rot.unsqueeze(1) # [T, 1, 3, 3] local_p_vel = (facing_inv_rot @ global_p_vel.unsqueeze(-1)).squeeze(-1) # [T, nJ, 3] # Multiply by FPS (30). pv_full = local_p_vel * 30.0 # 5. Drop the first frame for every feature to keep dimensions consistent. # pv and qv need the previous frame, so their first frame is undefined. # Drop the first frame from every feature for consistency. if consq_n > 1: p = p_full[1:] # [T-1, nJ, 3] pv = pv_full[1:] # [T-1, nJ, 3] qv = qv_full[1:] # [T-1, nJ, 6] q_out = q[1:] # [T-1, nJ, 6] r_out = r_squeezed[1:] # [T-1, 4] c_out = c[1:] # [T-1, nJ, 1] else: # Single-frame fallback; this should not happen in normal inputs. p = p_full pv = pv_full qv = qv_full q_out = q r_out = r_squeezed c_out = c return p, pv, qv, q_out, r_out, c_out def save_processed_with_tf_and_meta(data_dict, output_dir, filename): """ Save processed data, joint_text_features, and metadata. Args: data_dict: skeleton data, motion features, tf, text, m_len, and related metadata output_dir: output root directory filename: file stem without extension """ import json # Create subdirectories. processed_dir = pjoin(output_dir, 'processed') tf_dir = pjoin(output_dir, 'joint_text_features') os.makedirs(processed_dir, exist_ok=True) os.makedirs(tf_dir, exist_ok=True) # 1. Save the npz file to processed/ with skeleton and motion data, excluding tf. npz_dict = {} for key, value in data_dict.items(): # Skip tf and metadata fields. if key in ['tf', 'text', 'src_filename', 'is_segment', 'segment_info']: continue if isinstance(value, torch.Tensor): npz_dict[key] = value.cpu().numpy() else: npz_dict[key] = value npz_path = pjoin(processed_dir, f'{filename}.npz') np.savez(npz_path, **npz_dict) # 2. Save tf to joint_text_features/. if 'tf' in data_dict: tf_path = pjoin(tf_dir, f'{filename}.npz') tf_value = data_dict['tf'] if isinstance(tf_value, torch.Tensor): tf_value = tf_value.cpu().numpy() np.savez(tf_path, tf=tf_value) # 3. Save metadata to meta/ for segments and text records. if data_dict.get('is_segment', False) or 'text' in data_dict: meta_dir = pjoin(output_dir, 'meta') os.makedirs(meta_dir, exist_ok=True) meta_path = pjoin(meta_dir, f'{filename}.json') meta_dict = {} if 'text' in data_dict: meta_dict['text'] = data_dict['text'] if data_dict.get('is_segment', False): meta_dict['is_segment'] = True meta_dict['segment_info'] = data_dict.get('segment_info', {}) meta_dict['src_filename'] = data_dict.get('src_filename', filename) meta_dict['m_len'] = data_dict.get('m_len', 0) with open(meta_path, 'w') as f: json.dump(meta_dict, f, indent=2) def bvh_2_SkelPoseGraph(bvh_path): """ Load a skeleton from BVH and convert it to SkelPoseGraph without tf. Args: bvh_path: BVH file path Returns: skel_graph: skeleton graph object (SkelPoseGraph) Note: This path has no text features, so decoding uses zero tf. """ from fairmotion.data import bvh from sata.conversions.motion_to_graph import skel_2_graph print(f"Loading skeleton from BVH: {bvh_path}") motion = bvh.load(bvh_path, ignore_root_skel=True, ee_as_joint=True) # The skeleton must be normalized before graph conversion. from sata.utils.motion_utils import motion_normalize_h2s motion, tpose = motion_normalize_h2s(motion, False) skel = motion.skel text_feature = np.zeros((skel.num_joints(), 768), dtype=np.float32) skel_graph = skel_2_graph(skel, text_feature) print(f" Skeleton joints: {skel.num_joints()}") print(" [Warning] BVH has no text features; decoding will use zero tf") return skel_graph def fix_skeleton_coordinate_system(motion): """ Convert a motion from Z-up to Y-up coordinates with local-axis retargeting. Converts the entire motion from Z-up to Y-up, including: 1. Skeleton OFFSET conversion 2. Root position conversion for every frame 3. Local rotation retargeting for every joint in every frame Key idea: - When OFFSET changes, local rotations must be adjusted to preserve the visual pose. - Local rotation conversion: R_new = R_fix @ R_old @ R_fix^T Observations: - Original OFFSET: (0, 0.184, 0) -> Y-up - Current OFFSET: (0, 0, -0.184) -> Z-down - Requires a -90 or 90 degree rotation around the X axis Args: motion: fairmotion Motion object in Z-up coordinates Returns: motion: converted Motion object in Y-up coordinates """ from fairmotion.ops import conversions from fairmotion.ops import motion as motion_ops import copy # Create a deep copy. motion_copy = copy.deepcopy(motion) # Define the coordinate conversion rotation matrix. # (0, 0, -0.184) -> (0, 0.184, 0) # This uses a 90 degree rotation around the X axis. axis_angle = np.array([np.pi / 2, 0.0, 0.0]) R_fix = conversions.A2R(axis_angle) R_fix_T = R_fix.T # transpose; for rotation matrices, transpose equals inverse # Step 1: convert skeleton OFFSET values. for joint in motion_copy.skel.joints: old_offset = joint.xform_from_parent_joint[:3, 3].copy() # Apply rotation to the offset vector. new_offset = np.dot(R_fix, old_offset) joint.xform_from_parent_joint[:3, 3] = new_offset # Keep the transform rotation part as identity. joint.xform_from_parent_joint[:3, :3] = np.eye(3) # Step 2: convert each frame. for frame_idx in range(motion_copy.num_frames()): pose = motion_copy.get_pose_by_frame(frame_idx) for joint_idx, joint in enumerate(motion_copy.skel.joints): # Read the current joint local transform matrix (4x4). T_old = pose.data[joint_idx].copy() # Split rotation (3x3) and translation (3,). R_old = T_old[:3, :3] p_old = T_old[:3, 3] # Root joints need global position conversion. if joint.parent_joint is None: # Root joint: convert global position. p_new = np.dot(R_fix, p_old) # Convert root rotation too. R_new = np.dot(np.dot(R_fix, R_old), R_fix_T) else: # Non-root joint: retarget the local coordinate frame. # R_new = R_fix @ R_old @ R_fix^T R_new = np.dot(np.dot(R_fix, R_old), R_fix_T) # Local translation is usually defined by OFFSET; keep it consistent here. p_new = np.dot(R_fix, p_old) # Build the new transform matrix. T_new = np.eye(4) T_new[:3, :3] = R_new T_new[:3, 3] = p_new # Update pose data. pose.data[joint_idx] = T_new # Step 3: apply an extra global rotation fix. # Apply a -90 degree rotation around X to the entire motion sequence. # This independent global transform adjusts the final motion direction. rx = conversions.A2R(np.array([-np.pi / 2, 0.0, 0.0])) motion_final = motion_ops.rotate(motion_copy, rx) return motion_final