| | import os |
| | import sys |
| |
|
| | import torch |
| | from lightning import seed_everything |
| | from safetensors.torch import load_file as load_safetensors |
| |
|
| | from ldf_utils.initialize import compare_statedict_and_parameters, instantiate, load_config |
| |
|
| | |
| | os.environ["TOKENIZERS_PARALLELISM"] = "false" |
| |
|
| |
|
| | def load_model_from_config(): |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | torch.set_float32_matmul_precision("high") |
| | cfg = load_config() |
| | seed_everything(cfg.seed) |
| | |
| | |
| | |
| | if '--config' in sys.argv: |
| | config_idx = sys.argv.index('--config') + 1 |
| | config_dir = os.path.dirname(os.path.abspath(sys.argv[config_idx])) |
| | else: |
| | config_dir = os.getcwd() |
| |
|
| | vae = instantiate( |
| | target=cfg.test_vae.target, |
| | cfg=None, |
| | hfstyle=False, |
| | **cfg.test_vae.params, |
| | ) |
| | |
| | |
| | vae_path = cfg.test_vae_ckpt |
| | if not os.path.isabs(vae_path): |
| | vae_path = os.path.join(config_dir, vae_path) |
| | |
| | |
| | vae_state_dict = load_safetensors(vae_path) |
| | vae.load_state_dict(vae_state_dict, strict=True) |
| | print(f"Loaded VAE model from {vae_path}") |
| |
|
| | compare_statedict_and_parameters( |
| | state_dict=vae.state_dict(), |
| | named_parameters=vae.named_parameters(), |
| | named_buffers=vae.named_buffers(), |
| | ) |
| | vae.to(device) |
| | vae.eval() |
| |
|
| | |
| | model_params = dict(cfg.model.params) |
| | |
| | if 'checkpoint_path' in model_params and model_params['checkpoint_path']: |
| | if not os.path.isabs(model_params['checkpoint_path']): |
| | model_params['checkpoint_path'] = os.path.join(config_dir, model_params['checkpoint_path']) |
| | if 'tokenizer_path' in model_params and model_params['tokenizer_path']: |
| | if not os.path.isabs(model_params['tokenizer_path']): |
| | model_params['tokenizer_path'] = os.path.join(config_dir, model_params['tokenizer_path']) |
| | |
| | model = instantiate( |
| | target=cfg.model.target, cfg=None, hfstyle=False, **model_params |
| | ) |
| | |
| | |
| | model_path = cfg.test_ckpt |
| | if not os.path.isabs(model_path): |
| | model_path = os.path.join(config_dir, model_path) |
| | |
| | |
| | model_state_dict = load_safetensors(model_path) |
| | model.load_state_dict(model_state_dict, strict=True) |
| | print(f"Loaded model from {model_path}") |
| |
|
| | compare_statedict_and_parameters( |
| | state_dict=model.state_dict(), |
| | named_parameters=model.named_parameters(), |
| | named_buffers=model.named_buffers(), |
| | ) |
| | model.to(device) |
| | model.eval() |
| |
|
| | return vae, model |
| |
|
| |
|
| | @torch.inference_mode() |
| | def generate_feature_stream( |
| | model, feature_length, text, feature_text_end=None, num_denoise_steps=None |
| | ): |
| | """ |
| | Streaming interface for feature generation |
| | Args: |
| | model: Loaded model |
| | feature_length: List[int], generation length for each sample |
| | text: List[str] or List[List[str]], text prompts |
| | feature_text_end: List[List[int]], time points where text ends (if text is list of list) |
| | num_denoise_steps: Number of denoising steps |
| | Yields: |
| | dict: Contains "generated" (current generated feature segment) |
| | """ |
| |
|
| | |
| | |
| | x = {"feature_length": torch.tensor(feature_length), "text": text} |
| |
|
| | if feature_text_end is not None: |
| | x["feature_text_end"] = feature_text_end |
| |
|
| | |
| | |
| | generator = model.stream_generate(x, num_denoise_steps=num_denoise_steps) |
| |
|
| | for step_output in generator: |
| | |
| | yield step_output |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import argparse |
| |
|
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--config", type=str, required=True, help="Path to config") |
| | parser.add_argument( |
| | "--text", type=str, default="a person walks forward", help="Text prompt" |
| | ) |
| | parser.add_argument("--length", type=int, default=120, help="Motion length") |
| | parser.add_argument( |
| | "--output", type=str, default="output.mp4", help="Output video path" |
| | ) |
| | parser.add_argument( |
| | "--num_denoise_steps", type=int, default=None, help="Number of denoising steps" |
| | ) |
| | args = parser.parse_args() |
| |
|
| | print("Loading model...") |
| | vae, model = load_model_from_config() |
| |
|
| |
|