SATA / src /sata /utils /motion_data.py
zzysteve
Initial commit
5221c8c
Raw
History Blame Contribute Delete
16.5 kB
"""Motion-data helpers for skeleton loading, graph creation, reconstruction, and serialization."""
import os
import torch
import numpy as np
from os.path import join as pjoin
from sata.mydataset import SkelData
from sata.skel_pose_graph import SkelPoseGraph
def load_skeleton_from_npz(npz_path):
"""
Load skeleton data from an npz file without text features.
Args:
npz_path: motion npz path containing skeleton data
Returns:
skel_data: SkelData object with tf set to None
"""
# Load skeleton data.
data = np.load(npz_path)
# Extract intrinsic skeleton data.
lo = data['lo'] # [nJ, 3]
go = data['go'] # [nJ, 3]
qb = data['qb'] # [nJ]
edges = data['edges'] # [nE, 4]
# Sort edges by child index.
if not (np.arange(edges.shape[0]) == edges[:, 1]).all():
edges = edges[np.argsort(edges[:, 1])]
# Create SkelData with tf set to None.
skel_data = SkelData(
torch.Tensor(lo),
torch.Tensor(go),
torch.BoolTensor(qb),
torch.LongTensor(edges[:, :2]).transpose(1, 0), # [2, nE]
torch.LongTensor(edges[:, 2:]), # [nE, 2]
None, # tf is None
)
return skel_data
def load_skeleton_and_tf_from_npz(npz_path, tf_npz_path):
"""
Load skeleton data and text features from npz files.
Args:
npz_path: motion npz path containing skeleton data
tf_npz_path: text-feature npz path containing tf
Returns:
skel_data: SkelData object containing tf
"""
# Load skeleton data without tf.
skel_data = load_skeleton_from_npz(npz_path)
# Load text features.
try:
tf_data = np.load(tf_npz_path)
if 'tf' not in tf_data:
raise KeyError(f"'tf' key not found in {tf_npz_path}")
tf = tf_data['tf'] # [nJ, 768]
except Exception as e:
print(f"[Warning] Failed to load tf from {tf_npz_path}: {e}")
print(f"[Warning] Using zero tf instead")
nJ = skel_data.lo.shape[0]
tf = np.zeros((nJ, 768), dtype=np.float32)
# Attach tf to SkelData.
skel_data.tf = torch.Tensor(tf)
return skel_data
def create_graph_list_from_skeleton(skel_data, seq_length):
"""
Create a graph list of the requested length from skeleton data.
Args:
skel_data: SkelData object
seq_length: sequence length in frames
Returns:
graphs: list of SkelPoseGraph
"""
# Create one graph per frame with skeleton data only.
graphs = [SkelPoseGraph(skel_data, None) for _ in range(seq_length)]
return graphs
def create_graph_list_from_single_graph(skel_graph, seq_length):
"""
Create a graph list of the requested length from one SkelPoseGraph.
Args:
skel_graph: SkelPoseGraph object containing skeleton data only
seq_length: sequence length in frames
Returns:
graphs: list of SkelPoseGraph
"""
# Extract SkelData to avoid repeated conversion.
skel_data = SkelData(
lo=skel_graph.lo,
go=skel_graph.go,
qb=skel_graph.qb,
edge_index=skel_graph.edge_index,
edge_feature=skel_graph.edge_feature,
tf=skel_graph.tf if hasattr(skel_graph, 'tf') else torch.zeros(skel_graph.lo.shape[0], 768)
)
# Create the graph list.
graphs = [SkelPoseGraph(skel_data, None) for _ in range(seq_length)]
return graphs
def process_hatD_to_qrc(hatD_full, src_batch_full, actual_frames, num_nodes_per_frame,
out_rep_cfg, ms_dict):
"""
Convert hatD into q, r, and c.
Args:
hatD_full: [T*num_nodes, D] decoded features
src_batch_full: complete source batch used for post-processing
actual_frames: actual frame count
num_nodes_per_frame: nodes per frame
out_rep_cfg: output representation config
ms_dict: mean/std dictionary
Returns:
q: [T, nJ, 6] - quaternion (6D representation)
r: [T, 1, 4] - root transform
c: [T, nJ, 1] - contact
"""
from sata.mymodel import parse_hatD
# Extract q, r, and c directly with parse_hatD.
root_ids = src_batch_full.ptr[:-1]
out = parse_hatD(hatD_full, root_ids, out_rep_cfg, ms_dict)
# Extract q, r, and c according to the config.
q = out.get('q', None) # quaternion [T*nJ, 6]
r = out.get('r', None) # root position [T, 4]
c = out.get('c', None) # contact [T*nJ, 1]
# Reshape q, r, c to [T, nJ, ...] format
if q is not None:
q = q.view(actual_frames, num_nodes_per_frame, -1) # [T, nJ, 6]
if r is not None:
# r is already [T, 4] because parse_hatD handles root_ids.
r = r.unsqueeze(1) # [T, 1, 4] for downstream consistency
if c is not None:
c = c.view(actual_frames, num_nodes_per_frame, -1) # [T, nJ, 1]
return q, r, c
def compute_qv_from_qR(qR):
"""
Compute angular velocity from a rotation-matrix sequence.
qR: [T, nJ, 3, 3] rotation matrices
Returns qv: [T, nJ, 6] angular velocity in 6D representation
Based on motion_to_graph.py:
q_vel[1:] = rotations[:-1].swapaxes(-2, -1) @ rotations[1:]
"""
T, nJ = qR.shape[0], qR.shape[1]
# Initialize as identity matrices.
q_vel_R = torch.eye(3, device=qR.device, dtype=qR.dtype)[None, None, ...].repeat(T, nJ, 1, 1)
# Compute relative rotation: R[t-1].T @ R[t].
if T > 1:
q_vel_R[1:] = qR[:-1].transpose(-2, -1) @ qR[1:]
# Convert to 6D representation.
q_vel_flat = q_vel_R.reshape(-1, 3, 3) # [T*nJ, 3, 3]
# Use the first two columns as the 6D representation.
qv_flat = torch.cat([q_vel_flat[:, :, 0], q_vel_flat[:, :, 1]], dim=-1) # [T*nJ, 6]
qv = qv_flat.reshape(T, nJ, 6)
return qv
def reconstruct_p_pv_qv_from_qrc(q, r, c, src_batch, consq_n, device):
"""
Reconstruct p, pv, and qv from q, r, and c.
Mirrors the implementation in reconstruction_qrc_2_same.py.
Args:
q: [T, nJ, 6] - quaternion (6D representation)
r: [T, 1, 4] - root transform
c: [T, nJ, 1] - contact
src_batch: Batch object containing skeleton info
consq_n: T, sequence length
device: torch device
Returns:
p: [T-1, nJ, 3] - joint positions (excluding frame 1)
pv: [T-1, nJ, 3] - joint velocities (excluding frame 1)
qv: [T-1, nJ, 6] - joint angular velocities (excluding frame 1)
q_out: [T-1, nJ, 6] - joint rotations (excluding frame 1)
r_out: [T-1, 4] - root transform (excluding frame 1)
c_out: [T-1, nJ, 1] - contact (excluding frame 1)
Note: every output drops the first frame to keep temporal dimensions consistent.
"""
from sata.mymodel import FK, accum_root
from sata.utils import tensor_utils
nJ = q.shape[1]
# 1. Convert q from 6D representation to rotation matrices qR.
q_flat = q.reshape(-1, 6) # [T*nJ, 6]
qR_flat = tensor_utils.tensor_q2qR(q_flat) # [T*nJ, 3, 3]
qR = qR_flat.reshape(consq_n, nJ, 3, 3) # [T, nJ, 3, 3]
# 2. Run forward kinematics.
r_squeezed = r.squeeze(1) # [T, 1, 4] -> [T, 4]
fk_T_flat = FK(
lo=src_batch.lo, # [T*nJ, 3]
qR=qR_flat, # [T*nJ, 3, 3]
r=r_squeezed, # [T, 4]
root_ids=src_batch.ptr[:-1], # [T]
skel_depth=src_batch.skel_depth, # [T*nJ]
skel_edge_index=src_batch.edge_index, # [2, T*nE]
) # [T*nJ, 4, 4]
# Extract positions.
p_flat = fk_T_flat[..., :3, 3] # [T*nJ, 3]
p_full = p_flat.reshape(consq_n, nJ, 3) # [T, nJ, 3]
# 3. Compute angular velocity qv.
qv_full = compute_qv_from_qR(qR) # [T, nJ, 6]
# 4. Compute positional velocity pv.
# Accumulate root transforms.
r_for_accum = r # [T, 1, 4]
rT_accum = accum_root(r_for_accum, consq_n, apply_height=False, grad_truncate_k=0) # [T, 1, 4, 4]
facing_transforms = rT_accum[:, 0, :, :] # [T, 4, 4]
# Convert p to global coordinates.
p_T = tensor_utils.tensor_p2T(p_full.reshape(-1, 3)) # [T*nJ, 4, 4]
p_T = p_T.reshape(consq_n, nJ, 4, 4) # [T, nJ, 4, 4]
facing_T_expanded = facing_transforms.unsqueeze(1) # [T, 1, 4, 4]
global_p_T = facing_T_expanded @ p_T # [T, nJ, 4, 4]
global_p = global_p_T[..., :3, 3] # [T, nJ, 3]
# Compute global position differences.
global_p_vel = torch.zeros_like(global_p)
if consq_n > 1:
global_p_vel[1:] = global_p[1:] - global_p[:-1]
# Convert back to the facing frame.
facing_inv_rot = torch.inverse(facing_transforms)[:, :3, :3] # [T, 3, 3]
facing_inv_rot = facing_inv_rot.unsqueeze(1) # [T, 1, 3, 3]
local_p_vel = (facing_inv_rot @ global_p_vel.unsqueeze(-1)).squeeze(-1) # [T, nJ, 3]
# Multiply by FPS (30).
pv_full = local_p_vel * 30.0
# 5. Drop the first frame for every feature to keep dimensions consistent.
# pv and qv need the previous frame, so their first frame is undefined.
# Drop the first frame from every feature for consistency.
if consq_n > 1:
p = p_full[1:] # [T-1, nJ, 3]
pv = pv_full[1:] # [T-1, nJ, 3]
qv = qv_full[1:] # [T-1, nJ, 6]
q_out = q[1:] # [T-1, nJ, 6]
r_out = r_squeezed[1:] # [T-1, 4]
c_out = c[1:] # [T-1, nJ, 1]
else:
# Single-frame fallback; this should not happen in normal inputs.
p = p_full
pv = pv_full
qv = qv_full
q_out = q
r_out = r_squeezed
c_out = c
return p, pv, qv, q_out, r_out, c_out
def save_processed_with_tf_and_meta(data_dict, output_dir, filename):
"""
Save processed data, joint_text_features, and metadata.
Args:
data_dict: skeleton data, motion features, tf, text, m_len, and related metadata
output_dir: output root directory
filename: file stem without extension
"""
import json
# Create subdirectories.
processed_dir = pjoin(output_dir, 'processed')
tf_dir = pjoin(output_dir, 'joint_text_features')
os.makedirs(processed_dir, exist_ok=True)
os.makedirs(tf_dir, exist_ok=True)
# 1. Save the npz file to processed/ with skeleton and motion data, excluding tf.
npz_dict = {}
for key, value in data_dict.items():
# Skip tf and metadata fields.
if key in ['tf', 'text', 'src_filename', 'is_segment', 'segment_info']:
continue
if isinstance(value, torch.Tensor):
npz_dict[key] = value.cpu().numpy()
else:
npz_dict[key] = value
npz_path = pjoin(processed_dir, f'{filename}.npz')
np.savez(npz_path, **npz_dict)
# 2. Save tf to joint_text_features/.
if 'tf' in data_dict:
tf_path = pjoin(tf_dir, f'{filename}.npz')
tf_value = data_dict['tf']
if isinstance(tf_value, torch.Tensor):
tf_value = tf_value.cpu().numpy()
np.savez(tf_path, tf=tf_value)
# 3. Save metadata to meta/ for segments and text records.
if data_dict.get('is_segment', False) or 'text' in data_dict:
meta_dir = pjoin(output_dir, 'meta')
os.makedirs(meta_dir, exist_ok=True)
meta_path = pjoin(meta_dir, f'{filename}.json')
meta_dict = {}
if 'text' in data_dict:
meta_dict['text'] = data_dict['text']
if data_dict.get('is_segment', False):
meta_dict['is_segment'] = True
meta_dict['segment_info'] = data_dict.get('segment_info', {})
meta_dict['src_filename'] = data_dict.get('src_filename', filename)
meta_dict['m_len'] = data_dict.get('m_len', 0)
with open(meta_path, 'w') as f:
json.dump(meta_dict, f, indent=2)
def bvh_2_SkelPoseGraph(bvh_path):
"""
Load a skeleton from BVH and convert it to SkelPoseGraph without tf.
Args:
bvh_path: BVH file path
Returns:
skel_graph: skeleton graph object (SkelPoseGraph)
Note:
This path has no text features, so decoding uses zero tf.
"""
from fairmotion.data import bvh
from sata.conversions.motion_to_graph import skel_2_graph
print(f"Loading skeleton from BVH: {bvh_path}")
motion = bvh.load(bvh_path, ignore_root_skel=True, ee_as_joint=True)
# The skeleton must be normalized before graph conversion.
from sata.utils.motion_utils import motion_normalize_h2s
motion, tpose = motion_normalize_h2s(motion, False)
skel = motion.skel
text_feature = np.zeros((skel.num_joints(), 768), dtype=np.float32)
skel_graph = skel_2_graph(skel, text_feature)
print(f" Skeleton joints: {skel.num_joints()}")
print(" [Warning] BVH has no text features; decoding will use zero tf")
return skel_graph
def fix_skeleton_coordinate_system(motion):
"""
Convert a motion from Z-up to Y-up coordinates with local-axis retargeting.
Converts the entire motion from Z-up to Y-up, including:
1. Skeleton OFFSET conversion
2. Root position conversion for every frame
3. Local rotation retargeting for every joint in every frame
Key idea:
- When OFFSET changes, local rotations must be adjusted to preserve the visual pose.
- Local rotation conversion: R_new = R_fix @ R_old @ R_fix^T
Observations:
- Original OFFSET: (0, 0.184, 0) -> Y-up
- Current OFFSET: (0, 0, -0.184) -> Z-down
- Requires a -90 or 90 degree rotation around the X axis
Args:
motion: fairmotion Motion object in Z-up coordinates
Returns:
motion: converted Motion object in Y-up coordinates
"""
from fairmotion.ops import conversions
from fairmotion.ops import motion as motion_ops
import copy
# Create a deep copy.
motion_copy = copy.deepcopy(motion)
# Define the coordinate conversion rotation matrix.
# (0, 0, -0.184) -> (0, 0.184, 0)
# This uses a 90 degree rotation around the X axis.
axis_angle = np.array([np.pi / 2, 0.0, 0.0])
R_fix = conversions.A2R(axis_angle)
R_fix_T = R_fix.T # transpose; for rotation matrices, transpose equals inverse
# Step 1: convert skeleton OFFSET values.
for joint in motion_copy.skel.joints:
old_offset = joint.xform_from_parent_joint[:3, 3].copy()
# Apply rotation to the offset vector.
new_offset = np.dot(R_fix, old_offset)
joint.xform_from_parent_joint[:3, 3] = new_offset
# Keep the transform rotation part as identity.
joint.xform_from_parent_joint[:3, :3] = np.eye(3)
# Step 2: convert each frame.
for frame_idx in range(motion_copy.num_frames()):
pose = motion_copy.get_pose_by_frame(frame_idx)
for joint_idx, joint in enumerate(motion_copy.skel.joints):
# Read the current joint local transform matrix (4x4).
T_old = pose.data[joint_idx].copy()
# Split rotation (3x3) and translation (3,).
R_old = T_old[:3, :3]
p_old = T_old[:3, 3]
# Root joints need global position conversion.
if joint.parent_joint is None:
# Root joint: convert global position.
p_new = np.dot(R_fix, p_old)
# Convert root rotation too.
R_new = np.dot(np.dot(R_fix, R_old), R_fix_T)
else:
# Non-root joint: retarget the local coordinate frame.
# R_new = R_fix @ R_old @ R_fix^T
R_new = np.dot(np.dot(R_fix, R_old), R_fix_T)
# Local translation is usually defined by OFFSET; keep it consistent here.
p_new = np.dot(R_fix, p_old)
# Build the new transform matrix.
T_new = np.eye(4)
T_new[:3, :3] = R_new
T_new[:3, 3] = p_new
# Update pose data.
pose.data[joint_idx] = T_new
# Step 3: apply an extra global rotation fix.
# Apply a -90 degree rotation around X to the entire motion sequence.
# This independent global transform adjusts the final motion direction.
rx = conversions.A2R(np.array([-np.pi / 2, 0.0, 0.0]))
motion_final = motion_ops.rotate(motion_copy, rx)
return motion_final