| """ |
| Batch decode latent vectors to BVH files. |
| |
| Supports sliding-window decoding and multiple input formats. The Gradio demo |
| uses decode_z_skel_to_bvh directly; the CLI keeps batch decoding support for |
| local debugging. |
| """ |
|
|
| import argparse |
| import os |
| import numpy as np |
| import torch |
| from fairmotion.data import bvh |
| from torch_geometric.data import Batch |
|
|
| |
| from sata.conversions.graph_to_motion import hatD_recon_motion |
| from sata.skel_pose_graph import SkelPoseGraph |
| from sata.conversions.graph_to_motion import graph_2_skel |
| from fairmotion.core import motion as motion_class |
|
|
| from sata.utils.model_loading import load_model_by_type |
| |
| from sata.utils.motion_data import ( |
| SkelData, |
| load_skeleton_and_tf_from_npz, |
| create_graph_list_from_skeleton, |
| fix_skeleton_coordinate_system, |
| ) |
| from sata.utils.sliding_decode import decode_to_hatD_with_sliding_window |
|
|
| |
| VISUAL_BASE_DIR = os.path.join(os.path.dirname(__file__), "visual") |
| Z_DIR = os.path.join(VISUAL_BASE_DIR, "z") |
| SKEL_DIR = os.path.join(VISUAL_BASE_DIR, "skel", "processed") |
| SKEL_TF_DIR = os.path.join(VISUAL_BASE_DIR, "skel", "joint_text_features") |
| OUTPUT_DIR = os.path.join(VISUAL_BASE_DIR, "bvh_output") |
|
|
| |
| |
| |
| |
|
|
|
|
| def decode_z_skel_to_bvh(z_path, skel_path, model_epoch, output_dir, |
| output_name="decoded_motion.bvh", device="cuda:0", |
| window_size=64, overlap=16, use_sliding_window=None, |
| model_type="vae", model_state=None): |
| """ |
| Decode a latent/code file with a target skeleton and save it as BVH. |
| |
| Args: |
| z_path: path to a latent/code file (.npy, .npz, or .pt) |
| - VAE: latent vectors [T, z_dim] |
| - RVQ: code_idx [T, Q] |
| skel_path: target skeleton .npz path. The matching tf file is loaded automatically. |
| model_epoch: model checkpoint name, for example "ckpt0" |
| output_dir: output directory path |
| output_name: output file stem, default "decoded_motion.bvh" |
| device: compute device, default "cuda:0" |
| window_size: sliding-window size, default 64 |
| overlap: sliding-window overlap, default 16 |
| use_sliding_window: force sliding-window decoding; None selects automatically |
| model_type: "vae" or "rvq", default "vae" |
| model_state: optional preloaded (model, cfg, ms_dict); loaded on demand when None |
| |
| Returns: |
| output_path: saved BVH path |
| """ |
| |
| if model_state is not None: |
| model, cfg, ms_dict = model_state |
| print(f"[Model] Reusing preloaded model ({model_type.upper()})") |
| else: |
| print(f"Loading model: {model_epoch} ({model_type.upper()})") |
| model, cfg, ms_dict = load_model_by_type(model_type, model_epoch, device) |
| model = model.to(device) |
| model.eval() |
| |
| |
| print(f"Loading {'latent vectors' if model_type == 'vae' else 'code_idx'}: {z_path}") |
| |
| if z_path.endswith('.npy'): |
| |
| data = torch.from_numpy(np.load(z_path)).float().to(device) |
| |
| elif z_path.endswith('.npz'): |
| |
| data_loaded = np.load(z_path) |
| |
| possible_keys = ['z', 'z_pred', 'sample', 'code_idx', 'codes', 'idx_pred'] |
| data_key = None |
| for key in possible_keys: |
| if key in data_loaded: |
| data_key = key |
| break |
| |
| if data_key is None: |
| |
| data_key = list(data_loaded.keys())[0] |
| print(f" Loaded key from npz: {data_key}") |
| else: |
| print(f" Loaded key from npz: {data_key}") |
| |
| data = torch.from_numpy(data_loaded[data_key]).float().to(device) |
| |
| elif z_path.endswith('.pt'): |
| |
| encoded_data = torch.load(z_path, map_location=device) |
| |
| |
| if model_type == "vae": |
| |
| if 'z_pred' in encoded_data: |
| data = encoded_data['z_pred'].float().to(device) |
| print(" Loaded key from pt: z_pred") |
| elif 'z' in encoded_data: |
| data = encoded_data['z'].float().to(device) |
| print(" Loaded key from pt: z") |
| elif 'sample' in encoded_data: |
| data = encoded_data['sample'].float().to(device) |
| print(" Loaded key from pt: sample") |
| else: |
| raise KeyError("pt file does not contain any of: z_pred, z, sample") |
| elif model_type == "rvq": |
| |
| if 'idx_pred' in encoded_data: |
| data = encoded_data['idx_pred'].long().to(device) |
| print(" Loaded key from pt: idx_pred") |
| elif 'code_idx' in encoded_data: |
| data = encoded_data['code_idx'].long().to(device) |
| print(" Loaded key from pt: code_idx") |
| else: |
| raise KeyError("pt file does not contain either idx_pred or code_idx") |
| |
| |
| if 'text' in encoded_data: |
| print(f" Text metadata: {encoded_data['text']}") |
| if 'm_len' in encoded_data: |
| print(f" Original length: {encoded_data['m_len']}") |
| if 'is_segment' in encoded_data and encoded_data['is_segment']: |
| print(" Note: this file contains a segment") |
| |
| else: |
| raise ValueError(f"Unsupported file format: {z_path}. Only .npy, .npz, and .pt are supported") |
| |
| |
| if model_type == "vae": |
| z = data |
| data_length = z.shape[0] |
| print(f" z shape: {z.shape}") |
| elif model_type == "rvq": |
| code_idx = data.long() |
| data_length = code_idx.shape[0] |
| print(f" code_idx shape: {code_idx.shape}") |
| |
| |
| if not skel_path.endswith('.npz'): |
| raise ValueError(f"Unsupported skeleton format: {skel_path}. Only .npz is supported") |
| |
| |
| |
| print(f"Loading skeleton from NPZ: {skel_path}") |
| tf_npz_path = skel_path.replace('/processed/', '/joint_text_features/') |
| |
| |
| if tf_npz_path == skel_path: |
| import pathlib |
| path_obj = pathlib.Path(skel_path) |
| parent = path_obj.parent |
| if 'processed' in str(parent): |
| new_parent = str(parent).replace('processed', 'joint_text_features') |
| tf_npz_path = os.path.join(new_parent, path_obj.name) |
| |
| |
| skel_data = load_skeleton_and_tf_from_npz(skel_path, tf_npz_path) |
| print(f" Skeleton joints: {skel_data.lo.shape[0]}") |
| print(f" tf shape: {skel_data.tf.shape}") |
| |
| skel_graph = SkelPoseGraph(skel_data, None) |
| |
| saved_skel = graph_2_skel(Batch.from_data_list([skel_graph]).to(device), 1)[0] |
| |
| |
| if use_sliding_window is None: |
| |
| use_sliding_window = (data_length > window_size) |
| |
| |
| print("Decoding...") |
| out_rep_cfg = cfg["representation"]["out"] |
| |
| with torch.no_grad(): |
| if use_sliding_window and data_length > window_size: |
| |
| print(f" Long sequence ({data_length} frames); using sliding-window decoding") |
| |
| |
| latent_codes = z if model_type == "vae" else code_idx |
| |
| |
| from sata.utils.motion_data import create_graph_list_from_single_graph |
| src_graphs_list = create_graph_list_from_single_graph(skel_graph, data_length) |
| |
| |
| hatD_full, src_batch_full, actual_frames, num_nodes_per_frame = \ |
| decode_to_hatD_with_sliding_window( |
| model, latent_codes, src_graphs_list, |
| data_length, window_size, overlap, device, model_type |
| ) |
| |
| |
| out_motion_list, out_contact_list = hatD_recon_motion( |
| hatD_full, src_batch_full, out_rep_cfg, ms_dict, actual_frames |
| ) |
| out_motion = out_motion_list[0] |
| out_contact = out_contact_list[0] |
| else: |
| |
| print(f" Short sequence ({data_length} frames); using single-pass decoding") |
| |
| if model_type == "vae": |
| |
| skel_batch = Batch.from_data_list([skel_graph] * data_length).to(device) |
| hatD = model.decode(z, skel_batch, data_length) |
| elif model_type == "rvq": |
| |
| skel_batch = Batch.from_data_list([skel_graph] * data_length).to(device) |
| hatD, _ = model.decode_from_codes(code_idx, skel_batch, data_length) |
| |
| |
| out_motion_list, out_contact_list = hatD_recon_motion( |
| hatD, skel_batch, out_rep_cfg, ms_dict, data_length |
| ) |
| out_motion = out_motion_list[0] |
| out_contact = out_contact_list[0] |
| |
| print(f" Output motion frames: {out_motion.num_frames()}") |
| |
| |
| if not os.path.exists(output_dir): |
| os.makedirs(output_dir) |
| print(f"Created output directory: {output_dir}") |
| |
| |
| out_motion.fps = 20 |
| |
| |
| out_motion_fixed = fix_skeleton_coordinate_system(out_motion) |
| print(" [Info] Applied coordinate-system fix (Y-Z axis conversion)") |
| |
| output_path = os.path.join(output_dir, output_name + '.bvh') |
| bvh.save(out_motion_fixed, output_path, rot_order="XYZ") |
| |
| print(f"Saved to: {output_path}") |
| |
| return output_path |
|
|
|
|
| def scan_z_and_skeleton_files(visual_base_dir): |
| """ |
| Read z and skeleton path lists from txt files and build all pairings. |
| |
| Args: |
| visual_base_dir: visual root containing z.txt and skel.txt |
| |
| Returns: |
| pairs: list of tuple, [(z_path, z_name, skel_path, skel_name), ...] |
| """ |
| |
| z_files = [] |
| z_txt_path = os.path.join(visual_base_dir, "z.txt") |
| if os.path.exists(z_txt_path): |
| with open(z_txt_path, 'r', encoding='utf-8') as f: |
| for line in f: |
| line = line.strip() |
| if line and not line.startswith('#'): |
| z_path = line |
| |
| z_name = os.path.splitext(os.path.basename(z_path))[0] |
| z_files.append((z_path, z_name)) |
| print(f" Read {len(z_files)} z files from {z_txt_path}") |
| else: |
| print(f" Not found: {z_txt_path}") |
| |
| |
| skel_files = [] |
| skel_txt_path = os.path.join(visual_base_dir, "skel.txt") |
| if os.path.exists(skel_txt_path): |
| with open(skel_txt_path, 'r', encoding='utf-8') as f: |
| for line in f: |
| line = line.strip() |
| if line and not line.startswith('#'): |
| skel_path = line |
| |
| skel_name = os.path.splitext(os.path.basename(skel_path))[0] |
| skel_files.append((skel_path, skel_name)) |
| print(f" Read {len(skel_files)} skeleton files from {skel_txt_path}") |
| else: |
| print(f" Not found: {skel_txt_path}") |
| |
| |
| pairs = [] |
| for z_path, z_name in z_files: |
| for skel_path, skel_name in skel_files: |
| pairs.append((z_path, z_name, skel_path, skel_name)) |
| |
| return pairs |
|
|
|
|
| def batch_decode_all(model_epoch, visual_base_dir=VISUAL_BASE_DIR, |
| output_dir=OUTPUT_DIR, device="cuda:0", |
| window_size=64, overlap=16, model_type="vae"): |
| """ |
| Batch decode every z/skeleton pairing listed in txt files. |
| |
| Args: |
| model_epoch: model checkpoint name |
| visual_base_dir: visual root containing z.txt and skel.txt; default visual/ |
| output_dir: output directory; default visual/bvh_output |
| device: compute device |
| window_size: sliding-window size |
| overlap: sliding-window overlap |
| model_type: model type |
| |
| Returns: |
| results: list of per-pair result dictionaries |
| """ |
| |
| print("="*70) |
| print("Reading path lists...") |
| print(f" z.txt: {os.path.join(visual_base_dir, 'z.txt')}") |
| print(f" skel.txt: {os.path.join(visual_base_dir, 'skel.txt')}") |
| print("="*70) |
| |
| pairs = scan_z_and_skeleton_files(visual_base_dir) |
| |
| if len(pairs) == 0: |
| print("No z or skeleton files found. Check the directory configuration.") |
| return [] |
| |
| |
| z_count = len(set([p[1] for p in pairs])) |
| skel_count = len(set([p[3] for p in pairs])) |
| print(f"Found {z_count} z files x {skel_count} skeleton files = {len(pairs)} pairs") |
| print() |
| |
| |
| print("Pairings to process:") |
| for i, (z_path, z_name, skel_path, skel_name) in enumerate(pairs, 1): |
| output_name = f"{z_name}_{skel_name}" |
| print(f" [{i}/{len(pairs)}] {output_name}") |
| print("="*70) |
| print() |
| |
| |
| results = [] |
| for i, (z_path, z_name, skel_path, skel_name) in enumerate(pairs, 1): |
| output_name = f"{z_name}_{skel_name}" |
| |
| print("="*70) |
| print(f"Processing pair [{i}/{len(pairs)}]: {output_name}") |
| print("="*70) |
| print(f" z file: {z_path}") |
| print(f" skeleton file: {skel_path}") |
| print() |
| |
| try: |
| output_path = decode_z_skel_to_bvh( |
| z_path=z_path, |
| skel_path=skel_path, |
| model_epoch=model_epoch, |
| output_dir=output_dir, |
| output_name=output_name, |
| device=device, |
| window_size=window_size, |
| overlap=overlap, |
| use_sliding_window=None, |
| model_type=model_type |
| ) |
| |
| results.append({ |
| 'z_name': z_name, |
| 'skel_name': skel_name, |
| 'output_name': output_name, |
| 'output_path': output_path, |
| 'status': 'success' |
| }) |
| print(f"[{i}/{len(pairs)}] Success: {output_name}") |
| |
| except Exception as e: |
| print(f"[{i}/{len(pairs)}] Failed: {output_name}") |
| print(f" Error: {str(e)}") |
| results.append({ |
| 'z_name': z_name, |
| 'skel_name': skel_name, |
| 'output_name': output_name, |
| 'output_path': None, |
| 'status': 'failed', |
| 'error': str(e) |
| }) |
| |
| print() |
| |
| |
| print("="*70) |
| print("Batch processing complete") |
| print("="*70) |
| success_count = sum(1 for r in results if r['status'] == 'success') |
| failed_count = sum(1 for r in results if r['status'] == 'failed') |
| print(f"Total pairs: {len(results)}") |
| print(f" Success: {success_count}") |
| print(f" Failed: {failed_count}") |
| |
| if failed_count > 0: |
| print() |
| print("Failed pairs:") |
| for r in results: |
| if r['status'] == 'failed': |
| print(f" - {r['output_name']}: {r.get('error', 'Unknown error')}") |
| |
| print("="*70) |
| |
| return results |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Batch decode latent vectors to BVH files", formatter_class=argparse.RawDescriptionHelpFormatter, epilog="") |
| |
| parser.add_argument("--model_type", type=str, default="vae", choices=["vae", "rvq"], help="Model type: vae or rvq (default: vae)") |
| parser.add_argument("--model_epoch", type=str, required=True, help="Model checkpoint name, for example ckpt0") |
| parser.add_argument("--visual_dir", type=str, default=VISUAL_BASE_DIR, help=f"Visual root containing z.txt and skel.txt (default: {VISUAL_BASE_DIR})") |
| parser.add_argument("--output_dir", type=str, default=OUTPUT_DIR, help=f"Output directory (default: {OUTPUT_DIR})") |
| parser.add_argument("--device", type=str, default="cuda:0", help="Compute device (default: cuda:0)") |
| parser.add_argument("--window_size", type=int, default=64, help="Sliding-window size in frames (default: 64)") |
| parser.add_argument("--overlap", type=int, default=16, help="Sliding-window overlap in frames (default: 16)") |
| |
| args = parser.parse_args() |
| |
| print("="*70) |
| print("Batch latent-vector decoder (reads path lists from txt files)") |
| print("="*70) |
| print(f"Model type: {args.model_type.upper()}") |
| print(f"Model: {args.model_epoch}") |
| print(f"z path list: {os.path.join(args.visual_dir, 'z.txt')}") |
| print(f"skeleton path list: {os.path.join(args.visual_dir, 'skel.txt')}") |
| print(f"Output directory: {args.output_dir}") |
| print(f"Device: {args.device}") |
| print(f"Sliding window: size={args.window_size}, overlap={args.overlap}") |
| print("="*70) |
| print() |
| |
| try: |
| results = batch_decode_all( |
| model_epoch=args.model_epoch, |
| visual_base_dir=args.visual_dir, |
| output_dir=args.output_dir, |
| device=args.device, |
| window_size=args.window_size, |
| overlap=args.overlap, |
| model_type=args.model_type |
| ) |
| |
| |
| failed_count = sum(1 for r in results if r['status'] == 'failed') |
| if failed_count > 0: |
| exit(1) |
| |
| except Exception as e: |
| print() |
| print("="*70) |
| print(f"Error: {str(e)}") |
| print("="*70) |
| import traceback |
| traceback.print_exc() |
| exit(1) |
|
|
|
|