| """Motion-data helpers for skeleton loading, graph creation, reconstruction, and serialization.""" |
|
|
| import os |
| import torch |
| import numpy as np |
| from os.path import join as pjoin |
|
|
| from sata.mydataset import SkelData |
| from sata.skel_pose_graph import SkelPoseGraph |
|
|
|
|
| def load_skeleton_from_npz(npz_path): |
| """ |
| Load skeleton data from an npz file without text features. |
| |
| Args: |
| npz_path: motion npz path containing skeleton data |
| |
| Returns: |
| skel_data: SkelData object with tf set to None |
| """ |
| |
| data = np.load(npz_path) |
| |
| |
| lo = data['lo'] |
| go = data['go'] |
| qb = data['qb'] |
| edges = data['edges'] |
| |
| |
| if not (np.arange(edges.shape[0]) == edges[:, 1]).all(): |
| edges = edges[np.argsort(edges[:, 1])] |
| |
| |
| skel_data = SkelData( |
| torch.Tensor(lo), |
| torch.Tensor(go), |
| torch.BoolTensor(qb), |
| torch.LongTensor(edges[:, :2]).transpose(1, 0), |
| torch.LongTensor(edges[:, 2:]), |
| None, |
| ) |
| |
| return skel_data |
|
|
|
|
| def load_skeleton_and_tf_from_npz(npz_path, tf_npz_path): |
| """ |
| Load skeleton data and text features from npz files. |
| |
| Args: |
| npz_path: motion npz path containing skeleton data |
| tf_npz_path: text-feature npz path containing tf |
| |
| Returns: |
| skel_data: SkelData object containing tf |
| """ |
| |
| skel_data = load_skeleton_from_npz(npz_path) |
| |
| |
| try: |
| tf_data = np.load(tf_npz_path) |
| if 'tf' not in tf_data: |
| raise KeyError(f"'tf' key not found in {tf_npz_path}") |
| tf = tf_data['tf'] |
| except Exception as e: |
| print(f"[Warning] Failed to load tf from {tf_npz_path}: {e}") |
| print(f"[Warning] Using zero tf instead") |
| nJ = skel_data.lo.shape[0] |
| tf = np.zeros((nJ, 768), dtype=np.float32) |
| |
| |
| skel_data.tf = torch.Tensor(tf) |
| |
| return skel_data |
|
|
|
|
| def create_graph_list_from_skeleton(skel_data, seq_length): |
| """ |
| Create a graph list of the requested length from skeleton data. |
| |
| Args: |
| skel_data: SkelData object |
| seq_length: sequence length in frames |
| |
| Returns: |
| graphs: list of SkelPoseGraph |
| """ |
| |
| graphs = [SkelPoseGraph(skel_data, None) for _ in range(seq_length)] |
| return graphs |
|
|
|
|
| def create_graph_list_from_single_graph(skel_graph, seq_length): |
| """ |
| Create a graph list of the requested length from one SkelPoseGraph. |
| |
| Args: |
| skel_graph: SkelPoseGraph object containing skeleton data only |
| seq_length: sequence length in frames |
| |
| Returns: |
| graphs: list of SkelPoseGraph |
| """ |
| |
| skel_data = SkelData( |
| lo=skel_graph.lo, |
| go=skel_graph.go, |
| qb=skel_graph.qb, |
| edge_index=skel_graph.edge_index, |
| edge_feature=skel_graph.edge_feature, |
| tf=skel_graph.tf if hasattr(skel_graph, 'tf') else torch.zeros(skel_graph.lo.shape[0], 768) |
| ) |
| |
| |
| graphs = [SkelPoseGraph(skel_data, None) for _ in range(seq_length)] |
| return graphs |
|
|
|
|
| def process_hatD_to_qrc(hatD_full, src_batch_full, actual_frames, num_nodes_per_frame, |
| out_rep_cfg, ms_dict): |
| """ |
| Convert hatD into q, r, and c. |
| |
| Args: |
| hatD_full: [T*num_nodes, D] decoded features |
| src_batch_full: complete source batch used for post-processing |
| actual_frames: actual frame count |
| num_nodes_per_frame: nodes per frame |
| out_rep_cfg: output representation config |
| ms_dict: mean/std dictionary |
| |
| Returns: |
| q: [T, nJ, 6] - quaternion (6D representation) |
| r: [T, 1, 4] - root transform |
| c: [T, nJ, 1] - contact |
| """ |
| from sata.mymodel import parse_hatD |
|
|
| |
| root_ids = src_batch_full.ptr[:-1] |
| out = parse_hatD(hatD_full, root_ids, out_rep_cfg, ms_dict) |
| |
| |
| q = out.get('q', None) |
| r = out.get('r', None) |
| c = out.get('c', None) |
| |
| |
| if q is not None: |
| q = q.view(actual_frames, num_nodes_per_frame, -1) |
| if r is not None: |
| |
| r = r.unsqueeze(1) |
| if c is not None: |
| c = c.view(actual_frames, num_nodes_per_frame, -1) |
| |
| return q, r, c |
|
|
|
|
| def compute_qv_from_qR(qR): |
| """ |
| Compute angular velocity from a rotation-matrix sequence. |
| qR: [T, nJ, 3, 3] rotation matrices |
| Returns qv: [T, nJ, 6] angular velocity in 6D representation |
| |
| Based on motion_to_graph.py: |
| q_vel[1:] = rotations[:-1].swapaxes(-2, -1) @ rotations[1:] |
| """ |
| T, nJ = qR.shape[0], qR.shape[1] |
| |
| |
| q_vel_R = torch.eye(3, device=qR.device, dtype=qR.dtype)[None, None, ...].repeat(T, nJ, 1, 1) |
| |
| |
| if T > 1: |
| q_vel_R[1:] = qR[:-1].transpose(-2, -1) @ qR[1:] |
| |
| |
| q_vel_flat = q_vel_R.reshape(-1, 3, 3) |
| |
| qv_flat = torch.cat([q_vel_flat[:, :, 0], q_vel_flat[:, :, 1]], dim=-1) |
| qv = qv_flat.reshape(T, nJ, 6) |
| |
| return qv |
|
|
|
|
| def reconstruct_p_pv_qv_from_qrc(q, r, c, src_batch, consq_n, device): |
| """ |
| Reconstruct p, pv, and qv from q, r, and c. |
| Mirrors the implementation in reconstruction_qrc_2_same.py. |
| |
| Args: |
| q: [T, nJ, 6] - quaternion (6D representation) |
| r: [T, 1, 4] - root transform |
| c: [T, nJ, 1] - contact |
| src_batch: Batch object containing skeleton info |
| consq_n: T, sequence length |
| device: torch device |
| |
| Returns: |
| p: [T-1, nJ, 3] - joint positions (excluding frame 1) |
| pv: [T-1, nJ, 3] - joint velocities (excluding frame 1) |
| qv: [T-1, nJ, 6] - joint angular velocities (excluding frame 1) |
| q_out: [T-1, nJ, 6] - joint rotations (excluding frame 1) |
| r_out: [T-1, 4] - root transform (excluding frame 1) |
| c_out: [T-1, nJ, 1] - contact (excluding frame 1) |
| |
| Note: every output drops the first frame to keep temporal dimensions consistent. |
| """ |
| from sata.mymodel import FK, accum_root |
| from sata.utils import tensor_utils |
|
|
| nJ = q.shape[1] |
| |
| |
| q_flat = q.reshape(-1, 6) |
| qR_flat = tensor_utils.tensor_q2qR(q_flat) |
| qR = qR_flat.reshape(consq_n, nJ, 3, 3) |
| |
| |
| r_squeezed = r.squeeze(1) |
| fk_T_flat = FK( |
| lo=src_batch.lo, |
| qR=qR_flat, |
| r=r_squeezed, |
| root_ids=src_batch.ptr[:-1], |
| skel_depth=src_batch.skel_depth, |
| skel_edge_index=src_batch.edge_index, |
| ) |
| |
| |
| p_flat = fk_T_flat[..., :3, 3] |
| p_full = p_flat.reshape(consq_n, nJ, 3) |
| |
| |
| qv_full = compute_qv_from_qR(qR) |
| |
| |
| |
| r_for_accum = r |
| rT_accum = accum_root(r_for_accum, consq_n, apply_height=False, grad_truncate_k=0) |
| facing_transforms = rT_accum[:, 0, :, :] |
| |
| |
| p_T = tensor_utils.tensor_p2T(p_full.reshape(-1, 3)) |
| p_T = p_T.reshape(consq_n, nJ, 4, 4) |
| |
| facing_T_expanded = facing_transforms.unsqueeze(1) |
| global_p_T = facing_T_expanded @ p_T |
| global_p = global_p_T[..., :3, 3] |
| |
| |
| global_p_vel = torch.zeros_like(global_p) |
| if consq_n > 1: |
| global_p_vel[1:] = global_p[1:] - global_p[:-1] |
| |
| |
| facing_inv_rot = torch.inverse(facing_transforms)[:, :3, :3] |
| facing_inv_rot = facing_inv_rot.unsqueeze(1) |
| local_p_vel = (facing_inv_rot @ global_p_vel.unsqueeze(-1)).squeeze(-1) |
| |
| |
| pv_full = local_p_vel * 30.0 |
| |
| |
| |
| |
| if consq_n > 1: |
| p = p_full[1:] |
| pv = pv_full[1:] |
| qv = qv_full[1:] |
| q_out = q[1:] |
| r_out = r_squeezed[1:] |
| c_out = c[1:] |
| else: |
| |
| p = p_full |
| pv = pv_full |
| qv = qv_full |
| q_out = q |
| r_out = r_squeezed |
| c_out = c |
| |
| return p, pv, qv, q_out, r_out, c_out |
|
|
|
|
| def save_processed_with_tf_and_meta(data_dict, output_dir, filename): |
| """ |
| Save processed data, joint_text_features, and metadata. |
| Args: |
| data_dict: skeleton data, motion features, tf, text, m_len, and related metadata |
| output_dir: output root directory |
| filename: file stem without extension |
| """ |
| import json |
| |
| |
| processed_dir = pjoin(output_dir, 'processed') |
| tf_dir = pjoin(output_dir, 'joint_text_features') |
| os.makedirs(processed_dir, exist_ok=True) |
| os.makedirs(tf_dir, exist_ok=True) |
| |
| |
| npz_dict = {} |
| for key, value in data_dict.items(): |
| |
| if key in ['tf', 'text', 'src_filename', 'is_segment', 'segment_info']: |
| continue |
| if isinstance(value, torch.Tensor): |
| npz_dict[key] = value.cpu().numpy() |
| else: |
| npz_dict[key] = value |
| |
| npz_path = pjoin(processed_dir, f'{filename}.npz') |
| np.savez(npz_path, **npz_dict) |
| |
| |
| if 'tf' in data_dict: |
| tf_path = pjoin(tf_dir, f'{filename}.npz') |
| tf_value = data_dict['tf'] |
| if isinstance(tf_value, torch.Tensor): |
| tf_value = tf_value.cpu().numpy() |
| np.savez(tf_path, tf=tf_value) |
| |
| |
| if data_dict.get('is_segment', False) or 'text' in data_dict: |
| meta_dir = pjoin(output_dir, 'meta') |
| os.makedirs(meta_dir, exist_ok=True) |
| meta_path = pjoin(meta_dir, f'{filename}.json') |
| |
| meta_dict = {} |
| if 'text' in data_dict: |
| meta_dict['text'] = data_dict['text'] |
| if data_dict.get('is_segment', False): |
| meta_dict['is_segment'] = True |
| meta_dict['segment_info'] = data_dict.get('segment_info', {}) |
| meta_dict['src_filename'] = data_dict.get('src_filename', filename) |
| meta_dict['m_len'] = data_dict.get('m_len', 0) |
| |
| with open(meta_path, 'w') as f: |
| json.dump(meta_dict, f, indent=2) |
|
|
|
|
| def bvh_2_SkelPoseGraph(bvh_path): |
| """ |
| Load a skeleton from BVH and convert it to SkelPoseGraph without tf. |
| |
| Args: |
| bvh_path: BVH file path |
| |
| Returns: |
| skel_graph: skeleton graph object (SkelPoseGraph) |
| |
| Note: |
| This path has no text features, so decoding uses zero tf. |
| """ |
| from fairmotion.data import bvh |
| from sata.conversions.motion_to_graph import skel_2_graph |
|
|
| print(f"Loading skeleton from BVH: {bvh_path}") |
| motion = bvh.load(bvh_path, ignore_root_skel=True, ee_as_joint=True) |
| |
| |
| from sata.utils.motion_utils import motion_normalize_h2s |
| motion, tpose = motion_normalize_h2s(motion, False) |
| |
| skel = motion.skel |
| text_feature = np.zeros((skel.num_joints(), 768), dtype=np.float32) |
| skel_graph = skel_2_graph(skel, text_feature) |
| print(f" Skeleton joints: {skel.num_joints()}") |
| print(" [Warning] BVH has no text features; decoding will use zero tf") |
| |
| return skel_graph |
|
|
|
|
| def fix_skeleton_coordinate_system(motion): |
| """ |
| Convert a motion from Z-up to Y-up coordinates with local-axis retargeting. |
| |
| Converts the entire motion from Z-up to Y-up, including: |
| 1. Skeleton OFFSET conversion |
| 2. Root position conversion for every frame |
| 3. Local rotation retargeting for every joint in every frame |
| |
| Key idea: |
| - When OFFSET changes, local rotations must be adjusted to preserve the visual pose. |
| - Local rotation conversion: R_new = R_fix @ R_old @ R_fix^T |
| |
| Observations: |
| - Original OFFSET: (0, 0.184, 0) -> Y-up |
| - Current OFFSET: (0, 0, -0.184) -> Z-down |
| - Requires a -90 or 90 degree rotation around the X axis |
| |
| Args: |
| motion: fairmotion Motion object in Z-up coordinates |
| |
| Returns: |
| motion: converted Motion object in Y-up coordinates |
| """ |
| from fairmotion.ops import conversions |
| from fairmotion.ops import motion as motion_ops |
|
|
| import copy |
| |
| |
| motion_copy = copy.deepcopy(motion) |
| |
| |
| |
| |
| axis_angle = np.array([np.pi / 2, 0.0, 0.0]) |
| R_fix = conversions.A2R(axis_angle) |
| R_fix_T = R_fix.T |
| |
| |
| for joint in motion_copy.skel.joints: |
| old_offset = joint.xform_from_parent_joint[:3, 3].copy() |
| |
| new_offset = np.dot(R_fix, old_offset) |
| joint.xform_from_parent_joint[:3, 3] = new_offset |
| |
| |
| joint.xform_from_parent_joint[:3, :3] = np.eye(3) |
| |
| |
| for frame_idx in range(motion_copy.num_frames()): |
| pose = motion_copy.get_pose_by_frame(frame_idx) |
| |
| for joint_idx, joint in enumerate(motion_copy.skel.joints): |
| |
| T_old = pose.data[joint_idx].copy() |
| |
| |
| R_old = T_old[:3, :3] |
| p_old = T_old[:3, 3] |
| |
| |
| if joint.parent_joint is None: |
| |
| p_new = np.dot(R_fix, p_old) |
| |
| R_new = np.dot(np.dot(R_fix, R_old), R_fix_T) |
| else: |
| |
| |
| R_new = np.dot(np.dot(R_fix, R_old), R_fix_T) |
| |
| p_new = np.dot(R_fix, p_old) |
| |
| |
| T_new = np.eye(4) |
| T_new[:3, :3] = R_new |
| T_new[:3, 3] = p_new |
| |
| |
| pose.data[joint_idx] = T_new |
| |
| |
| |
| |
| rx = conversions.A2R(np.array([-np.pi / 2, 0.0, 0.0])) |
| motion_final = motion_ops.rotate(motion_copy, rx) |
| |
| return motion_final |
|
|
| |