| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import argparse |
| | import os |
| |
|
| | import imageio |
| | import torch |
| |
|
| | from cosmos1.models.autoregressive.inference.world_generation_pipeline import ARBaseGenerationPipeline |
| | from cosmos1.models.autoregressive.utils.inference import add_common_arguments, load_vision_input, validate_args |
| | from .log import log |
| |
|
| |
|
| | def parse_args(): |
| | parser = argparse.ArgumentParser(description="Video to world generation demo script") |
| | |
| | add_common_arguments(parser) |
| | parser.add_argument( |
| | "--ar_model_dir", |
| | type=str, |
| | default="Cosmos-1.0-Autoregressive-4B", |
| | ) |
| | parser.add_argument("--input_type", type=str, default="video", help="Type of input", choices=["image", "video"]) |
| | args = parser.parse_args() |
| | return args |
| |
|
| |
|
| | def main(args): |
| | """Run video-to-world generation demo. |
| | |
| | This function handles the main video-to-world generation pipeline, including: |
| | - Setting up the random seed for reproducibility |
| | - Initializing the generation pipeline with the provided configuration |
| | - Processing single or multiple images/videos from input |
| | - Generating videos from images/videos |
| | - Saving the generated videos to disk |
| | |
| | Args: |
| | cfg (argparse.Namespace): Configuration namespace containing: |
| | - Model configuration (checkpoint paths, model settings) |
| | - Generation parameters (temperature, top_p) |
| | - Input/output settings (images/videos, save paths) |
| | - Performance options (model offloading settings) |
| | |
| | The function will save: |
| | - Generated MP4 video files |
| | |
| | If guardrails block the generation, a critical log message is displayed |
| | and the function continues to the next prompt if available. |
| | """ |
| | inference_type = "base" |
| | sampling_config = validate_args(args, inference_type) |
| |
|
| | |
| | pipeline = ARBaseGenerationPipeline( |
| | inference_type=inference_type, |
| | checkpoint_dir=args.checkpoint_dir, |
| | checkpoint_name=args.ar_model_dir, |
| | disable_diffusion_decoder=args.disable_diffusion_decoder, |
| | offload_guardrail_models=args.offload_guardrail_models, |
| | offload_diffusion_decoder=args.offload_diffusion_decoder, |
| | offload_network=args.offload_ar_model, |
| | offload_tokenizer=args.offload_tokenizer, |
| | ) |
| |
|
| | |
| | input_videos = load_vision_input( |
| | input_type=args.input_type, |
| | batch_input_path=args.batch_input_path, |
| | input_image_or_video_path=args.input_image_or_video_path, |
| | data_resolution=args.data_resolution, |
| | num_input_frames=args.num_input_frames, |
| | ) |
| |
|
| | for idx, input_filename in enumerate(input_videos): |
| | inp_vid = input_videos[input_filename] |
| | |
| | log.info(f"Run with image or video path: {input_filename}") |
| | out_vid = pipeline.generate( |
| | inp_vid=inp_vid, |
| | num_input_frames=args.num_input_frames, |
| | seed=args.seed, |
| | sampling_config=sampling_config, |
| | ) |
| | if out_vid is None: |
| | log.critical("Guardrail blocked base generation.") |
| | continue |
| |
|
| | |
| | if args.input_image_or_video_path: |
| | out_vid_path = os.path.join(args.video_save_folder, f"{args.video_save_name}.mp4") |
| | else: |
| | out_vid_path = os.path.join(args.video_save_folder, f"{idx}.mp4") |
| |
|
| | imageio.mimsave(out_vid_path, out_vid, fps=25) |
| |
|
| | log.info(f"Saved video to {out_vid_path}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | torch._C._jit_set_texpr_fuser_enabled(False) |
| | args = parse_args() |
| | main(args) |
| |
|