SATA / src /mdm /sample /sample_z.py
zzysteve
Initial commit
5221c8c
Raw
History Blame Contribute Delete
12.1 kB
# This code is based on https://github.com/openai/guided-diffusion
"""
Generate preprocessed posterior features (z) from text for an external decoder.
Unlike generate.py, this script does not post-process the generated features.
It saves the raw z features in a format that external decoders can read.
"""
import os
import numpy as np
import torch
from argparse import ArgumentParser
from utils.fixseed import fixseed
from utils import dist_util
from utils.model_util import create_gaussian_diffusion
from utils.sampler_util import ClassifierFreeSampleModel
from data_loaders.tensors import collate
from model.mdm import MDM
import json
def sample_z_args():
"""Parse command-line arguments for z sampling."""
parser = ArgumentParser(description='Sample z from text using trained MDM model')
# Model options
parser.add_argument("--model_path", required=True, type=str,
help="Path to trained model checkpoint")
parser.add_argument("--use_ema", action='store_true',
help="Use EMA model if available")
# Input options
parser.add_argument("--text_prompt", default='', type=str,
help="A single text prompt to generate")
parser.add_argument("--input_text", default='', type=str,
help="Path to a text file with prompts (one per line)")
# Sampling options
parser.add_argument("--num_samples", default=1, type=int,
help="Number of samples per prompt")
parser.add_argument("--num_repetitions", default=1, type=int,
help="Number of repetitions for each prompt")
parser.add_argument("--motion_length", default=6.0, type=float,
help="Motion length in seconds")
parser.add_argument("--guidance_param", default=2.5, type=float,
help="Classifier-free guidance scale")
# Output options
parser.add_argument("--output_dir", default='', type=str,
help="Output directory for results")
parser.add_argument("--save_individual", action='store_true',
help="Save each sample as individual file")
# Misc options
parser.add_argument("--seed", default=10, type=int)
parser.add_argument("--device", default=0, type=int)
args = parser.parse_args()
return args
def load_model_from_checkpoint(model_path, device, use_ema=False):
"""Load a model from a checkpoint."""
# Load args.
args_path = os.path.join(os.path.dirname(model_path), 'args.json')
with open(args_path, 'r') as f:
model_args = json.load(f)
# Create a simple args object.
class Args:
pass
args = Args()
for k, v in model_args.items():
setattr(args, k, v)
# Determine conditioning mode.
if hasattr(args, 'unconstrained') and args.unconstrained:
cond_mode = 'no_cond'
elif args.dataset in ['kit', 'humanml', 'preprocessed_posterior']:
cond_mode = 'text'
else:
cond_mode = 'action'
# Get model parameters.
njoints = getattr(args, 'njoints', 512)
nfeats = getattr(args, 'nfeats', 1)
model_kwargs = {
'modeltype': '',
'njoints': njoints,
'nfeats': nfeats,
'num_actions': 1,
'translation': True,
'pose_rep': 'rot6d',
'glob': True,
'glob_rot': True,
'latent_dim': args.latent_dim,
'ff_size': 1024,
'num_layers': args.layers,
'num_heads': 4,
'dropout': 0.1,
'activation': 'gelu',
'data_rep': 'hml_vec',
'cond_mode': cond_mode,
'cond_mask_prob': getattr(args, 'cond_mask_prob', 0.1),
'action_emb': 'tensor',
'arch': args.arch,
'emb_trans_dec': getattr(args, 'emb_trans_dec', False),
'clip_version': 'ViT-B/32',
'dataset': args.dataset,
'text_encoder_type': getattr(args, 'text_encoder_type', 'clip'),
'pos_embed_max_len': getattr(args, 'pos_embed_max_len', 5000),
'mask_frames': getattr(args, 'mask_frames', False),
'pred_len': getattr(args, 'pred_len', 0),
'context_len': getattr(args, 'context_len', 0),
'emb_policy': 'add',
'all_goal_joint_names': [],
'multi_target_cond': getattr(args, 'multi_target_cond', False),
'multi_encoder_type': getattr(args, 'multi_encoder_type', 'single'),
'target_enc_layers': getattr(args, 'target_enc_layers', 1),
'use_rot2xyz': False,
}
# Create the model without SMPL-backed xyz conversion; this script saves raw z only.
model = MDM(**model_kwargs)
# Load weights.
state_dict = torch.load(model_path, map_location=device)
if use_ema and 'model_avg' in state_dict:
print("Loading EMA model weights...")
state_dict = state_dict['model_avg']
elif 'model' in state_dict:
state_dict = state_dict['model']
# Remove unused keys.
keys_to_delete = ['sequence_pos_encoder.pe', 'embed_timestep.sequence_pos_encoder.pe']
for key in keys_to_delete:
if key in state_dict:
del state_dict[key]
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
print(f"Loaded model. Missing keys: {len(missing_keys)}, Unexpected keys: {len(unexpected_keys)}")
# Create diffusion.
diffusion = create_gaussian_diffusion(args)
return model, diffusion, args
def main():
args = sample_z_args()
fixseed(args.seed)
dist_util.setup_dist(args.device)
device = dist_util.dev()
# Load text prompts.
if args.text_prompt != '':
texts = [args.text_prompt] * args.num_samples
elif args.input_text != '':
assert os.path.exists(args.input_text), f"Input text file not found: {args.input_text}"
with open(args.input_text, 'r') as f:
texts = [line.strip() for line in f.readlines() if line.strip()]
args.num_samples = len(texts)
else:
raise ValueError("Must provide either --text_prompt or --input_text")
print(f"Loaded {len(texts)} text prompts")
# Load model.
print(f"Loading model from {args.model_path}...")
model, diffusion, model_args = load_model_from_checkpoint(
args.model_path, device, use_ema=args.use_ema
)
model.to(device)
model.eval()
# Set classifier-free guidance.
if args.guidance_param != 1:
model = ClassifierFreeSampleModel(model)
# Compute frame count.
fps = 20 # Default FPS
n_frames = int(args.motion_length * fps)
# Prepare inputs.
batch_size = len(texts)
motion_shape = (batch_size, model.njoints if hasattr(model, 'njoints') else model.model.njoints,
model.nfeats if hasattr(model, 'nfeats') else model.model.nfeats, n_frames)
print(f"Motion shape: {motion_shape}")
# Build model_kwargs.
collate_args = [{'inp': torch.zeros(n_frames), 'tokens': None, 'lengths': n_frames, 'text': txt}
for txt in texts]
_, model_kwargs = collate(collate_args)
model_kwargs['y'] = {key: val.to(device) if torch.is_tensor(val) else val
for key, val in model_kwargs['y'].items()}
# Add guidance scale.
if args.guidance_param != 1:
model_kwargs['y']['scale'] = torch.ones(batch_size, device=device) * args.guidance_param
# Pre-encode text when the model supports it.
actual_model = model.model if hasattr(model, 'model') else model
if hasattr(actual_model, 'encode_text') and 'text' in model_kwargs['y']:
print("Pre-encoding text...")
model_kwargs['y']['text_embed'] = actual_model.encode_text(model_kwargs['y']['text'])
# Sample.
all_samples = []
all_texts = []
all_lengths = []
sample_fn = diffusion.p_sample_loop
for rep_i in range(args.num_repetitions):
print(f"\n### Sampling [repetition #{rep_i + 1}/{args.num_repetitions}]")
sample = sample_fn(
model,
motion_shape,
clip_denoised=False,
model_kwargs=model_kwargs,
skip_timesteps=0,
init_image=None,
progress=True,
dump_steps=None,
noise=None,
const_noise=False,
)
# Do not post-process here; save raw z features directly.
# sample shape: (batch_size, njoints, nfeats, n_frames)
all_samples.append(sample.cpu().numpy())
all_texts.extend(texts)
all_lengths.extend([n_frames] * batch_size)
print(f"Generated {len(texts)} samples")
# Merge all results.
all_samples = np.concatenate(all_samples, axis=0) # (N, D, 1, T)
all_lengths = np.array(all_lengths)
# Convert shape: (N, D, 1, T) -> (N, T, D).
all_samples = all_samples.squeeze(2) # (N, D, T)
all_samples = all_samples.transpose(0, 2, 1) # (N, T, D)
print(f"\nTotal samples: {all_samples.shape[0]}")
print(f"Sample shape: {all_samples.shape} # (N, T, D)")
# Save results.
if args.output_dir == '':
model_name = os.path.basename(os.path.dirname(args.model_path))
niter = os.path.basename(args.model_path).replace('model', '').replace('.pt', '')
args.output_dir = os.path.join(os.path.dirname(args.model_path),
f'sampled_z_{model_name}_{niter}_seed{args.seed}')
os.makedirs(args.output_dir, exist_ok=True)
# Save the main result file in npz format.
results_path = os.path.join(args.output_dir, 'results.npz')
print(f"\nSaving results to {results_path}")
np.savez(results_path,
z=all_samples, # (N, T, D)
texts=np.array(all_texts, dtype=str), # (N,) string array
lengths=all_lengths, # (N,)
motion_length=np.array(args.motion_length),
guidance_param=np.array(args.guidance_param),
model_path=np.array(args.model_path),
)
# Optionally save individual files.
if args.save_individual:
individual_dir = os.path.join(args.output_dir, 'individual')
os.makedirs(individual_dir, exist_ok=True)
for i in range(all_samples.shape[0]):
# Save npz with all metadata.
sample_path = os.path.join(individual_dir, f'sample_{i:04d}.npz')
np.savez(sample_path,
z=all_samples[i], # (T, D)
text=np.array(all_texts[i]),
length=np.array(all_lengths[i]),
)
# Save npy with z only.
z_path = os.path.join(individual_dir, f'z_{i:04d}.npy')
np.save(z_path, all_samples[i]) # (T, D)
print(f"Saved {all_samples.shape[0]} individual files to {individual_dir}")
# Save sampling config.
config_path = os.path.join(args.output_dir, 'sample_config.json')
with open(config_path, 'w') as f:
json.dump({
'model_path': args.model_path,
'num_samples': args.num_samples,
'num_repetitions': args.num_repetitions,
'motion_length': args.motion_length,
'guidance_param': args.guidance_param,
'seed': args.seed,
'use_ema': args.use_ema,
'sample_shape': list(all_samples.shape),
}, f, indent=4)
print(f"\nDone! Results saved to {args.output_dir}")
print("\nTo load the results in your decoder:")
print(" import numpy as np")
print(f" data = np.load('{results_path}')")
print(" z = data['z'] # shape: (N, T, D), dtype: float32")
print(" texts = data['texts'] # shape: (N,), dtype: str")
print(" lengths = data['lengths'] # shape: (N,), dtype: int")
print("\n # For individual samples:")
print(" single_z = z[0] # shape: (T, D)")
return args.output_dir
if __name__ == "__main__":
main()