| |
| """ |
| Generate preprocessed posterior features (z) from text for an external decoder. |
| |
| Unlike generate.py, this script does not post-process the generated features. |
| It saves the raw z features in a format that external decoders can read. |
| """ |
| import os |
| import numpy as np |
| import torch |
| from argparse import ArgumentParser |
| from utils.fixseed import fixseed |
| from utils import dist_util |
| from utils.model_util import create_gaussian_diffusion |
| from utils.sampler_util import ClassifierFreeSampleModel |
| from data_loaders.tensors import collate |
| from model.mdm import MDM |
|
|
| import json |
|
|
|
|
| def sample_z_args(): |
| """Parse command-line arguments for z sampling.""" |
| parser = ArgumentParser(description='Sample z from text using trained MDM model') |
| |
| |
| parser.add_argument("--model_path", required=True, type=str, |
| help="Path to trained model checkpoint") |
| parser.add_argument("--use_ema", action='store_true', |
| help="Use EMA model if available") |
| |
| |
| parser.add_argument("--text_prompt", default='', type=str, |
| help="A single text prompt to generate") |
| parser.add_argument("--input_text", default='', type=str, |
| help="Path to a text file with prompts (one per line)") |
| |
| |
| parser.add_argument("--num_samples", default=1, type=int, |
| help="Number of samples per prompt") |
| parser.add_argument("--num_repetitions", default=1, type=int, |
| help="Number of repetitions for each prompt") |
| parser.add_argument("--motion_length", default=6.0, type=float, |
| help="Motion length in seconds") |
| parser.add_argument("--guidance_param", default=2.5, type=float, |
| help="Classifier-free guidance scale") |
| |
| |
| parser.add_argument("--output_dir", default='', type=str, |
| help="Output directory for results") |
| parser.add_argument("--save_individual", action='store_true', |
| help="Save each sample as individual file") |
| |
| |
| parser.add_argument("--seed", default=10, type=int) |
| parser.add_argument("--device", default=0, type=int) |
| |
| args = parser.parse_args() |
| return args |
|
|
|
|
| def load_model_from_checkpoint(model_path, device, use_ema=False): |
| """Load a model from a checkpoint.""" |
| |
| args_path = os.path.join(os.path.dirname(model_path), 'args.json') |
| with open(args_path, 'r') as f: |
| model_args = json.load(f) |
| |
| |
| class Args: |
| pass |
| args = Args() |
| for k, v in model_args.items(): |
| setattr(args, k, v) |
| |
| |
| if hasattr(args, 'unconstrained') and args.unconstrained: |
| cond_mode = 'no_cond' |
| elif args.dataset in ['kit', 'humanml', 'preprocessed_posterior']: |
| cond_mode = 'text' |
| else: |
| cond_mode = 'action' |
| |
| |
| njoints = getattr(args, 'njoints', 512) |
| nfeats = getattr(args, 'nfeats', 1) |
| |
| model_kwargs = { |
| 'modeltype': '', |
| 'njoints': njoints, |
| 'nfeats': nfeats, |
| 'num_actions': 1, |
| 'translation': True, |
| 'pose_rep': 'rot6d', |
| 'glob': True, |
| 'glob_rot': True, |
| 'latent_dim': args.latent_dim, |
| 'ff_size': 1024, |
| 'num_layers': args.layers, |
| 'num_heads': 4, |
| 'dropout': 0.1, |
| 'activation': 'gelu', |
| 'data_rep': 'hml_vec', |
| 'cond_mode': cond_mode, |
| 'cond_mask_prob': getattr(args, 'cond_mask_prob', 0.1), |
| 'action_emb': 'tensor', |
| 'arch': args.arch, |
| 'emb_trans_dec': getattr(args, 'emb_trans_dec', False), |
| 'clip_version': 'ViT-B/32', |
| 'dataset': args.dataset, |
| 'text_encoder_type': getattr(args, 'text_encoder_type', 'clip'), |
| 'pos_embed_max_len': getattr(args, 'pos_embed_max_len', 5000), |
| 'mask_frames': getattr(args, 'mask_frames', False), |
| 'pred_len': getattr(args, 'pred_len', 0), |
| 'context_len': getattr(args, 'context_len', 0), |
| 'emb_policy': 'add', |
| 'all_goal_joint_names': [], |
| 'multi_target_cond': getattr(args, 'multi_target_cond', False), |
| 'multi_encoder_type': getattr(args, 'multi_encoder_type', 'single'), |
| 'target_enc_layers': getattr(args, 'target_enc_layers', 1), |
| 'use_rot2xyz': False, |
| } |
| |
| |
| model = MDM(**model_kwargs) |
| |
| |
| state_dict = torch.load(model_path, map_location=device) |
| |
| if use_ema and 'model_avg' in state_dict: |
| print("Loading EMA model weights...") |
| state_dict = state_dict['model_avg'] |
| elif 'model' in state_dict: |
| state_dict = state_dict['model'] |
| |
| |
| keys_to_delete = ['sequence_pos_encoder.pe', 'embed_timestep.sequence_pos_encoder.pe'] |
| for key in keys_to_delete: |
| if key in state_dict: |
| del state_dict[key] |
| |
| missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) |
| print(f"Loaded model. Missing keys: {len(missing_keys)}, Unexpected keys: {len(unexpected_keys)}") |
| |
| |
| diffusion = create_gaussian_diffusion(args) |
| |
| return model, diffusion, args |
|
|
|
|
| def main(): |
| args = sample_z_args() |
| fixseed(args.seed) |
| |
| dist_util.setup_dist(args.device) |
| device = dist_util.dev() |
| |
| |
| if args.text_prompt != '': |
| texts = [args.text_prompt] * args.num_samples |
| elif args.input_text != '': |
| assert os.path.exists(args.input_text), f"Input text file not found: {args.input_text}" |
| with open(args.input_text, 'r') as f: |
| texts = [line.strip() for line in f.readlines() if line.strip()] |
| args.num_samples = len(texts) |
| else: |
| raise ValueError("Must provide either --text_prompt or --input_text") |
| |
| print(f"Loaded {len(texts)} text prompts") |
| |
| |
| print(f"Loading model from {args.model_path}...") |
| model, diffusion, model_args = load_model_from_checkpoint( |
| args.model_path, device, use_ema=args.use_ema |
| ) |
| model.to(device) |
| model.eval() |
| |
| |
| if args.guidance_param != 1: |
| model = ClassifierFreeSampleModel(model) |
| |
| |
| fps = 20 |
| n_frames = int(args.motion_length * fps) |
| |
| |
| batch_size = len(texts) |
| motion_shape = (batch_size, model.njoints if hasattr(model, 'njoints') else model.model.njoints, |
| model.nfeats if hasattr(model, 'nfeats') else model.model.nfeats, n_frames) |
| |
| print(f"Motion shape: {motion_shape}") |
| |
| |
| collate_args = [{'inp': torch.zeros(n_frames), 'tokens': None, 'lengths': n_frames, 'text': txt} |
| for txt in texts] |
| _, model_kwargs = collate(collate_args) |
| model_kwargs['y'] = {key: val.to(device) if torch.is_tensor(val) else val |
| for key, val in model_kwargs['y'].items()} |
| |
| |
| if args.guidance_param != 1: |
| model_kwargs['y']['scale'] = torch.ones(batch_size, device=device) * args.guidance_param |
| |
| |
| actual_model = model.model if hasattr(model, 'model') else model |
| if hasattr(actual_model, 'encode_text') and 'text' in model_kwargs['y']: |
| print("Pre-encoding text...") |
| model_kwargs['y']['text_embed'] = actual_model.encode_text(model_kwargs['y']['text']) |
| |
| |
| all_samples = [] |
| all_texts = [] |
| all_lengths = [] |
| |
| sample_fn = diffusion.p_sample_loop |
| |
| for rep_i in range(args.num_repetitions): |
| print(f"\n### Sampling [repetition #{rep_i + 1}/{args.num_repetitions}]") |
| |
| sample = sample_fn( |
| model, |
| motion_shape, |
| clip_denoised=False, |
| model_kwargs=model_kwargs, |
| skip_timesteps=0, |
| init_image=None, |
| progress=True, |
| dump_steps=None, |
| noise=None, |
| const_noise=False, |
| ) |
| |
| |
| |
| |
| all_samples.append(sample.cpu().numpy()) |
| all_texts.extend(texts) |
| all_lengths.extend([n_frames] * batch_size) |
| |
| print(f"Generated {len(texts)} samples") |
| |
| |
| all_samples = np.concatenate(all_samples, axis=0) |
| all_lengths = np.array(all_lengths) |
| |
| |
| all_samples = all_samples.squeeze(2) |
| all_samples = all_samples.transpose(0, 2, 1) |
| |
| print(f"\nTotal samples: {all_samples.shape[0]}") |
| print(f"Sample shape: {all_samples.shape} # (N, T, D)") |
| |
| |
| if args.output_dir == '': |
| model_name = os.path.basename(os.path.dirname(args.model_path)) |
| niter = os.path.basename(args.model_path).replace('model', '').replace('.pt', '') |
| args.output_dir = os.path.join(os.path.dirname(args.model_path), |
| f'sampled_z_{model_name}_{niter}_seed{args.seed}') |
| |
| os.makedirs(args.output_dir, exist_ok=True) |
| |
| |
| results_path = os.path.join(args.output_dir, 'results.npz') |
| print(f"\nSaving results to {results_path}") |
| np.savez(results_path, |
| z=all_samples, |
| texts=np.array(all_texts, dtype=str), |
| lengths=all_lengths, |
| motion_length=np.array(args.motion_length), |
| guidance_param=np.array(args.guidance_param), |
| model_path=np.array(args.model_path), |
| ) |
| |
| |
| if args.save_individual: |
| individual_dir = os.path.join(args.output_dir, 'individual') |
| os.makedirs(individual_dir, exist_ok=True) |
| |
| for i in range(all_samples.shape[0]): |
| |
| sample_path = os.path.join(individual_dir, f'sample_{i:04d}.npz') |
| np.savez(sample_path, |
| z=all_samples[i], |
| text=np.array(all_texts[i]), |
| length=np.array(all_lengths[i]), |
| ) |
| |
| z_path = os.path.join(individual_dir, f'z_{i:04d}.npy') |
| np.save(z_path, all_samples[i]) |
| |
| print(f"Saved {all_samples.shape[0]} individual files to {individual_dir}") |
| |
| |
| config_path = os.path.join(args.output_dir, 'sample_config.json') |
| with open(config_path, 'w') as f: |
| json.dump({ |
| 'model_path': args.model_path, |
| 'num_samples': args.num_samples, |
| 'num_repetitions': args.num_repetitions, |
| 'motion_length': args.motion_length, |
| 'guidance_param': args.guidance_param, |
| 'seed': args.seed, |
| 'use_ema': args.use_ema, |
| 'sample_shape': list(all_samples.shape), |
| }, f, indent=4) |
| |
| print(f"\nDone! Results saved to {args.output_dir}") |
| print("\nTo load the results in your decoder:") |
| print(" import numpy as np") |
| print(f" data = np.load('{results_path}')") |
| print(" z = data['z'] # shape: (N, T, D), dtype: float32") |
| print(" texts = data['texts'] # shape: (N,), dtype: str") |
| print(" lengths = data['lengths'] # shape: (N,), dtype: int") |
| print("\n # For individual samples:") |
| print(" single_z = z[0] # shape: (T, D)") |
| |
| return args.output_dir |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|