| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import gc |
| import importlib |
| import math |
| import os |
| from typing import List |
|
|
| import torch |
| import torchvision |
| from huggingface_hub import snapshot_download |
|
|
| from .ar_config_inference_inference_config import DiffusionDecoderSamplingConfig |
| from .ar_diffusion_decoder_inference import diffusion_decoder_process_tokens |
| from .ar_diffusion_decoder_model import LatentDiffusionDecoderModel |
| from .df_inference_inference_utils import ( |
| load_network_model, |
| load_tokenizer_model, |
| skip_init_linear, |
| ) |
| from .log import log |
| from .config_helper import get_config_module, override |
|
|
| TOKENIZER_COMPRESSION_FACTOR = [8, 16, 16] |
| DATA_RESOLUTION_SUPPORTED = [640, 1024] |
| NUM_CONTEXT_FRAMES = 33 |
|
|
|
|
| def resize_input(video: torch.Tensor, resolution: list[int]): |
| r""" |
| Function to perform aspect ratio preserving resizing and center cropping. |
| This is needed to make the video into target resolution. |
| Args: |
| video (torch.Tensor): Input video tensor |
| resolution (list[int]): Data resolution |
| Returns: |
| Cropped video |
| """ |
|
|
| orig_h, orig_w = video.shape[2], video.shape[3] |
| target_h, target_w = resolution |
|
|
| scaling_ratio = max((target_w / orig_w), (target_h / orig_h)) |
| resizing_shape = (int(math.ceil(scaling_ratio * orig_h)), int(math.ceil(scaling_ratio * orig_w))) |
| video_resized = torchvision.transforms.functional.resize(video, resizing_shape) |
| video_cropped = torchvision.transforms.functional.center_crop(video_resized, resolution) |
| return video_cropped |
|
|
|
|
| def read_input_videos(input_video: str) -> torch.tensor: |
| """Utility to read the input video and return a torch tensor |
| |
| Args: |
| input_video (str): A path to .mp4 file |
| data_resolution (list, optional): The . Defaults to [640, 1024]. |
| |
| Returns: |
| A torch tensor of the video |
| """ |
| video, _, _ = torchvision.io.read_video(input_video) |
| video = video.float() / 255.0 |
| video = video * 2 - 1 |
|
|
| if video.shape[0] > NUM_CONTEXT_FRAMES: |
| video = video[0:NUM_CONTEXT_FRAMES, :, :, :] |
| else: |
| log.info(f"Video doesn't have {NUM_CONTEXT_FRAMES} frames. Padding the video with the last frame.") |
| |
| nframes_in_video = video.shape[0] |
| video = torch.cat( |
| (video, video[-1, :, :, :].unsqueeze(0).repeat(NUM_CONTEXT_FRAMES - nframes_in_video, 1, 1, 1)), |
| dim=0, |
| ) |
|
|
| video = video[0:NUM_CONTEXT_FRAMES, :, :, :] |
| video = video.permute(0, 3, 1, 2) |
| video = resize_input(video, DATA_RESOLUTION_SUPPORTED) |
| return video.transpose(0, 1).unsqueeze(0) |
|
|
|
|
| def run_diffusion_decoder_model(indices_tensor_cur_batch: List[torch.Tensor], out_videos_cur_batch): |
| """Run a 7b diffusion model to enhance generation output |
| |
| Args: |
| indices_tensor_cur_batch (List[torch.Tensor]): The index tensor(i.e) prompt + generation tokens |
| out_videos_cur_batch (torch.Tensor): The output decoded video of shape [bs, 3, 33, 640, 1024] |
| """ |
| diffusion_decoder_ckpt_path = snapshot_download("nvidia/Cosmos-1.0-Diffusion-7B-Decoder-DV8x16x16ToCV8x8x8") |
| dd_tokenizer_dir = snapshot_download("nvidia/Cosmos-1.0-Tokenizer-CV8x8x8") |
| tokenizer_corruptor_dir = snapshot_download("nvidia/Cosmos-1.0-Tokenizer-DV8x16x16") |
|
|
| diffusion_decoder_model = load_model_by_config( |
| config_job_name="DD_FT_7Bv1_003_002_tokenizer888_spatch2_discrete_cond_on_token", |
| config_file="cosmos1/models/autoregressive/diffusion_decoder/config/config_latent_diffusion_decoder.py", |
| model_class=LatentDiffusionDecoderModel, |
| encoder_path=os.path.join(tokenizer_corruptor_dir, "encoder.jit"), |
| decoder_path=os.path.join(tokenizer_corruptor_dir, "decoder.jit"), |
| ) |
| load_network_model(diffusion_decoder_model, os.path.join(diffusion_decoder_ckpt_path, "model.pt")) |
| load_tokenizer_model(diffusion_decoder_model, dd_tokenizer_dir) |
|
|
| generic_prompt = dict() |
| aux_vars = torch.load(os.path.join(diffusion_decoder_ckpt_path, "aux_vars.pt"), weights_only=True) |
| generic_prompt["context"] = aux_vars["context"].cuda() |
| generic_prompt["context_mask"] = aux_vars["context_mask"].cuda() |
|
|
| output_video = diffusion_decoder_process_tokens( |
| model=diffusion_decoder_model, |
| indices_tensor=indices_tensor_cur_batch, |
| dd_sampling_config=DiffusionDecoderSamplingConfig(), |
| original_video_example=out_videos_cur_batch[0], |
| t5_emb_batch=[generic_prompt["context"]], |
| ) |
|
|
| del diffusion_decoder_model |
| diffusion_decoder_model = None |
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
| return output_video |
|
|
|
|
| def load_model_by_config( |
| config_job_name, |
| config_file="projects/cosmos_video/config/config.py", |
| model_class=LatentDiffusionDecoderModel, |
| encoder_path=None, |
| decoder_path=None, |
| ): |
| config_module = get_config_module(config_file) |
| config = importlib.import_module(config_module).make_config() |
|
|
| config = override(config, ["--", f"experiment={config_job_name}"]) |
|
|
| |
| config.validate() |
| |
| config.freeze() |
| if encoder_path: |
| config.model.tokenizer_corruptor["enc_fp"] = encoder_path |
| if decoder_path: |
| config.model.tokenizer_corruptor["dec_fp"] = decoder_path |
| |
| with skip_init_linear(): |
| model = model_class(config.model) |
| return model |
|
|