| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import gc |
| import os |
| from argparse import ArgumentParser |
| from typing import List |
|
|
| import imageio |
| import nemo.lightning as nl |
| import numpy as np |
| import torch |
| from einops import rearrange |
| from huggingface_hub import snapshot_download |
| from megatron.core.inference.common_inference_params import CommonInferenceParams |
| from megatron.core.inference.engines.mcore_engine import MCoreEngine |
| from megatron.core.inference.text_generation_controllers.simple_text_generation_controller import ( |
| SimpleTextGenerationController, |
| ) |
| from nemo.collections.llm.inference.base import _setup_trainer_and_restore_model |
| from nemo.lightning import io |
| from nemo.lightning.ckpt_utils import ckpt_to_context_subdir |
|
|
| from cosmos1.models.autoregressive.nemo.utils import run_diffusion_decoder_model |
| from discrete_video import DiscreteVideoFSQJITTokenizer |
| from cosmos1.models.autoregressive.utils.inference import load_vision_input |
| from .presets import presets as guardrail_presets |
| from .log import log |
|
|
| torch._C._jit_set_texpr_fuser_enabled(False) |
|
|
| TOKENIZER_COMPRESSION_FACTOR = [8, 16, 16] |
| NUM_CONTEXT_FRAMES = 33 |
| NUM_INPUT_FRAMES_VIDEO = 9 |
| LATENT_SHAPE = [5, 40, 64] |
| DATA_RESOLUTION = [640, 1024] |
|
|
|
|
| class CosmosMCoreTokenizerWrappper: |
| """ |
| A small dummy wrapper to pass into the text generation controller. |
| """ |
|
|
| def __init__(self): |
| self.tokenizer = None |
| self.eod = -1 |
| self.vocab_size = 64000 |
|
|
| def detokenize(self, tokens: List[int], remove_special_tokens: bool = False): |
| return tokens |
|
|
| def tokenize(self, prompt: List[int]): |
| return prompt |
|
|
|
|
| def main(args): |
| num_input_frames = 1 if args.input_type == "image" else NUM_INPUT_FRAMES_VIDEO |
|
|
| vision_input_dict = load_vision_input( |
| input_type=args.input_type, |
| batch_input_path=None, |
| input_image_or_video_path=args.input_image_or_video_path, |
| data_resolution=DATA_RESOLUTION, |
| num_input_frames=num_input_frames, |
| ) |
|
|
| vision_input = list(vision_input_dict.values())[0].cuda() |
|
|
| T, H, W = LATENT_SHAPE |
| latent_context_t_size = 1 if args.input_type == "image" else 2 |
| num_tokens_to_generate = int(np.prod([T - latent_context_t_size, H, W])) |
|
|
| |
| if args.encoder_path == "nvidia/Cosmos-1.0-Tokenizer-DV8x16x16": |
| args.encoder_path = os.path.join(snapshot_download(args.encoder_path), "encoder.jit") |
| if args.decoder_path == "nvidia/Cosmos-1.0-Tokenizer-DV8x16x16": |
| args.decoder_path = os.path.join(snapshot_download(args.decoder_path), "decoder.jit") |
| video_tokenizer = DiscreteVideoFSQJITTokenizer( |
| enc_fp=args.encoder_path, |
| dec_fp=args.decoder_path, |
| name="discrete_video_fsq", |
| pixel_chunk_duration=NUM_CONTEXT_FRAMES, |
| latent_chunk_duration=T, |
| ).cuda() |
|
|
| quantized_out, _ = video_tokenizer.encode(vision_input, pixel_chunk_duration=None) |
| indices = video_tokenizer.fsq_quantizer.codes_to_indices(quantized_out.permute(0, 2, 3, 4, 1)) |
| indices = rearrange(indices, "B T H W -> B (T H W)") |
| video_tokens = [indices[0][0:-num_tokens_to_generate].tolist()] |
|
|
| |
| if args.ar_model_dir in ["nvidia/Cosmos-1.0-Autoregressive-4B", "nvidia/Cosmos-1.0-Autoregressive-12B"]: |
| args.ar_model_dir = os.path.join(snapshot_download(args.ar_model_dir, allow_patterns=["nemo/*"]), "nemo") |
| model: io.TrainerContext = io.load_context(path=ckpt_to_context_subdir(args.ar_model_dir), subpath="model") |
|
|
| strategy = nl.MegatronStrategy( |
| tensor_model_parallel_size=1, |
| pipeline_model_parallel_size=1, |
| context_parallel_size=1, |
| sequence_parallel=False, |
| setup_optimizers=False, |
| store_optimizer_states=False, |
| ) |
|
|
| trainer = nl.Trainer( |
| accelerator="gpu", |
| devices=1, |
| num_nodes=1, |
| strategy=strategy, |
| num_sanity_val_steps=0, |
| plugins=nl.MegatronMixedPrecision( |
| precision="bf16-mixed", |
| params_dtype=torch.bfloat16, |
| pipeline_dtype=torch.bfloat16, |
| autocast_enabled=False, |
| grad_reduce_in_fp32=False, |
| ), |
| ) |
| _setup_trainer_and_restore_model(path=args.ar_model_dir, trainer=trainer, model=model) |
|
|
| inference_wrapped_model = model.get_inference_wrapper(torch.bfloat16, inference_batch_times_seqlen_threshold=1000) |
|
|
| |
| text_generation_controller = SimpleTextGenerationController( |
| inference_wrapped_model=inference_wrapped_model, tokenizer=CosmosMCoreTokenizerWrappper() |
| ) |
|
|
| mcore_engine = MCoreEngine(text_generation_controller=text_generation_controller, max_batch_size=1) |
|
|
| common_inference_params = CommonInferenceParams( |
| temperature=args.temperature, top_p=args.top_p, num_tokens_to_generate=num_tokens_to_generate |
| ) |
|
|
| log.info(f"Running Inference to generate {num_tokens_to_generate} tokens. This will take some time. ") |
| results = mcore_engine.generate( |
| prompts=video_tokens, |
| add_BOS=False, |
| encoder_prompts=None, |
| common_inference_params=common_inference_params, |
| ) |
|
|
| result = list(results)[0] |
| prompt_tokens = torch.tensor(result.prompt_tokens).cuda() |
| prompt_tokens[prompt_tokens == -1] = result.generated_tokens |
|
|
| indices_tensor = prompt_tokens.unsqueeze(dim=0) |
| indices_tensor = rearrange( |
| indices_tensor, |
| "B (T H W) -> B T H W", |
| T=LATENT_SHAPE[0], |
| H=LATENT_SHAPE[1], |
| W=LATENT_SHAPE[2], |
| ) |
|
|
| if torch.cuda.current_device() == 0: |
| |
| log.info("Running diffusion model on the generated result") |
| video_decoded = video_tokenizer.decode(indices_tensor.cuda()) |
| out_video = (video_decoded * 0.5 + 0.5).clamp_(0, 1) |
|
|
| if not args.disable_diffusion_decoder: |
| del model |
| del inference_wrapped_model |
| del video_tokenizer |
| model = None |
| inference_wrapped_model = None |
| video_tokenizer = None |
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
| out_video = run_diffusion_decoder_model( |
| indices_tensor_cur_batch=[indices_tensor.squeeze()], out_videos_cur_batch=out_video |
| ) |
|
|
| out_video = out_video[0].detach().clone() |
| output_video = (out_video * 255).to(torch.uint8).permute(1, 2, 3, 0).cpu().numpy() |
|
|
| if args.guardrail_dir: |
| log.info("Running guardrails on the generated video") |
| if args.guardrail_dir == "nvidia/Cosmos-1.0-Guardrail": |
| args.guardrail_dir = snapshot_download(args.guardrail_dir) |
| video_guardrail = guardrail_presets.create_video_guardrail_runner(checkpoint_dir=args.guardrail_dir) |
| output_video = guardrail_presets.run_video_guardrail(output_video, video_guardrail) |
| if output_video is None: |
| raise ValueError("Guardrail blocked world generation.") |
|
|
| |
| imageio.mimsave( |
| args.video_save_name, |
| output_video, |
| fps=25, |
| ) |
|
|
| log.info(f"Saved to {args.video_save_name}") |
|
|
|
|
| if __name__ == "__main__": |
| parser = ArgumentParser() |
| parser.add_argument("--input_type", type=str, default="video", help="Type of input", choices=["image", "video"]) |
| parser.add_argument( |
| "--input_image_or_video_path", required=True, type=str, help="The path to the input video to run inference" |
| ) |
| parser.add_argument( |
| "--video_save_name", default="./nemo_generated_video.mp4", type=str, help="The path to generated video" |
| ) |
| parser.add_argument( |
| "--ar_model_dir", |
| default="nvidia/Cosmos-1.0-Autoregressive-4B", |
| type=str, |
| help="The path to the nemo autoregressive model", |
| ) |
| parser.add_argument( |
| "--encoder_path", default="nvidia/Cosmos-1.0-Tokenizer-DV8x16x16", type=str, help="The path to encoder" |
| ) |
| parser.add_argument( |
| "--decoder_path", default="nvidia/Cosmos-1.0-Tokenizer-DV8x16x16", type=str, help="The path to the decoder" |
| ) |
| parser.add_argument( |
| "--guardrail_dir", default="nvidia/Cosmos-1.0-Guardrail", type=str, help="The path to the guardrails" |
| ) |
| parser.add_argument("--top_p", default=0.8, type=float, help="The top_p inference parameter ") |
| parser.add_argument("--temperature", default=1, type=int, help="Sampling temperature") |
| parser.add_argument("--disable_diffusion_decoder", action="store_true", help="Disable diffusion decoder") |
|
|
| args = parser.parse_args() |
|
|
| main(args) |
|
|