Spaces:
Build error
Build error
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| from contextlib import contextmanager | |
| from typing import Tuple, Union | |
| import einops | |
| import numpy as np | |
| import torch | |
| import torchvision | |
| import torchvision.transforms.functional as transforms_F | |
| from matplotlib import pyplot as plt | |
| from cosmos_predict1.diffusion.training.models.extend_model import ExtendDiffusionModel | |
| from cosmos_predict1.utils import log | |
| from cosmos_predict1.utils.easy_io import easy_io | |
| """This file contain functions needed for long video generation, | |
| * function `generate_video_from_batch_with_loop` is used by `single_gpu_sep20` | |
| """ | |
| def switch_config_for_inference(model): | |
| """For extend model inference, we need to make sure the condition_location is set to "first_n" and apply_corruption_to_condition_region is False. | |
| This context manager changes the model configuration to the correct settings for inference, and then restores the original settings when exiting the context. | |
| Args: | |
| model (ExtendDiffusionModel): video generation model | |
| """ | |
| # Store the current condition_location | |
| current_condition_location = model.config.conditioner.video_cond_bool.condition_location | |
| if current_condition_location != "first_n" and current_condition_location != "first_and_last_1": | |
| current_condition_location = "first_n" | |
| current_apply_corruption_to_condition_region = ( | |
| model.config.conditioner.video_cond_bool.apply_corruption_to_condition_region | |
| ) | |
| try: | |
| log.info( | |
| "Change the condition_location to 'first_n' for inference, and apply_corruption_to_condition_region to False" | |
| ) | |
| # Change the condition_location to "first_n" for inference | |
| model.config.conditioner.video_cond_bool.condition_location = current_condition_location | |
| if current_apply_corruption_to_condition_region == "gaussian_blur": | |
| model.config.conditioner.video_cond_bool.apply_corruption_to_condition_region = "clean" | |
| elif current_apply_corruption_to_condition_region == "noise_with_sigma": | |
| model.config.conditioner.video_cond_bool.apply_corruption_to_condition_region = "noise_with_sigma_fixed" | |
| # Yield control back to the calling context | |
| yield | |
| finally: | |
| # Restore the original condition_location after exiting the context | |
| log.info( | |
| f"Restore the original condition_location {current_condition_location}, apply_corruption_to_condition_region {current_apply_corruption_to_condition_region}" | |
| ) | |
| model.config.conditioner.video_cond_bool.condition_location = current_condition_location | |
| model.config.conditioner.video_cond_bool.apply_corruption_to_condition_region = ( | |
| current_apply_corruption_to_condition_region | |
| ) | |
| def visualize_latent_tensor_bcthw(tensor, nrow=1, show_norm=False, save_fig_path=None): | |
| """Debug function to display a latent tensor as a grid of images. | |
| Args: | |
| tensor (torch.Tensor): tensor in shape BCTHW | |
| nrow (int): number of images per row | |
| show_norm (bool): whether to display the norm of the tensor | |
| save_fig_path (str): path to save the visualization | |
| """ | |
| log.info( | |
| f"display latent tensor shape {tensor.shape}, max={tensor.max()}, min={tensor.min()}, mean={tensor.mean()}, std={tensor.std()}" | |
| ) | |
| tensor = tensor.float().cpu().detach() | |
| tensor = einops.rearrange(tensor, "b c (t n) h w -> (b t h) (n w) c", n=nrow) # .numpy() | |
| # display the grid | |
| tensor_mean = tensor.mean(-1) | |
| tensor_norm = tensor.norm(dim=-1) | |
| log.info(f"tensor_norm, tensor_mean {tensor_norm.shape}, {tensor_mean.shape}") | |
| plt.figure(figsize=(20, 20)) | |
| plt.imshow(tensor_mean) | |
| plt.title(f"mean {tensor_mean.mean()}, std {tensor_mean.std()}") | |
| if save_fig_path: | |
| os.makedirs(os.path.dirname(save_fig_path), exist_ok=True) | |
| log.info(f"save to {os.path.abspath(save_fig_path)}") | |
| plt.savefig(save_fig_path, bbox_inches="tight", pad_inches=0) | |
| plt.show() | |
| if show_norm: | |
| plt.figure(figsize=(20, 20)) | |
| plt.imshow(tensor_norm) | |
| plt.show() | |
| def visualize_tensor_bcthw(tensor: torch.Tensor, nrow=4, save_fig_path=None): | |
| """Debug function to display a tensor as a grid of images. | |
| Args: | |
| tensor (torch.Tensor): tensor in shape BCTHW | |
| nrow (int): number of images per row | |
| save_fig_path (str): path to save the visualization | |
| """ | |
| log.info(f"display {tensor.shape}, {tensor.max()}, {tensor.min()}") | |
| assert tensor.max() < 200, f"tensor max {tensor.max()} > 200, the data range is likely wrong" | |
| tensor = tensor.float().cpu().detach() | |
| tensor = einops.rearrange(tensor, "b c t h w -> (b t) c h w") | |
| # use torchvision to save the tensor as a grid of images | |
| grid = torchvision.utils.make_grid(tensor, nrow=nrow) | |
| if save_fig_path is not None: | |
| os.makedirs(os.path.dirname(save_fig_path), exist_ok=True) | |
| log.info(f"save to {os.path.abspath(save_fig_path)}") | |
| torchvision.utils.save_image(tensor, save_fig_path) | |
| # display the grid | |
| plt.figure(figsize=(20, 20)) | |
| plt.imshow(grid.permute(1, 2, 0)) | |
| plt.show() | |
| def compute_num_frames_condition(model: "ExtendDiffusionModel", num_of_latent_overlap: int, downsample_factor=8) -> int: | |
| """This function computes the number of condition pixel frames given the number of latent frames to overlap. | |
| Args: | |
| model (ExtendDiffusionModel): Video generation model | |
| num_of_latent_overlap (int): Number of latent frames to overlap | |
| downsample_factor (int): Downsample factor for temporal reduce | |
| Returns: | |
| int: Number of condition frames in output space | |
| """ | |
| # Access the VAE: use tokenizer.video_vae if it exists, otherwise use tokenizer directly | |
| vae = model.tokenizer.video_vae if hasattr(model.tokenizer, "video_vae") else model.tokenizer | |
| # Check if the VAE is causal (default to True if attribute not found) | |
| if getattr(vae, "is_casual", True): | |
| # For causal model | |
| num_frames_condition = num_of_latent_overlap // vae.latent_chunk_duration * vae.pixel_chunk_duration | |
| if num_of_latent_overlap % vae.latent_chunk_duration == 1: | |
| num_frames_condition += 1 | |
| elif num_of_latent_overlap % vae.latent_chunk_duration > 1: | |
| num_frames_condition += 1 + (num_of_latent_overlap % vae.latent_chunk_duration - 1) * downsample_factor | |
| else: | |
| num_frames_condition = num_of_latent_overlap * downsample_factor | |
| return num_frames_condition | |
| def read_video_or_image_into_frames_BCTHW( | |
| input_path: str, | |
| input_path_format: str = None, | |
| H: int = None, | |
| W: int = None, | |
| normalize: bool = True, | |
| max_frames: int = -1, | |
| also_return_fps: bool = False, | |
| ) -> torch.Tensor: | |
| """Read video or image from file and convert it to tensor. The frames will be normalized to [-1, 1]. | |
| Args: | |
| input_path (str): path to the input video or image, end with .mp4 or .png or .jpg | |
| H (int): height to resize the video | |
| W (int): width to resize the video | |
| Returns: | |
| torch.Tensor: video tensor in shape (1, C, T, H, W), range [-1, 1] | |
| """ | |
| log.info(f"Reading video from {input_path}") | |
| loaded_data = easy_io.load(input_path, file_format=input_path_format, backend_args=None) | |
| if input_path.endswith(".png") or input_path.endswith(".jpg") or input_path.endswith(".jpeg"): | |
| frames = np.array(loaded_data) # HWC, [0,255] | |
| if frames.shape[-1] > 3: # RGBA, set the transparent to white | |
| # Separate the RGB and Alpha channels | |
| rgb_channels = frames[..., :3] | |
| alpha_channel = frames[..., 3] / 255.0 # Normalize alpha channel to [0, 1] | |
| # Create a white background | |
| white_bg = np.ones_like(rgb_channels) * 255 # White background in RGB | |
| # Blend the RGB channels with the white background based on the alpha channel | |
| frames = (rgb_channels * alpha_channel[..., None] + white_bg * (1 - alpha_channel[..., None])).astype( | |
| np.uint8 | |
| ) | |
| frames = [frames] | |
| fps = 0 | |
| else: | |
| frames, meta_data = loaded_data | |
| fps = int(meta_data.get("fps")) | |
| if max_frames != -1: | |
| frames = frames[:max_frames] | |
| input_tensor = np.stack(frames, axis=0) | |
| input_tensor = einops.rearrange(input_tensor, "t h w c -> t c h w") | |
| if normalize: | |
| input_tensor = input_tensor / 128.0 - 1.0 | |
| input_tensor = torch.from_numpy(input_tensor).bfloat16() # TCHW | |
| log.info(f"Raw data shape: {input_tensor.shape}") | |
| if H is not None and W is not None: | |
| input_tensor = transforms_F.resize( | |
| input_tensor, | |
| size=(H, W), # type: ignore | |
| interpolation=transforms_F.InterpolationMode.BICUBIC, | |
| antialias=True, | |
| ) | |
| input_tensor = einops.rearrange(input_tensor, "(b t) c h w -> b c t h w", b=1) | |
| if normalize: | |
| input_tensor = input_tensor.to("cuda") | |
| log.info(f"Load shape {input_tensor.shape} value {input_tensor.min()}, {input_tensor.max()}") | |
| if also_return_fps: | |
| return input_tensor, fps | |
| return input_tensor | |
| def create_condition_latent_from_input_frames( | |
| model: ExtendDiffusionModel, | |
| input_frames: torch.Tensor, | |
| num_frames_condition: int = 25, | |
| ): | |
| """Create condition latent for video generation. It will take the last num_frames_condition frames from the input frames as condition latent. | |
| Args: | |
| model (ExtendDiffusionModel): Video generation model | |
| input_frames (torch.Tensor): Video tensor in shape (1,C,T,H,W), range [-1, 1] | |
| num_frames_condition (int): Number of condition frames | |
| Returns: | |
| torch.Tensor: Condition latent in shape B,C,T,H,W | |
| """ | |
| B, C, T, H, W = input_frames.shape | |
| # Dynamically access the VAE: use tokenizer.video_vae if it exists, otherwise use tokenizer directly | |
| vae = model.tokenizer.video_vae if hasattr(model.tokenizer, "video_vae") else model.tokenizer | |
| num_frames_encode = vae.pixel_chunk_duration # Access pixel_chunk_duration from the VAE | |
| log.info( | |
| f"num_frames_encode not set, set it based on pixel chunk duration and model state shape: {num_frames_encode}" | |
| ) | |
| log.info( | |
| f"Create condition latent from input frames {input_frames.shape}, value {input_frames.min()}, {input_frames.max()}, dtype {input_frames.dtype}" | |
| ) | |
| assert ( | |
| input_frames.shape[2] >= num_frames_condition | |
| ), f"input_frames not enough for condition, require at least {num_frames_condition}, got {input_frames.shape[2]}, {input_frames.shape}" | |
| assert ( | |
| num_frames_encode >= num_frames_condition | |
| ), f"num_frames_encode should be larger than num_frames_condition, got {num_frames_encode}, {num_frames_condition}" | |
| # Put the conditional frames at the beginning of the video, and pad the end with zeros | |
| if model.config.conditioner.video_cond_bool.condition_location == "first_and_last_1": | |
| condition_frames_first = input_frames[:, :, :num_frames_condition] | |
| condition_frames_last = input_frames[:, :, -num_frames_condition:] | |
| padding_frames = condition_frames_first.new_zeros(B, C, num_frames_encode + 1 - 2 * num_frames_condition, H, W) | |
| encode_input_frames = torch.cat([condition_frames_first, padding_frames, condition_frames_last], dim=2) | |
| else: | |
| condition_frames = input_frames[:, :, -num_frames_condition:] | |
| padding_frames = condition_frames.new_zeros(B, C, num_frames_encode - num_frames_condition, H, W) | |
| encode_input_frames = torch.cat([condition_frames, padding_frames], dim=2) | |
| log.info( | |
| f"create latent with input shape {encode_input_frames.shape} including padding {num_frames_encode - num_frames_condition} at the end" | |
| ) | |
| if hasattr(model, "n_views"): | |
| encode_input_frames = einops.rearrange(encode_input_frames, "(B V) C T H W -> B C (V T) H W", V=model.n_views) | |
| if model.config.conditioner.video_cond_bool.condition_location == "first_and_last_1": | |
| latent1 = model.encode(encode_input_frames[:, :, :num_frames_encode]) # BCTHW | |
| latent2 = model.encode(encode_input_frames[:, :, num_frames_encode:]) | |
| latent = torch.cat([latent1, latent2], dim=2) # BCTHW | |
| else: | |
| latent = model.encode(encode_input_frames) | |
| return latent, encode_input_frames | |
| def get_condition_latent( | |
| model: ExtendDiffusionModel, | |
| conditioned_image_or_video_path: str, | |
| num_of_latent_condition: int = 4, | |
| state_shape: list[int] = None, | |
| input_path_format: str = None, | |
| frame_index: int = 0, | |
| frame_stride: int = 1, | |
| ): | |
| if state_shape is None: | |
| state_shape = model.state_shape | |
| if num_of_latent_condition == 0: | |
| log.info("No condition latent needed, return empty latent") | |
| condition_latent = ( | |
| torch.zeros( | |
| [ | |
| 1, | |
| ] | |
| + state_shape | |
| ) | |
| .to(torch.bfloat16) | |
| .cuda() | |
| ) | |
| return condition_latent, None | |
| H, W = ( | |
| state_shape[-2] * model.vae.spatial_compression_factor, | |
| state_shape[-1] * model.vae.spatial_compression_factor, | |
| ) | |
| input_frames = read_video_or_image_into_frames_BCTHW( | |
| conditioned_image_or_video_path, | |
| input_path_format=input_path_format, | |
| H=H, | |
| W=W, | |
| ) | |
| if model.config.conditioner.video_cond_bool.condition_location == "first_and_last_1": | |
| start_frame = frame_index * frame_stride | |
| end_frame = (frame_index + 1) * frame_stride | |
| input_frames = torch.cat( | |
| [input_frames[:, :, start_frame : start_frame + 1], input_frames[:, :, end_frame : end_frame + 1]], dim=2 | |
| ).contiguous() # BCTHW | |
| num_frames_condition = compute_num_frames_condition( | |
| model, num_of_latent_condition, downsample_factor=model.vae.temporal_compression_factor | |
| ) | |
| condition_latent, _ = create_condition_latent_from_input_frames(model, input_frames, num_frames_condition) | |
| condition_latent = condition_latent.to(torch.bfloat16) | |
| return condition_latent, input_frames | |
| def generate_video_from_batch_with_loop( | |
| model: ExtendDiffusionModel, | |
| state_shape: list[int], | |
| is_negative_prompt: bool, | |
| data_batch: dict, | |
| condition_latent: torch.Tensor, | |
| # hyper-parameters for inference | |
| num_of_loops: int, | |
| num_of_latent_overlap_list: list[int], | |
| guidance: float, | |
| num_steps: int, | |
| seed: int, | |
| add_input_frames_guidance: bool = False, | |
| augment_sigma_list: list[float] = None, | |
| data_batch_list: Union[None, list[dict]] = None, | |
| visualize: bool = False, | |
| save_fig_path: str = None, | |
| skip_reencode: int = 0, | |
| return_noise: bool = False, | |
| ) -> Tuple[np.array, list, list, torch.Tensor] | Tuple[np.array, list, list, torch.Tensor, torch.Tensor]: | |
| """Generate video with loop, given data batch. The condition latent will be updated at each loop. | |
| Args: | |
| model (ExtendDiffusionModel) | |
| state_shape (list): shape of the state tensor | |
| is_negative_prompt (bool): whether to use negative prompt | |
| data_batch (dict): data batch for video generation | |
| condition_latent (torch.Tensor): condition latent in shape BCTHW | |
| num_of_loops (int): number of loops to generate video | |
| num_of_latent_overlap_list (list[int]): list number of latent frames to overlap between clips, different clips can have different overlap | |
| guidance (float): The guidance scale to use during sample generation; defaults to 5.0. | |
| num_steps (int): number of steps for diffusion sampling | |
| seed (int): random seed for sampling | |
| add_input_frames_guidance (bool): whether to add image guidance, default is False | |
| augment_sigma_list (list): list of sigma value for the condition corruption at different clip, used when apply_corruption_to_condition_region is "noise_with_sigma" or "noise_with_sigma_fixed". default is None | |
| data_batch_list (list): list of data batch for video generation, used when num_of_loops >= 1, to support multiple prompts in auto-regressive generation. default is None | |
| visualize (bool): whether to visualize the latent and grid, default is False | |
| save_fig_path (str): path to save the visualization, default is None | |
| skip_reencode (int): whether to skip re-encode the input frames, default is 0 | |
| return_noise (bool): whether to return the initial noise used for sampling, used for ODE pairs generation. Default is False | |
| Returns: | |
| np.array: generated video in shape THWC, range [0, 255] | |
| list: list of condition latent, each in shape BCTHW | |
| list: list of sample latent, each in shape BCTHW | |
| torch.Tensor: initial noise used for sampling, shape BCTHW (if return_noise is True) | |
| """ | |
| if data_batch_list is None: | |
| data_batch_list = [data_batch for _ in range(num_of_loops)] | |
| if visualize: | |
| assert save_fig_path is not None, "save_fig_path should be set when visualize is True" | |
| # Generate video with loop | |
| condition_latent_list = [] | |
| decode_latent_list = [] # list collect the latent token to be decoded at the end | |
| sample_latent = [] | |
| grid_list = [] | |
| augment_sigma_list = ( | |
| model.config.conditioner.video_cond_bool.apply_corruption_to_condition_region_sigma_value | |
| if augment_sigma_list is None | |
| else augment_sigma_list | |
| ) | |
| for i in range(num_of_loops): | |
| num_of_latent_overlap_i = num_of_latent_overlap_list[i] | |
| num_of_latent_overlap_i_plus_1 = ( | |
| num_of_latent_overlap_list[i + 1] | |
| if i + 1 < len(num_of_latent_overlap_list) | |
| else num_of_latent_overlap_list[-1] | |
| ) | |
| if condition_latent.shape[2] < state_shape[1]: | |
| # Padding condition latent to state shape | |
| log.info(f"Padding condition latent {condition_latent.shape} to state shape {state_shape}") | |
| b, c, t, h, w = condition_latent.shape | |
| condition_latent = torch.cat( | |
| [ | |
| condition_latent, | |
| condition_latent.new_zeros(b, c, state_shape[1] - t, h, w), | |
| ], | |
| dim=2, | |
| ).contiguous() | |
| log.info(f"after padding, condition latent shape {condition_latent.shape}") | |
| log.info(f"Generate video loop {i} / {num_of_loops}") | |
| if visualize: | |
| log.info(f"Visualize condition latent {i}") | |
| visualize_latent_tensor_bcthw( | |
| condition_latent[:, :, :4].float(), | |
| nrow=4, | |
| save_fig_path=os.path.join(save_fig_path, f"loop_{i:02d}_condition_latent_first_4.png"), | |
| ) # BCTHW | |
| condition_latent_list.append(condition_latent) | |
| if i < len(augment_sigma_list): | |
| condition_video_augment_sigma_in_inference = augment_sigma_list[i] | |
| log.info(f"condition_video_augment_sigma_in_inference {condition_video_augment_sigma_in_inference}") | |
| else: | |
| condition_video_augment_sigma_in_inference = augment_sigma_list[-1] | |
| assert not add_input_frames_guidance, "add_input_frames_guidance should be False, not supported" | |
| sample = model.generate_samples_from_batch( | |
| data_batch_list[i], | |
| guidance=guidance, | |
| state_shape=state_shape, | |
| num_steps=num_steps, | |
| is_negative_prompt=is_negative_prompt, | |
| seed=seed + i, | |
| condition_latent=condition_latent, | |
| num_condition_t=num_of_latent_overlap_i, | |
| condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, | |
| return_noise=return_noise, | |
| ) | |
| if return_noise: | |
| sample, noise = sample | |
| if visualize: | |
| log.info(f"Visualize sampled latent {i} 4-8 frames") | |
| visualize_latent_tensor_bcthw( | |
| sample[:, :, 4:8].float(), | |
| nrow=4, | |
| save_fig_path=os.path.join(save_fig_path, f"loop_{i:02d}_sample_latent_last_4.png"), | |
| ) # BCTHW | |
| diff_between_sample_and_condition = (sample - condition_latent)[:, :, :num_of_latent_overlap_i] | |
| log.info( | |
| f"Visualize diff between sample and condition latent {i} first 4 frames {diff_between_sample_and_condition.mean()}" | |
| ) | |
| sample_latent.append(sample) | |
| T = condition_latent.shape[2] | |
| assert num_of_latent_overlap_i <= T, f"num_of_latent_overlap should be < T, get {num_of_latent_overlap_i}, {T}" | |
| if model.config.conditioner.video_cond_bool.sample_tokens_start_from_p_or_i: | |
| assert skip_reencode, "skip_reencode should be turned on when sample_tokens_start_from_p_or_i is True" | |
| if i == 0: | |
| decode_latent_list.append(sample) | |
| else: | |
| decode_latent_list.append(sample[:, :, num_of_latent_overlap_i:]) | |
| else: | |
| # Interpolator mode. Decode the first and last as an image. | |
| if model.config.conditioner.video_cond_bool.condition_location == "first_and_last_1": | |
| grid_BCTHW_1 = (1.0 + model.decode(sample[:, :, :-1, ...])).clamp(0, 2) / 2 # [B, 3, T-1, H, W], [0, 1] | |
| grid_BCTHW_2 = (1.0 + model.decode(sample[:, :, -1:, ...])).clamp(0, 2) / 2 # [B, 3, 1, H, W], [0, 1] | |
| grid_BCTHW = torch.cat([grid_BCTHW_1, grid_BCTHW_2], dim=2) # [B, 3, T, H, W], [0, 1] | |
| else: | |
| grid_BCTHW = (1.0 + model.decode(sample)).clamp(0, 2) / 2 # [B, 3, T, H, W], [0, 1] | |
| if visualize: | |
| log.info(f"Visualize grid {i}") | |
| visualize_tensor_bcthw( | |
| grid_BCTHW.float(), nrow=5, save_fig_path=os.path.join(save_fig_path, f"loop_{i:02d}_grid.png") | |
| ) | |
| grid_np_THWC = ( | |
| (grid_BCTHW[0].permute(1, 2, 3, 0) * 255).to(torch.uint8).cpu().numpy().astype(np.uint8) | |
| ) # THW3, range [0, 255] | |
| # Post-process the output: cut the conditional frames from the output if it's not the first loop | |
| num_cond_frames = compute_num_frames_condition( | |
| model, num_of_latent_overlap_i_plus_1, downsample_factor=model.tokenizer.temporal_compression_factor | |
| ) | |
| if i == 0: | |
| new_grid_np_THWC = grid_np_THWC # First output, dont cut the conditional frames | |
| else: | |
| new_grid_np_THWC = grid_np_THWC[ | |
| num_cond_frames: | |
| ] # Remove the conditional frames from the output, since it's overlapped with previous loop | |
| grid_list.append(new_grid_np_THWC) | |
| # Prepare the next loop: re-compute the condition latent | |
| if hasattr(model, "n_views"): | |
| grid_BCTHW = einops.rearrange(grid_BCTHW, "B C (V T) H W -> (B V) C T H W", V=model.n_views) | |
| condition_frame_input = grid_BCTHW[:, :, -num_cond_frames:] * 2 - 1 # BCTHW, range [0, 1] to [-1, 1] | |
| if skip_reencode: | |
| # Use the last num_of_latent_overlap latent token as condition latent | |
| log.info(f"Skip re-encode the condition frames, use the last {num_of_latent_overlap_i_plus_1} latent token") | |
| condition_latent = sample[:, :, -num_of_latent_overlap_i_plus_1:] | |
| else: | |
| # Re-encode the condition frames to get the new condition latent | |
| condition_latent, _ = create_condition_latent_from_input_frames( | |
| model, condition_frame_input, num_frames_condition=num_cond_frames | |
| ) # BCTHW | |
| condition_latent = condition_latent.to(torch.bfloat16) | |
| # save videos | |
| if model.config.conditioner.video_cond_bool.sample_tokens_start_from_p_or_i: | |
| # decode all video together | |
| decode_latent_list = torch.cat(decode_latent_list, dim=2) | |
| grid_BCTHW = (1.0 + model.decode(decode_latent_list)).clamp(0, 2) / 2 # [B, 3, T, H, W], [0, 1] | |
| video_THWC = ( | |
| (grid_BCTHW[0].permute(1, 2, 3, 0) * 255).to(torch.uint8).cpu().numpy().astype(np.uint8) | |
| ) # THW3, range [0, 255] | |
| else: | |
| video_THWC = np.concatenate(grid_list, axis=0) # THW3, range [0, 255] | |
| if return_noise: | |
| return video_THWC, condition_latent_list, sample_latent, noise | |
| return video_THWC, condition_latent_list, sample_latent | |