""" Batch decode latent vectors to BVH files. Supports sliding-window decoding and multiple input formats. The Gradio demo uses decode_z_skel_to_bvh directly; the CLI keeps batch decoding support for local debugging. """ import argparse import os import numpy as np import torch from fairmotion.data import bvh from torch_geometric.data import Batch # Required conversion helpers from sata.conversions.graph_to_motion import hatD_recon_motion from sata.skel_pose_graph import SkelPoseGraph from sata.conversions.graph_to_motion import graph_2_skel from fairmotion.core import motion as motion_class from sata.utils.model_loading import load_model_by_type # Reuse shared SATA motion-data helpers. from sata.utils.motion_data import ( SkelData, load_skeleton_and_tf_from_npz, create_graph_list_from_skeleton, fix_skeleton_coordinate_system, ) from sata.utils.sliding_decode import decode_to_hatD_with_sliding_window # Default directory configuration VISUAL_BASE_DIR = os.path.join(os.path.dirname(__file__), "visual") Z_DIR = os.path.join(VISUAL_BASE_DIR, "z") SKEL_DIR = os.path.join(VISUAL_BASE_DIR, "skel", "processed") SKEL_TF_DIR = os.path.join(VISUAL_BASE_DIR, "skel", "joint_text_features") OUTPUT_DIR = os.path.join(VISUAL_BASE_DIR, "bvh_output") # visual/z.txt contains one z file path per line. # visual/skel.txt contains one skeleton npz path per line. # The tf path is derived by replacing processed with joint_text_features. # OUTPUT_DIR: visual/bvh_output def decode_z_skel_to_bvh(z_path, skel_path, model_epoch, output_dir, output_name="decoded_motion.bvh", device="cuda:0", window_size=64, overlap=16, use_sliding_window=None, model_type="vae", model_state=None): """ Decode a latent/code file with a target skeleton and save it as BVH. Args: z_path: path to a latent/code file (.npy, .npz, or .pt) - VAE: latent vectors [T, z_dim] - RVQ: code_idx [T, Q] skel_path: target skeleton .npz path. The matching tf file is loaded automatically. model_epoch: model checkpoint name, for example "ckpt0" output_dir: output directory path output_name: output file stem, default "decoded_motion.bvh" device: compute device, default "cuda:0" window_size: sliding-window size, default 64 overlap: sliding-window overlap, default 16 use_sliding_window: force sliding-window decoding; None selects automatically model_type: "vae" or "rvq", default "vae" model_state: optional preloaded (model, cfg, ms_dict); loaded on demand when None Returns: output_path: saved BVH path """ # 1. Load the model or reuse a preloaded model state. if model_state is not None: model, cfg, ms_dict = model_state print(f"[Model] Reusing preloaded model ({model_type.upper()})") else: print(f"Loading model: {model_epoch} ({model_type.upper()})") model, cfg, ms_dict = load_model_by_type(model_type, model_epoch, device) model = model.to(device) model.eval() # 2. Load latent vectors or code indices. print(f"Loading {'latent vectors' if model_type == 'vae' else 'code_idx'}: {z_path}") if z_path.endswith('.npy'): # NumPy .npy file data = torch.from_numpy(np.load(z_path)).float().to(device) elif z_path.endswith('.npz'): # NumPy .npz file, possibly with multiple arrays data_loaded = np.load(z_path) # Try common payload keys. possible_keys = ['z', 'z_pred', 'sample', 'code_idx', 'codes', 'idx_pred'] data_key = None for key in possible_keys: if key in data_loaded: data_key = key break if data_key is None: # Fall back to the first key. data_key = list(data_loaded.keys())[0] print(f" Loaded key from npz: {data_key}") else: print(f" Loaded key from npz: {data_key}") data = torch.from_numpy(data_loaded[data_key]).float().to(device) elif z_path.endswith('.pt'): # PyTorch .pt file compatible with gen_dec.py encoded data. encoded_data = torch.load(z_path, map_location=device) # Extract the payload for the selected model type. if model_type == "vae": # VAE: prefer z_pred, then z, then sample. if 'z_pred' in encoded_data: data = encoded_data['z_pred'].float().to(device) print(" Loaded key from pt: z_pred") elif 'z' in encoded_data: data = encoded_data['z'].float().to(device) print(" Loaded key from pt: z") elif 'sample' in encoded_data: data = encoded_data['sample'].float().to(device) print(" Loaded key from pt: sample") else: raise KeyError("pt file does not contain any of: z_pred, z, sample") elif model_type == "rvq": # RVQ: prefer idx_pred, then code_idx. if 'idx_pred' in encoded_data: data = encoded_data['idx_pred'].long().to(device) print(" Loaded key from pt: idx_pred") elif 'code_idx' in encoded_data: data = encoded_data['code_idx'].long().to(device) print(" Loaded key from pt: code_idx") else: raise KeyError("pt file does not contain either idx_pred or code_idx") # Print metadata when present. if 'text' in encoded_data: print(f" Text metadata: {encoded_data['text']}") if 'm_len' in encoded_data: print(f" Original length: {encoded_data['m_len']}") if 'is_segment' in encoded_data and encoded_data['is_segment']: print(" Note: this file contains a segment") else: raise ValueError(f"Unsupported file format: {z_path}. Only .npy, .npz, and .pt are supported") # Normalize payload dtype for the selected model type. if model_type == "vae": z = data data_length = z.shape[0] print(f" z shape: {z.shape}") elif model_type == "rvq": code_idx = data.long() # RVQ code_idx must be long. data_length = code_idx.shape[0] print(f" code_idx shape: {code_idx.shape}") # 3. Load the skeleton and convert it to a graph. Only NPZ is supported. if not skel_path.endswith('.npz'): raise ValueError(f"Unsupported skeleton format: {skel_path}. Only .npz is supported") # NPZ format contains skeleton data and a matching tf file. # Build the tf path by replacing processed with joint_text_features. print(f"Loading skeleton from NPZ: {skel_path}") tf_npz_path = skel_path.replace('/processed/', '/joint_text_features/') # If the path did not change, try the common parent directory pattern. if tf_npz_path == skel_path: import pathlib path_obj = pathlib.Path(skel_path) parent = path_obj.parent if 'processed' in str(parent): new_parent = str(parent).replace('processed', 'joint_text_features') tf_npz_path = os.path.join(new_parent, path_obj.name) # Load through utils.py. skel_data = load_skeleton_and_tf_from_npz(skel_path, tf_npz_path) print(f" Skeleton joints: {skel_data.lo.shape[0]}") print(f" tf shape: {skel_data.tf.shape}") skel_graph = SkelPoseGraph(skel_data, None) # Convert directly to a graph. # Keep the skeleton for later BVH generation. saved_skel = graph_2_skel(Batch.from_data_list([skel_graph]).to(device), 1)[0] # 4. Decide whether to use sliding-window decoding. if use_sliding_window is None: # Automatically use sliding windows for sequences longer than the window size. use_sliding_window = (data_length > window_size) # 5. Decode to motion. print("Decoding...") out_rep_cfg = cfg["representation"]["out"] with torch.no_grad(): if use_sliding_window and data_length > window_size: # Decode with the shared sliding-window implementation. print(f" Long sequence ({data_length} frames); using sliding-window decoding") # Prepare latent codes for the selected model type. latent_codes = z if model_type == "vae" else code_idx # Create a graph list of the required length from one skeleton graph. from sata.utils.motion_data import create_graph_list_from_single_graph src_graphs_list = create_graph_list_from_single_graph(skel_graph, data_length) # Call the shared sliding-window decoder. hatD_full, src_batch_full, actual_frames, num_nodes_per_frame = \ decode_to_hatD_with_sliding_window( model, latent_codes, src_graphs_list, data_length, window_size, overlap, device, model_type ) # Reconstruct motion with hatD_recon_motion. out_motion_list, out_contact_list = hatD_recon_motion( hatD_full, src_batch_full, out_rep_cfg, ms_dict, actual_frames ) out_motion = out_motion_list[0] out_contact = out_contact_list[0] else: # Use single-pass decoding. print(f" Short sequence ({data_length} frames); using single-pass decoding") if model_type == "vae": # VAE single-pass decode. skel_batch = Batch.from_data_list([skel_graph] * data_length).to(device) hatD = model.decode(z, skel_batch, data_length) elif model_type == "rvq": # RVQ single-pass decode. skel_batch = Batch.from_data_list([skel_graph] * data_length).to(device) hatD, _ = model.decode_from_codes(code_idx, skel_batch, data_length) # Reconstruct motion. out_motion_list, out_contact_list = hatD_recon_motion( hatD, skel_batch, out_rep_cfg, ms_dict, data_length ) out_motion = out_motion_list[0] out_contact = out_contact_list[0] print(f" Output motion frames: {out_motion.num_frames()}") # 6. Save the BVH file. if not os.path.exists(output_dir): os.makedirs(output_dir) print(f"Created output directory: {output_dir}") # Set FPS to 20. out_motion.fps = 20 # Fix coordinate system differences. out_motion_fixed = fix_skeleton_coordinate_system(out_motion) print(" [Info] Applied coordinate-system fix (Y-Z axis conversion)") output_path = os.path.join(output_dir, output_name + '.bvh') bvh.save(out_motion_fixed, output_path, rot_order="XYZ") # bvh.save(out_motion, output_path, rot_order="XYZ") print(f"Saved to: {output_path}") return output_path def scan_z_and_skeleton_files(visual_base_dir): """ Read z and skeleton path lists from txt files and build all pairings. Args: visual_base_dir: visual root containing z.txt and skel.txt Returns: pairs: list of tuple, [(z_path, z_name, skel_path, skel_name), ...] """ # Read z.txt. z_files = [] z_txt_path = os.path.join(visual_base_dir, "z.txt") if os.path.exists(z_txt_path): with open(z_txt_path, 'r', encoding='utf-8') as f: for line in f: line = line.strip() if line and not line.startswith('#'): # Skip empty lines and comments. z_path = line # Use the file stem as the display name. z_name = os.path.splitext(os.path.basename(z_path))[0] z_files.append((z_path, z_name)) print(f" Read {len(z_files)} z files from {z_txt_path}") else: print(f" Not found: {z_txt_path}") # Read skel.txt. skel_files = [] skel_txt_path = os.path.join(visual_base_dir, "skel.txt") if os.path.exists(skel_txt_path): with open(skel_txt_path, 'r', encoding='utf-8') as f: for line in f: line = line.strip() if line and not line.startswith('#'): # Skip empty lines and comments. skel_path = line # Use the file stem as the display name. skel_name = os.path.splitext(os.path.basename(skel_path))[0] skel_files.append((skel_path, skel_name)) print(f" Read {len(skel_files)} skeleton files from {skel_txt_path}") else: print(f" Not found: {skel_txt_path}") # Build the Cartesian product. pairs = [] for z_path, z_name in z_files: for skel_path, skel_name in skel_files: pairs.append((z_path, z_name, skel_path, skel_name)) return pairs def batch_decode_all(model_epoch, visual_base_dir=VISUAL_BASE_DIR, output_dir=OUTPUT_DIR, device="cuda:0", window_size=64, overlap=16, model_type="vae"): """ Batch decode every z/skeleton pairing listed in txt files. Args: model_epoch: model checkpoint name visual_base_dir: visual root containing z.txt and skel.txt; default visual/ output_dir: output directory; default visual/bvh_output device: compute device window_size: sliding-window size overlap: sliding-window overlap model_type: model type Returns: results: list of per-pair result dictionaries """ # Read txt files and build pairings. print("="*70) print("Reading path lists...") print(f" z.txt: {os.path.join(visual_base_dir, 'z.txt')}") print(f" skel.txt: {os.path.join(visual_base_dir, 'skel.txt')}") print("="*70) pairs = scan_z_and_skeleton_files(visual_base_dir) if len(pairs) == 0: print("No z or skeleton files found. Check the directory configuration.") return [] # Count source files. z_count = len(set([p[1] for p in pairs])) skel_count = len(set([p[3] for p in pairs])) print(f"Found {z_count} z files x {skel_count} skeleton files = {len(pairs)} pairs") print() # Show all pairings. print("Pairings to process:") for i, (z_path, z_name, skel_path, skel_name) in enumerate(pairs, 1): output_name = f"{z_name}_{skel_name}" print(f" [{i}/{len(pairs)}] {output_name}") print("="*70) print() # Process each pair. results = [] for i, (z_path, z_name, skel_path, skel_name) in enumerate(pairs, 1): output_name = f"{z_name}_{skel_name}" print("="*70) print(f"Processing pair [{i}/{len(pairs)}]: {output_name}") print("="*70) print(f" z file: {z_path}") print(f" skeleton file: {skel_path}") print() try: output_path = decode_z_skel_to_bvh( z_path=z_path, skel_path=skel_path, model_epoch=model_epoch, output_dir=output_dir, output_name=output_name, device=device, window_size=window_size, overlap=overlap, use_sliding_window=None, # Select automatically. model_type=model_type ) results.append({ 'z_name': z_name, 'skel_name': skel_name, 'output_name': output_name, 'output_path': output_path, 'status': 'success' }) print(f"[{i}/{len(pairs)}] Success: {output_name}") except Exception as e: print(f"[{i}/{len(pairs)}] Failed: {output_name}") print(f" Error: {str(e)}") results.append({ 'z_name': z_name, 'skel_name': skel_name, 'output_name': output_name, 'output_path': None, 'status': 'failed', 'error': str(e) }) print() # Summary. print("="*70) print("Batch processing complete") print("="*70) success_count = sum(1 for r in results if r['status'] == 'success') failed_count = sum(1 for r in results if r['status'] == 'failed') print(f"Total pairs: {len(results)}") print(f" Success: {success_count}") print(f" Failed: {failed_count}") if failed_count > 0: print() print("Failed pairs:") for r in results: if r['status'] == 'failed': print(f" - {r['output_name']}: {r.get('error', 'Unknown error')}") print("="*70) return results if __name__ == "__main__": parser = argparse.ArgumentParser(description="Batch decode latent vectors to BVH files", formatter_class=argparse.RawDescriptionHelpFormatter, epilog="") parser.add_argument("--model_type", type=str, default="vae", choices=["vae", "rvq"], help="Model type: vae or rvq (default: vae)") parser.add_argument("--model_epoch", type=str, required=True, help="Model checkpoint name, for example ckpt0") parser.add_argument("--visual_dir", type=str, default=VISUAL_BASE_DIR, help=f"Visual root containing z.txt and skel.txt (default: {VISUAL_BASE_DIR})") parser.add_argument("--output_dir", type=str, default=OUTPUT_DIR, help=f"Output directory (default: {OUTPUT_DIR})") parser.add_argument("--device", type=str, default="cuda:0", help="Compute device (default: cuda:0)") parser.add_argument("--window_size", type=int, default=64, help="Sliding-window size in frames (default: 64)") parser.add_argument("--overlap", type=int, default=16, help="Sliding-window overlap in frames (default: 16)") args = parser.parse_args() print("="*70) print("Batch latent-vector decoder (reads path lists from txt files)") print("="*70) print(f"Model type: {args.model_type.upper()}") print(f"Model: {args.model_epoch}") print(f"z path list: {os.path.join(args.visual_dir, 'z.txt')}") print(f"skeleton path list: {os.path.join(args.visual_dir, 'skel.txt')}") print(f"Output directory: {args.output_dir}") print(f"Device: {args.device}") print(f"Sliding window: size={args.window_size}, overlap={args.overlap}") print("="*70) print() try: results = batch_decode_all( model_epoch=args.model_epoch, visual_base_dir=args.visual_dir, output_dir=args.output_dir, device=args.device, window_size=args.window_size, overlap=args.overlap, model_type=args.model_type ) # Set exit status from the batch result. failed_count = sum(1 for r in results if r['status'] == 'failed') if failed_count > 0: exit(1) except Exception as e: print() print("="*70) print(f"Error: {str(e)}") print("="*70) import traceback traceback.print_exc() exit(1)