SATA / src /Visualization /decode_z_sliding.py
zzysteve
Initial commit
5221c8c
Raw
History Blame Contribute Delete
19.3 kB
"""
Batch decode latent vectors to BVH files.
Supports sliding-window decoding and multiple input formats. The Gradio demo
uses decode_z_skel_to_bvh directly; the CLI keeps batch decoding support for
local debugging.
"""
import argparse
import os
import numpy as np
import torch
from fairmotion.data import bvh
from torch_geometric.data import Batch
# Required conversion helpers
from sata.conversions.graph_to_motion import hatD_recon_motion
from sata.skel_pose_graph import SkelPoseGraph
from sata.conversions.graph_to_motion import graph_2_skel
from fairmotion.core import motion as motion_class
from sata.utils.model_loading import load_model_by_type
# Reuse shared SATA motion-data helpers.
from sata.utils.motion_data import (
SkelData,
load_skeleton_and_tf_from_npz,
create_graph_list_from_skeleton,
fix_skeleton_coordinate_system,
)
from sata.utils.sliding_decode import decode_to_hatD_with_sliding_window
# Default directory configuration
VISUAL_BASE_DIR = os.path.join(os.path.dirname(__file__), "visual")
Z_DIR = os.path.join(VISUAL_BASE_DIR, "z")
SKEL_DIR = os.path.join(VISUAL_BASE_DIR, "skel", "processed")
SKEL_TF_DIR = os.path.join(VISUAL_BASE_DIR, "skel", "joint_text_features")
OUTPUT_DIR = os.path.join(VISUAL_BASE_DIR, "bvh_output")
# visual/z.txt contains one z file path per line.
# visual/skel.txt contains one skeleton npz path per line.
# The tf path is derived by replacing processed with joint_text_features.
# OUTPUT_DIR: visual/bvh_output
def decode_z_skel_to_bvh(z_path, skel_path, model_epoch, output_dir,
output_name="decoded_motion.bvh", device="cuda:0",
window_size=64, overlap=16, use_sliding_window=None,
model_type="vae", model_state=None):
"""
Decode a latent/code file with a target skeleton and save it as BVH.
Args:
z_path: path to a latent/code file (.npy, .npz, or .pt)
- VAE: latent vectors [T, z_dim]
- RVQ: code_idx [T, Q]
skel_path: target skeleton .npz path. The matching tf file is loaded automatically.
model_epoch: model checkpoint name, for example "ckpt0"
output_dir: output directory path
output_name: output file stem, default "decoded_motion.bvh"
device: compute device, default "cuda:0"
window_size: sliding-window size, default 64
overlap: sliding-window overlap, default 16
use_sliding_window: force sliding-window decoding; None selects automatically
model_type: "vae" or "rvq", default "vae"
model_state: optional preloaded (model, cfg, ms_dict); loaded on demand when None
Returns:
output_path: saved BVH path
"""
# 1. Load the model or reuse a preloaded model state.
if model_state is not None:
model, cfg, ms_dict = model_state
print(f"[Model] Reusing preloaded model ({model_type.upper()})")
else:
print(f"Loading model: {model_epoch} ({model_type.upper()})")
model, cfg, ms_dict = load_model_by_type(model_type, model_epoch, device)
model = model.to(device)
model.eval()
# 2. Load latent vectors or code indices.
print(f"Loading {'latent vectors' if model_type == 'vae' else 'code_idx'}: {z_path}")
if z_path.endswith('.npy'):
# NumPy .npy file
data = torch.from_numpy(np.load(z_path)).float().to(device)
elif z_path.endswith('.npz'):
# NumPy .npz file, possibly with multiple arrays
data_loaded = np.load(z_path)
# Try common payload keys.
possible_keys = ['z', 'z_pred', 'sample', 'code_idx', 'codes', 'idx_pred']
data_key = None
for key in possible_keys:
if key in data_loaded:
data_key = key
break
if data_key is None:
# Fall back to the first key.
data_key = list(data_loaded.keys())[0]
print(f" Loaded key from npz: {data_key}")
else:
print(f" Loaded key from npz: {data_key}")
data = torch.from_numpy(data_loaded[data_key]).float().to(device)
elif z_path.endswith('.pt'):
# PyTorch .pt file compatible with gen_dec.py encoded data.
encoded_data = torch.load(z_path, map_location=device)
# Extract the payload for the selected model type.
if model_type == "vae":
# VAE: prefer z_pred, then z, then sample.
if 'z_pred' in encoded_data:
data = encoded_data['z_pred'].float().to(device)
print(" Loaded key from pt: z_pred")
elif 'z' in encoded_data:
data = encoded_data['z'].float().to(device)
print(" Loaded key from pt: z")
elif 'sample' in encoded_data:
data = encoded_data['sample'].float().to(device)
print(" Loaded key from pt: sample")
else:
raise KeyError("pt file does not contain any of: z_pred, z, sample")
elif model_type == "rvq":
# RVQ: prefer idx_pred, then code_idx.
if 'idx_pred' in encoded_data:
data = encoded_data['idx_pred'].long().to(device)
print(" Loaded key from pt: idx_pred")
elif 'code_idx' in encoded_data:
data = encoded_data['code_idx'].long().to(device)
print(" Loaded key from pt: code_idx")
else:
raise KeyError("pt file does not contain either idx_pred or code_idx")
# Print metadata when present.
if 'text' in encoded_data:
print(f" Text metadata: {encoded_data['text']}")
if 'm_len' in encoded_data:
print(f" Original length: {encoded_data['m_len']}")
if 'is_segment' in encoded_data and encoded_data['is_segment']:
print(" Note: this file contains a segment")
else:
raise ValueError(f"Unsupported file format: {z_path}. Only .npy, .npz, and .pt are supported")
# Normalize payload dtype for the selected model type.
if model_type == "vae":
z = data
data_length = z.shape[0]
print(f" z shape: {z.shape}")
elif model_type == "rvq":
code_idx = data.long() # RVQ code_idx must be long.
data_length = code_idx.shape[0]
print(f" code_idx shape: {code_idx.shape}")
# 3. Load the skeleton and convert it to a graph. Only NPZ is supported.
if not skel_path.endswith('.npz'):
raise ValueError(f"Unsupported skeleton format: {skel_path}. Only .npz is supported")
# NPZ format contains skeleton data and a matching tf file.
# Build the tf path by replacing processed with joint_text_features.
print(f"Loading skeleton from NPZ: {skel_path}")
tf_npz_path = skel_path.replace('/processed/', '/joint_text_features/')
# If the path did not change, try the common parent directory pattern.
if tf_npz_path == skel_path:
import pathlib
path_obj = pathlib.Path(skel_path)
parent = path_obj.parent
if 'processed' in str(parent):
new_parent = str(parent).replace('processed', 'joint_text_features')
tf_npz_path = os.path.join(new_parent, path_obj.name)
# Load through utils.py.
skel_data = load_skeleton_and_tf_from_npz(skel_path, tf_npz_path)
print(f" Skeleton joints: {skel_data.lo.shape[0]}")
print(f" tf shape: {skel_data.tf.shape}")
skel_graph = SkelPoseGraph(skel_data, None) # Convert directly to a graph.
# Keep the skeleton for later BVH generation.
saved_skel = graph_2_skel(Batch.from_data_list([skel_graph]).to(device), 1)[0]
# 4. Decide whether to use sliding-window decoding.
if use_sliding_window is None:
# Automatically use sliding windows for sequences longer than the window size.
use_sliding_window = (data_length > window_size)
# 5. Decode to motion.
print("Decoding...")
out_rep_cfg = cfg["representation"]["out"]
with torch.no_grad():
if use_sliding_window and data_length > window_size:
# Decode with the shared sliding-window implementation.
print(f" Long sequence ({data_length} frames); using sliding-window decoding")
# Prepare latent codes for the selected model type.
latent_codes = z if model_type == "vae" else code_idx
# Create a graph list of the required length from one skeleton graph.
from sata.utils.motion_data import create_graph_list_from_single_graph
src_graphs_list = create_graph_list_from_single_graph(skel_graph, data_length)
# Call the shared sliding-window decoder.
hatD_full, src_batch_full, actual_frames, num_nodes_per_frame = \
decode_to_hatD_with_sliding_window(
model, latent_codes, src_graphs_list,
data_length, window_size, overlap, device, model_type
)
# Reconstruct motion with hatD_recon_motion.
out_motion_list, out_contact_list = hatD_recon_motion(
hatD_full, src_batch_full, out_rep_cfg, ms_dict, actual_frames
)
out_motion = out_motion_list[0]
out_contact = out_contact_list[0]
else:
# Use single-pass decoding.
print(f" Short sequence ({data_length} frames); using single-pass decoding")
if model_type == "vae":
# VAE single-pass decode.
skel_batch = Batch.from_data_list([skel_graph] * data_length).to(device)
hatD = model.decode(z, skel_batch, data_length)
elif model_type == "rvq":
# RVQ single-pass decode.
skel_batch = Batch.from_data_list([skel_graph] * data_length).to(device)
hatD, _ = model.decode_from_codes(code_idx, skel_batch, data_length)
# Reconstruct motion.
out_motion_list, out_contact_list = hatD_recon_motion(
hatD, skel_batch, out_rep_cfg, ms_dict, data_length
)
out_motion = out_motion_list[0]
out_contact = out_contact_list[0]
print(f" Output motion frames: {out_motion.num_frames()}")
# 6. Save the BVH file.
if not os.path.exists(output_dir):
os.makedirs(output_dir)
print(f"Created output directory: {output_dir}")
# Set FPS to 20.
out_motion.fps = 20
# Fix coordinate system differences.
out_motion_fixed = fix_skeleton_coordinate_system(out_motion)
print(" [Info] Applied coordinate-system fix (Y-Z axis conversion)")
output_path = os.path.join(output_dir, output_name + '.bvh')
bvh.save(out_motion_fixed, output_path, rot_order="XYZ")
# bvh.save(out_motion, output_path, rot_order="XYZ")
print(f"Saved to: {output_path}")
return output_path
def scan_z_and_skeleton_files(visual_base_dir):
"""
Read z and skeleton path lists from txt files and build all pairings.
Args:
visual_base_dir: visual root containing z.txt and skel.txt
Returns:
pairs: list of tuple, [(z_path, z_name, skel_path, skel_name), ...]
"""
# Read z.txt.
z_files = []
z_txt_path = os.path.join(visual_base_dir, "z.txt")
if os.path.exists(z_txt_path):
with open(z_txt_path, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if line and not line.startswith('#'): # Skip empty lines and comments.
z_path = line
# Use the file stem as the display name.
z_name = os.path.splitext(os.path.basename(z_path))[0]
z_files.append((z_path, z_name))
print(f" Read {len(z_files)} z files from {z_txt_path}")
else:
print(f" Not found: {z_txt_path}")
# Read skel.txt.
skel_files = []
skel_txt_path = os.path.join(visual_base_dir, "skel.txt")
if os.path.exists(skel_txt_path):
with open(skel_txt_path, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if line and not line.startswith('#'): # Skip empty lines and comments.
skel_path = line
# Use the file stem as the display name.
skel_name = os.path.splitext(os.path.basename(skel_path))[0]
skel_files.append((skel_path, skel_name))
print(f" Read {len(skel_files)} skeleton files from {skel_txt_path}")
else:
print(f" Not found: {skel_txt_path}")
# Build the Cartesian product.
pairs = []
for z_path, z_name in z_files:
for skel_path, skel_name in skel_files:
pairs.append((z_path, z_name, skel_path, skel_name))
return pairs
def batch_decode_all(model_epoch, visual_base_dir=VISUAL_BASE_DIR,
output_dir=OUTPUT_DIR, device="cuda:0",
window_size=64, overlap=16, model_type="vae"):
"""
Batch decode every z/skeleton pairing listed in txt files.
Args:
model_epoch: model checkpoint name
visual_base_dir: visual root containing z.txt and skel.txt; default visual/
output_dir: output directory; default visual/bvh_output
device: compute device
window_size: sliding-window size
overlap: sliding-window overlap
model_type: model type
Returns:
results: list of per-pair result dictionaries
"""
# Read txt files and build pairings.
print("="*70)
print("Reading path lists...")
print(f" z.txt: {os.path.join(visual_base_dir, 'z.txt')}")
print(f" skel.txt: {os.path.join(visual_base_dir, 'skel.txt')}")
print("="*70)
pairs = scan_z_and_skeleton_files(visual_base_dir)
if len(pairs) == 0:
print("No z or skeleton files found. Check the directory configuration.")
return []
# Count source files.
z_count = len(set([p[1] for p in pairs]))
skel_count = len(set([p[3] for p in pairs]))
print(f"Found {z_count} z files x {skel_count} skeleton files = {len(pairs)} pairs")
print()
# Show all pairings.
print("Pairings to process:")
for i, (z_path, z_name, skel_path, skel_name) in enumerate(pairs, 1):
output_name = f"{z_name}_{skel_name}"
print(f" [{i}/{len(pairs)}] {output_name}")
print("="*70)
print()
# Process each pair.
results = []
for i, (z_path, z_name, skel_path, skel_name) in enumerate(pairs, 1):
output_name = f"{z_name}_{skel_name}"
print("="*70)
print(f"Processing pair [{i}/{len(pairs)}]: {output_name}")
print("="*70)
print(f" z file: {z_path}")
print(f" skeleton file: {skel_path}")
print()
try:
output_path = decode_z_skel_to_bvh(
z_path=z_path,
skel_path=skel_path,
model_epoch=model_epoch,
output_dir=output_dir,
output_name=output_name,
device=device,
window_size=window_size,
overlap=overlap,
use_sliding_window=None, # Select automatically.
model_type=model_type
)
results.append({
'z_name': z_name,
'skel_name': skel_name,
'output_name': output_name,
'output_path': output_path,
'status': 'success'
})
print(f"[{i}/{len(pairs)}] Success: {output_name}")
except Exception as e:
print(f"[{i}/{len(pairs)}] Failed: {output_name}")
print(f" Error: {str(e)}")
results.append({
'z_name': z_name,
'skel_name': skel_name,
'output_name': output_name,
'output_path': None,
'status': 'failed',
'error': str(e)
})
print()
# Summary.
print("="*70)
print("Batch processing complete")
print("="*70)
success_count = sum(1 for r in results if r['status'] == 'success')
failed_count = sum(1 for r in results if r['status'] == 'failed')
print(f"Total pairs: {len(results)}")
print(f" Success: {success_count}")
print(f" Failed: {failed_count}")
if failed_count > 0:
print()
print("Failed pairs:")
for r in results:
if r['status'] == 'failed':
print(f" - {r['output_name']}: {r.get('error', 'Unknown error')}")
print("="*70)
return results
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Batch decode latent vectors to BVH files", formatter_class=argparse.RawDescriptionHelpFormatter, epilog="")
parser.add_argument("--model_type", type=str, default="vae", choices=["vae", "rvq"], help="Model type: vae or rvq (default: vae)")
parser.add_argument("--model_epoch", type=str, required=True, help="Model checkpoint name, for example ckpt0")
parser.add_argument("--visual_dir", type=str, default=VISUAL_BASE_DIR, help=f"Visual root containing z.txt and skel.txt (default: {VISUAL_BASE_DIR})")
parser.add_argument("--output_dir", type=str, default=OUTPUT_DIR, help=f"Output directory (default: {OUTPUT_DIR})")
parser.add_argument("--device", type=str, default="cuda:0", help="Compute device (default: cuda:0)")
parser.add_argument("--window_size", type=int, default=64, help="Sliding-window size in frames (default: 64)")
parser.add_argument("--overlap", type=int, default=16, help="Sliding-window overlap in frames (default: 16)")
args = parser.parse_args()
print("="*70)
print("Batch latent-vector decoder (reads path lists from txt files)")
print("="*70)
print(f"Model type: {args.model_type.upper()}")
print(f"Model: {args.model_epoch}")
print(f"z path list: {os.path.join(args.visual_dir, 'z.txt')}")
print(f"skeleton path list: {os.path.join(args.visual_dir, 'skel.txt')}")
print(f"Output directory: {args.output_dir}")
print(f"Device: {args.device}")
print(f"Sliding window: size={args.window_size}, overlap={args.overlap}")
print("="*70)
print()
try:
results = batch_decode_all(
model_epoch=args.model_epoch,
visual_base_dir=args.visual_dir,
output_dir=args.output_dir,
device=args.device,
window_size=args.window_size,
overlap=args.overlap,
model_type=args.model_type
)
# Set exit status from the batch result.
failed_count = sum(1 for r in results if r['status'] == 'failed')
if failed_count > 0:
exit(1)
except Exception as e:
print()
print("="*70)
print(f"Error: {str(e)}")
print("="*70)
import traceback
traceback.print_exc()
exit(1)