| | import os |
| | import time |
| | from pathlib import Path |
| | from loguru import logger |
| | from datetime import datetime |
| |
|
| | from voyager.utils.file_utils import save_videos_grid |
| | from voyager.config import parse_args |
| | from voyager.inference import HunyuanVideoSampler |
| |
|
| |
|
| | def main(): |
| | args = parse_args() |
| | print(args) |
| | models_root_path = Path(args.model_base) |
| | if not models_root_path.exists(): |
| | raise ValueError(f"`models_root` not exists: {models_root_path}") |
| |
|
| | |
| | save_path = args.save_path if args.save_path_suffix == "" else f'{args.save_path}_{args.save_path_suffix}' |
| | if not os.path.exists(save_path): |
| | os.makedirs(save_path, exist_ok=True) |
| |
|
| | |
| | hunyuan_video_sampler = HunyuanVideoSampler.from_pretrained( |
| | models_root_path, args=args) |
| |
|
| | |
| | args = hunyuan_video_sampler.args |
| |
|
| | |
| | |
| | outputs = hunyuan_video_sampler.predict( |
| | prompt=args.prompt, |
| | height=args.video_size[0], |
| | width=args.video_size[1], |
| | video_length=args.video_length, |
| | seed=args.seed, |
| | negative_prompt=args.neg_prompt, |
| | infer_steps=args.infer_steps, |
| | guidance_scale=args.cfg_scale, |
| | num_videos_per_prompt=args.num_videos, |
| | flow_shift=args.flow_shift, |
| | batch_size=args.batch_size, |
| | embedded_guidance_scale=args.embedded_cfg_scale, |
| | i2v_mode=args.i2v_mode, |
| | i2v_resolution=args.i2v_resolution, |
| | i2v_image_path=args.i2v_image_path, |
| | i2v_condition_type=args.i2v_condition_type, |
| | i2v_stability=args.i2v_stability, |
| | ulysses_degree=args.ulysses_degree, |
| | ring_degree=args.ring_degree, |
| | ref_images=[(os.path.join(args.input_path, "ref_image.png"), |
| | os.path.join(args.input_path, "ref_depth.exr"))], |
| | partial_cond=[(os.path.join(args.input_path, "video_input", f"render_{j:04d}.png"), os.path.join( |
| | args.input_path, "video_input", f"depth_{j:04d}.exr")) for j in range(49)], |
| | partial_mask=[(os.path.join(args.input_path, "video_input", f"mask_{j:04d}.png"), os.path.join( |
| | args.input_path, "video_input", f"mask_{j:04d}.png")) for j in range(49)] |
| | ) |
| | samples = outputs['samples'] |
| |
|
| | |
| | |
| | if 'LOCAL_RANK' not in os.environ or int(os.environ['LOCAL_RANK']) == 0: |
| | for i, sample in enumerate(samples): |
| | sample = samples[i].unsqueeze(0) |
| | time_flag = datetime.fromtimestamp( |
| | time.time()).strftime("%Y-%m-%d-%H:%M:%S") |
| | cur_save_path = \ |
| | f"{save_path}/{time_flag}_seed{outputs['seeds'][i]}_{outputs['prompts'][i][:100].replace('/', '')}.mp4" |
| | save_videos_grid(sample, cur_save_path, fps=24) |
| | logger.info(f'Sample save to: {cur_save_path}') |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|