# This code is based on https://github.com/openai/guided-diffusion """ Generate preprocessed posterior features (z) from text for an external decoder. Unlike generate.py, this script does not post-process the generated features. It saves the raw z features in a format that external decoders can read. """ import os import numpy as np import torch from argparse import ArgumentParser from utils.fixseed import fixseed from utils import dist_util from utils.model_util import create_gaussian_diffusion from utils.sampler_util import ClassifierFreeSampleModel from data_loaders.tensors import collate from model.mdm import MDM import json def sample_z_args(): """Parse command-line arguments for z sampling.""" parser = ArgumentParser(description='Sample z from text using trained MDM model') # Model options parser.add_argument("--model_path", required=True, type=str, help="Path to trained model checkpoint") parser.add_argument("--use_ema", action='store_true', help="Use EMA model if available") # Input options parser.add_argument("--text_prompt", default='', type=str, help="A single text prompt to generate") parser.add_argument("--input_text", default='', type=str, help="Path to a text file with prompts (one per line)") # Sampling options parser.add_argument("--num_samples", default=1, type=int, help="Number of samples per prompt") parser.add_argument("--num_repetitions", default=1, type=int, help="Number of repetitions for each prompt") parser.add_argument("--motion_length", default=6.0, type=float, help="Motion length in seconds") parser.add_argument("--guidance_param", default=2.5, type=float, help="Classifier-free guidance scale") # Output options parser.add_argument("--output_dir", default='', type=str, help="Output directory for results") parser.add_argument("--save_individual", action='store_true', help="Save each sample as individual file") # Misc options parser.add_argument("--seed", default=10, type=int) parser.add_argument("--device", default=0, type=int) args = parser.parse_args() return args def load_model_from_checkpoint(model_path, device, use_ema=False): """Load a model from a checkpoint.""" # Load args. args_path = os.path.join(os.path.dirname(model_path), 'args.json') with open(args_path, 'r') as f: model_args = json.load(f) # Create a simple args object. class Args: pass args = Args() for k, v in model_args.items(): setattr(args, k, v) # Determine conditioning mode. if hasattr(args, 'unconstrained') and args.unconstrained: cond_mode = 'no_cond' elif args.dataset in ['kit', 'humanml', 'preprocessed_posterior']: cond_mode = 'text' else: cond_mode = 'action' # Get model parameters. njoints = getattr(args, 'njoints', 512) nfeats = getattr(args, 'nfeats', 1) model_kwargs = { 'modeltype': '', 'njoints': njoints, 'nfeats': nfeats, 'num_actions': 1, 'translation': True, 'pose_rep': 'rot6d', 'glob': True, 'glob_rot': True, 'latent_dim': args.latent_dim, 'ff_size': 1024, 'num_layers': args.layers, 'num_heads': 4, 'dropout': 0.1, 'activation': 'gelu', 'data_rep': 'hml_vec', 'cond_mode': cond_mode, 'cond_mask_prob': getattr(args, 'cond_mask_prob', 0.1), 'action_emb': 'tensor', 'arch': args.arch, 'emb_trans_dec': getattr(args, 'emb_trans_dec', False), 'clip_version': 'ViT-B/32', 'dataset': args.dataset, 'text_encoder_type': getattr(args, 'text_encoder_type', 'clip'), 'pos_embed_max_len': getattr(args, 'pos_embed_max_len', 5000), 'mask_frames': getattr(args, 'mask_frames', False), 'pred_len': getattr(args, 'pred_len', 0), 'context_len': getattr(args, 'context_len', 0), 'emb_policy': 'add', 'all_goal_joint_names': [], 'multi_target_cond': getattr(args, 'multi_target_cond', False), 'multi_encoder_type': getattr(args, 'multi_encoder_type', 'single'), 'target_enc_layers': getattr(args, 'target_enc_layers', 1), 'use_rot2xyz': False, } # Create the model without SMPL-backed xyz conversion; this script saves raw z only. model = MDM(**model_kwargs) # Load weights. state_dict = torch.load(model_path, map_location=device) if use_ema and 'model_avg' in state_dict: print("Loading EMA model weights...") state_dict = state_dict['model_avg'] elif 'model' in state_dict: state_dict = state_dict['model'] # Remove unused keys. keys_to_delete = ['sequence_pos_encoder.pe', 'embed_timestep.sequence_pos_encoder.pe'] for key in keys_to_delete: if key in state_dict: del state_dict[key] missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) print(f"Loaded model. Missing keys: {len(missing_keys)}, Unexpected keys: {len(unexpected_keys)}") # Create diffusion. diffusion = create_gaussian_diffusion(args) return model, diffusion, args def main(): args = sample_z_args() fixseed(args.seed) dist_util.setup_dist(args.device) device = dist_util.dev() # Load text prompts. if args.text_prompt != '': texts = [args.text_prompt] * args.num_samples elif args.input_text != '': assert os.path.exists(args.input_text), f"Input text file not found: {args.input_text}" with open(args.input_text, 'r') as f: texts = [line.strip() for line in f.readlines() if line.strip()] args.num_samples = len(texts) else: raise ValueError("Must provide either --text_prompt or --input_text") print(f"Loaded {len(texts)} text prompts") # Load model. print(f"Loading model from {args.model_path}...") model, diffusion, model_args = load_model_from_checkpoint( args.model_path, device, use_ema=args.use_ema ) model.to(device) model.eval() # Set classifier-free guidance. if args.guidance_param != 1: model = ClassifierFreeSampleModel(model) # Compute frame count. fps = 20 # Default FPS n_frames = int(args.motion_length * fps) # Prepare inputs. batch_size = len(texts) motion_shape = (batch_size, model.njoints if hasattr(model, 'njoints') else model.model.njoints, model.nfeats if hasattr(model, 'nfeats') else model.model.nfeats, n_frames) print(f"Motion shape: {motion_shape}") # Build model_kwargs. collate_args = [{'inp': torch.zeros(n_frames), 'tokens': None, 'lengths': n_frames, 'text': txt} for txt in texts] _, model_kwargs = collate(collate_args) model_kwargs['y'] = {key: val.to(device) if torch.is_tensor(val) else val for key, val in model_kwargs['y'].items()} # Add guidance scale. if args.guidance_param != 1: model_kwargs['y']['scale'] = torch.ones(batch_size, device=device) * args.guidance_param # Pre-encode text when the model supports it. actual_model = model.model if hasattr(model, 'model') else model if hasattr(actual_model, 'encode_text') and 'text' in model_kwargs['y']: print("Pre-encoding text...") model_kwargs['y']['text_embed'] = actual_model.encode_text(model_kwargs['y']['text']) # Sample. all_samples = [] all_texts = [] all_lengths = [] sample_fn = diffusion.p_sample_loop for rep_i in range(args.num_repetitions): print(f"\n### Sampling [repetition #{rep_i + 1}/{args.num_repetitions}]") sample = sample_fn( model, motion_shape, clip_denoised=False, model_kwargs=model_kwargs, skip_timesteps=0, init_image=None, progress=True, dump_steps=None, noise=None, const_noise=False, ) # Do not post-process here; save raw z features directly. # sample shape: (batch_size, njoints, nfeats, n_frames) all_samples.append(sample.cpu().numpy()) all_texts.extend(texts) all_lengths.extend([n_frames] * batch_size) print(f"Generated {len(texts)} samples") # Merge all results. all_samples = np.concatenate(all_samples, axis=0) # (N, D, 1, T) all_lengths = np.array(all_lengths) # Convert shape: (N, D, 1, T) -> (N, T, D). all_samples = all_samples.squeeze(2) # (N, D, T) all_samples = all_samples.transpose(0, 2, 1) # (N, T, D) print(f"\nTotal samples: {all_samples.shape[0]}") print(f"Sample shape: {all_samples.shape} # (N, T, D)") # Save results. if args.output_dir == '': model_name = os.path.basename(os.path.dirname(args.model_path)) niter = os.path.basename(args.model_path).replace('model', '').replace('.pt', '') args.output_dir = os.path.join(os.path.dirname(args.model_path), f'sampled_z_{model_name}_{niter}_seed{args.seed}') os.makedirs(args.output_dir, exist_ok=True) # Save the main result file in npz format. results_path = os.path.join(args.output_dir, 'results.npz') print(f"\nSaving results to {results_path}") np.savez(results_path, z=all_samples, # (N, T, D) texts=np.array(all_texts, dtype=str), # (N,) string array lengths=all_lengths, # (N,) motion_length=np.array(args.motion_length), guidance_param=np.array(args.guidance_param), model_path=np.array(args.model_path), ) # Optionally save individual files. if args.save_individual: individual_dir = os.path.join(args.output_dir, 'individual') os.makedirs(individual_dir, exist_ok=True) for i in range(all_samples.shape[0]): # Save npz with all metadata. sample_path = os.path.join(individual_dir, f'sample_{i:04d}.npz') np.savez(sample_path, z=all_samples[i], # (T, D) text=np.array(all_texts[i]), length=np.array(all_lengths[i]), ) # Save npy with z only. z_path = os.path.join(individual_dir, f'z_{i:04d}.npy') np.save(z_path, all_samples[i]) # (T, D) print(f"Saved {all_samples.shape[0]} individual files to {individual_dir}") # Save sampling config. config_path = os.path.join(args.output_dir, 'sample_config.json') with open(config_path, 'w') as f: json.dump({ 'model_path': args.model_path, 'num_samples': args.num_samples, 'num_repetitions': args.num_repetitions, 'motion_length': args.motion_length, 'guidance_param': args.guidance_param, 'seed': args.seed, 'use_ema': args.use_ema, 'sample_shape': list(all_samples.shape), }, f, indent=4) print(f"\nDone! Results saved to {args.output_dir}") print("\nTo load the results in your decoder:") print(" import numpy as np") print(f" data = np.load('{results_path}')") print(" z = data['z'] # shape: (N, T, D), dtype: float32") print(" texts = data['texts'] # shape: (N,), dtype: str") print(" lengths = data['lengths'] # shape: (N,), dtype: int") print("\n # For individual samples:") print(" single_z = z[0] # shape: (T, D)") return args.output_dir if __name__ == "__main__": main()