Spaces:
Build error
Build error
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from dataclasses import dataclass | |
| from statistics import NormalDist | |
| from typing import Callable, Dict, Optional, Tuple, Union | |
| import numpy as np | |
| import torch | |
| from einops import rearrange | |
| from megatron.core import parallel_state | |
| from torch import Tensor | |
| from cosmos_predict1.diffusion.config.base.conditioner import VideoCondBoolConfig | |
| from cosmos_predict1.diffusion.functional.batch_ops import batch_mul | |
| from cosmos_predict1.diffusion.training.conditioner import DataType, VideoExtendCondition | |
| from cosmos_predict1.diffusion.training.context_parallel import cat_outputs_cp, split_inputs_cp | |
| from cosmos_predict1.diffusion.training.models.model import DiffusionModel as BaseModel | |
| from cosmos_predict1.diffusion.training.models.model import _broadcast, broadcast_condition | |
| from cosmos_predict1.diffusion.training.models.model_image import diffusion_fsdp_class_decorator | |
| from cosmos_predict1.utils import log, misc | |
| class VideoDenoisePrediction: | |
| x0: torch.Tensor # clean data prediction | |
| eps: Optional[torch.Tensor] = None # noise prediction | |
| logvar: Optional[torch.Tensor] = None # log variance of noise prediction, can be used a confidence / uncertainty | |
| net_in: Optional[torch.Tensor] = None # input to the network | |
| net_x0_pred: Optional[torch.Tensor] = None # prediction of x0 from the network | |
| xt: Optional[torch.Tensor] = None # input to the network, before muliply with c_in | |
| x0_pred_replaced: Optional[torch.Tensor] = None # x0 prediction with condition region replaced by gt_latent | |
| def normalize_condition_latent(condition_latent): | |
| """Normalize the condition latent tensor to have zero mean and unit variance | |
| Args: | |
| condition_latent (torch.Tensor): latent tensor in shape B,C,T,H,W | |
| """ | |
| condition_latent_2D = rearrange(condition_latent, "b c t h w -> b c t (h w)") | |
| mean = condition_latent_2D.mean(dim=-1) | |
| std = condition_latent_2D.std(dim=-1) | |
| # bct -> bct11 | |
| mean = mean.unsqueeze(-1).unsqueeze(-1) | |
| std = std.unsqueeze(-1).unsqueeze(-1) | |
| condition_latent = (condition_latent - mean) / std | |
| return condition_latent | |
| class ExtendDiffusionModel(BaseModel): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.is_extend_model = True | |
| def get_data_and_condition( | |
| self, data_batch: dict[str, Tensor], num_condition_t: Union[int, None] = None | |
| ) -> Tuple[Tensor, VideoExtendCondition]: | |
| raw_state, latent_state, condition = super().get_data_and_condition(data_batch) | |
| if condition.data_type == DataType.VIDEO: | |
| if self.config.conditioner.video_cond_bool.sample_tokens_start_from_p_or_i: | |
| latent_state = self.sample_tokens_start_from_p_or_i(latent_state) | |
| condition = self.add_condition_video_indicator_and_video_input_mask( | |
| latent_state, condition, num_condition_t=num_condition_t | |
| ) | |
| if self.config.conditioner.video_cond_bool.add_pose_condition: | |
| condition = self.add_condition_pose(data_batch, condition) | |
| log.debug(f"condition.data_type {condition.data_type}") | |
| return raw_state, latent_state, condition | |
| def draw_augment_sigma_and_epsilon( | |
| self, size: int, condition: VideoExtendCondition, p_mean: float, p_std: float, multiplier: float | |
| ) -> Tensor: | |
| is_video_batch = condition.data_type == DataType.VIDEO | |
| del condition | |
| batch_size = size[0] | |
| epsilon = torch.randn(size, **self.tensor_kwargs) | |
| gaussian_dist = NormalDist(mu=p_mean, sigma=p_std) | |
| cdf_vals = np.random.uniform(size=(batch_size)) | |
| samples_interval_gaussian = [gaussian_dist.inv_cdf(cdf_val) for cdf_val in cdf_vals] | |
| log_sigma = torch.tensor(samples_interval_gaussian, device="cuda") | |
| sigma_B = torch.exp(log_sigma).to(**self.tensor_kwargs) | |
| sigma_B = _broadcast(sigma_B * multiplier, to_tp=True, to_cp=is_video_batch) | |
| epsilon = _broadcast(epsilon, to_tp=True, to_cp=is_video_batch) | |
| return sigma_B, epsilon | |
| def augment_conditional_latent_frames( | |
| self, | |
| condition: VideoExtendCondition, | |
| cfg_video_cond_bool: VideoCondBoolConfig, | |
| gt_latent: Tensor, | |
| condition_video_augment_sigma_in_inference: float = 0.001, | |
| sigma: Tensor = None, | |
| seed_inference: int = 1, | |
| ) -> Union[VideoExtendCondition, Tensor]: | |
| """This function is used to augment the condition input with noise | |
| Args: | |
| condition (VideoExtendCondition): condition object | |
| condition_video_indicator: binary tensor indicating the region is condition(value=1) or generation(value=0). Bx1xTx1x1 tensor. | |
| condition_video_input_mask: input mask for the network input, indicating the condition region. B,1,T,H,W tensor. will be concat with the input for the network. | |
| cfg_video_cond_bool (VideoCondBoolConfig): video condition bool config | |
| gt_latent (Tensor): ground truth latent tensor in shape B,C,T,H,W | |
| condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference | |
| sigma (Tensor): noise level for the generation region | |
| Returns: | |
| VideoExtendCondition: updated condition object | |
| condition_video_augment_sigma: sigma for the condition region, feed to the network | |
| augment_latent (Tensor): augmented latent tensor in shape B,C,T,H,W | |
| """ | |
| if cfg_video_cond_bool.apply_corruption_to_condition_region == "noise_with_sigma": | |
| # Training only, sample sigma for the condition region | |
| augment_sigma, _ = self.draw_augment_sigma_and_epsilon( | |
| gt_latent.shape, | |
| condition, | |
| cfg_video_cond_bool.augment_sigma_sample_p_mean, | |
| cfg_video_cond_bool.augment_sigma_sample_p_std, | |
| cfg_video_cond_bool.augment_sigma_sample_multiplier, | |
| ) | |
| noise = torch.randn(*gt_latent.shape, **self.tensor_kwargs) | |
| elif cfg_video_cond_bool.apply_corruption_to_condition_region == "noise_with_sigma_fixed": | |
| # Inference only, use fixed sigma for the condition region | |
| log.debug( | |
| f"condition_video_augment_sigma_in_inference={condition_video_augment_sigma_in_inference}, sigma={sigma.flatten()[0]}" | |
| ) | |
| assert ( | |
| condition_video_augment_sigma_in_inference is not None | |
| ), "condition_video_augment_sigma_in_inference should be provided" | |
| augment_sigma = condition_video_augment_sigma_in_inference | |
| if augment_sigma >= sigma.flatten()[0]: | |
| # This is a inference trick! If the sampling sigma is smaller than the augment sigma, we will start denoising the condition region together. | |
| # This is achieved by setting all region as `generation`, i.e. value=0 | |
| log.debug("augment_sigma larger than sigma or other frame, remove condition") | |
| condition.condition_video_indicator = condition.condition_video_indicator * 0 | |
| augment_sigma = torch.tensor([augment_sigma], **self.tensor_kwargs) | |
| # Inference, use fixed seed | |
| noise = misc.arch_invariant_rand( | |
| gt_latent.shape, | |
| torch.float32, | |
| self.tensor_kwargs["device"], | |
| seed_inference, | |
| ) | |
| else: | |
| raise ValueError(f"does not support {cfg_video_cond_bool.apply_corruption_to_condition_region}") | |
| # Now apply the augment_sigma to the gt_latent | |
| augment_latent = gt_latent + noise * augment_sigma.view(-1, 1, 1, 1, 1) | |
| _, _, c_in_augment, c_noise_augment = self.scaling(sigma=augment_sigma) | |
| if cfg_video_cond_bool.condition_on_augment_sigma: # model takes augment_sigma as input | |
| if condition.condition_video_indicator.sum() > 0: # has condition frames | |
| condition.condition_video_augment_sigma = c_noise_augment | |
| else: # no condition frames | |
| condition.condition_video_augment_sigma = torch.zeros_like(c_noise_augment) | |
| # Multiply the whole latent with c_in_augment | |
| augment_latent_cin = batch_mul(augment_latent, c_in_augment) | |
| # Since the whole latent will multiply with c_in later, we devide the value to cancel the effect | |
| _, _, c_in, _ = self.scaling(sigma=sigma) | |
| augment_latent_cin = batch_mul(augment_latent_cin, 1 / c_in) | |
| return condition, augment_latent_cin | |
| def drop_out_condition_region( | |
| self, augment_latent: Tensor, noise_x: Tensor, cfg_video_cond_bool: VideoCondBoolConfig | |
| ) -> Tensor: | |
| """Use for CFG on input frames, we drop out the conditional region | |
| There are two option: | |
| 1. when we dropout, we set the region to be zero | |
| 2. when we dropout, we set the region to be noise_x | |
| """ | |
| # Unconditional case, use for cfg | |
| if cfg_video_cond_bool.cfg_unconditional_type == "zero_condition_region_condition_mask": | |
| # Set the condition location input to be zero | |
| augment_latent_drop = torch.zeros_like(augment_latent) | |
| elif cfg_video_cond_bool.cfg_unconditional_type == "noise_x_condition_region": | |
| # Set the condition location input to be noise_x, i.e., same as base model training | |
| augment_latent_drop = noise_x | |
| else: | |
| raise NotImplementedError( | |
| f"cfg_unconditional_type {cfg_video_cond_bool.cfg_unconditional_type} not implemented" | |
| ) | |
| return augment_latent_drop | |
| def denoise( | |
| self, | |
| noise_x: Tensor, | |
| sigma: Tensor, | |
| condition: VideoExtendCondition, | |
| condition_video_augment_sigma_in_inference: float = 0.001, | |
| seed_inference: int = 1, | |
| ) -> VideoDenoisePrediction: | |
| """ | |
| Denoise the noisy input tensor. | |
| Args: | |
| noise_x (Tensor): Noisy input tensor. | |
| sigma (Tensor): Noise level. | |
| condition (VideoExtendCondition): Condition for denoising. | |
| condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference | |
| Returns: | |
| Tensor: Denoised output tensor. | |
| """ | |
| if condition.data_type == DataType.IMAGE: | |
| pred = super().denoise(noise_x, sigma, condition) | |
| log.debug(f"hit image denoise, noise_x shape {noise_x.shape}, sigma shape {sigma.shape}", rank0_only=False) | |
| return VideoDenoisePrediction( | |
| x0=pred.x0, | |
| eps=pred.eps, | |
| logvar=pred.logvar, | |
| xt=noise_x, | |
| ) | |
| else: | |
| assert ( | |
| condition.gt_latent is not None | |
| ), f"find None gt_latent in condition, likely didn't call self.add_condition_video_indicator_and_video_input_mask when preparing the condition or this is a image batch but condition.data_type is wrong, get {noise_x.shape}" | |
| gt_latent = condition.gt_latent | |
| cfg_video_cond_bool: VideoCondBoolConfig = self.config.conditioner.video_cond_bool | |
| condition_latent = gt_latent | |
| if cfg_video_cond_bool.normalize_condition_latent: | |
| condition_latent = normalize_condition_latent(condition_latent) | |
| # Augment the latent with different sigma value, and add the augment_sigma to the condition object if needed | |
| condition, augment_latent = self.augment_conditional_latent_frames( | |
| condition, | |
| cfg_video_cond_bool, | |
| condition_latent, | |
| condition_video_augment_sigma_in_inference, | |
| sigma, | |
| seed_inference=seed_inference, | |
| ) | |
| condition_video_indicator = condition.condition_video_indicator # [B, 1, T, 1, 1] | |
| if parallel_state.get_context_parallel_world_size() > 1: | |
| cp_group = parallel_state.get_context_parallel_group() | |
| condition_video_indicator = split_inputs_cp(condition_video_indicator, seq_dim=2, cp_group=cp_group) | |
| augment_latent = split_inputs_cp(augment_latent, seq_dim=2, cp_group=cp_group) | |
| gt_latent = split_inputs_cp(gt_latent, seq_dim=2, cp_group=cp_group) | |
| if not condition.video_cond_bool: | |
| # Unconditional case, drop out the condition region | |
| augment_latent = self.drop_out_condition_region(augment_latent, noise_x, cfg_video_cond_bool) | |
| # Compose the model input with condition region (augment_latent) and generation region (noise_x) | |
| new_noise_xt = condition_video_indicator * augment_latent + (1 - condition_video_indicator) * noise_x | |
| # Call the abse model | |
| denoise_pred = super().denoise(new_noise_xt, sigma, condition) | |
| x0_pred_replaced = condition_video_indicator * gt_latent + (1 - condition_video_indicator) * denoise_pred.x0 | |
| if cfg_video_cond_bool.compute_loss_for_condition_region: | |
| # We also denoise the conditional region | |
| x0_pred = denoise_pred.x0 | |
| else: | |
| x0_pred = x0_pred_replaced | |
| return VideoDenoisePrediction( | |
| x0=x0_pred, | |
| eps=batch_mul(noise_x - x0_pred, 1.0 / sigma), | |
| logvar=denoise_pred.logvar, | |
| net_in=batch_mul(1.0 / torch.sqrt(self.sigma_data**2 + sigma**2), new_noise_xt), | |
| net_x0_pred=denoise_pred.x0, | |
| xt=new_noise_xt, | |
| x0_pred_replaced=x0_pred_replaced, | |
| ) | |
| def generate_samples_from_batch( | |
| self, | |
| data_batch: Dict, | |
| guidance: float = 1.5, | |
| seed: int = 1, | |
| state_shape: Tuple | None = None, | |
| n_sample: int | None = None, | |
| is_negative_prompt: bool = False, | |
| num_steps: int = 35, | |
| condition_latent: Union[torch.Tensor, None] = None, | |
| num_condition_t: Union[int, None] = None, | |
| condition_video_augment_sigma_in_inference: float = None, | |
| add_input_frames_guidance: bool = False, | |
| return_noise: bool = False, | |
| ) -> Tensor | Tuple[Tensor, Tensor]: | |
| """ | |
| Generate samples from the batch. Based on given batch, it will automatically determine whether to generate image or video samples. | |
| Different from the base model, this function support condition latent as input, it will create a differnt x0_fn if condition latent is given. | |
| If this feature is stablized, we could consider to move this function to the base model. | |
| Args: | |
| condition_latent (Optional[torch.Tensor]): latent tensor in shape B,C,T,H,W as condition to generate video. | |
| num_condition_t (Optional[int]): number of condition latent T, if None, will use the whole first half | |
| add_input_frames_guidance (bool): add guidance to the input frames, used for cfg on input frames | |
| return_noise (bool): return the initial noise or not, used for ODE pairs generation | |
| """ | |
| self._normalize_video_databatch_inplace(data_batch) | |
| self._augment_image_dim_inplace(data_batch) | |
| is_image_batch = self.is_image_batch(data_batch) | |
| if is_image_batch: | |
| log.debug("image batch, call base model generate_samples_from_batch") | |
| return super().generate_samples_from_batch( | |
| data_batch, | |
| guidance=guidance, | |
| seed=seed, | |
| state_shape=state_shape, | |
| n_sample=n_sample, | |
| is_negative_prompt=is_negative_prompt, | |
| num_steps=num_steps, | |
| ) | |
| if n_sample is None: | |
| input_key = self.input_image_key if is_image_batch else self.input_data_key | |
| n_sample = data_batch[input_key].shape[0] | |
| if state_shape is None: | |
| if is_image_batch: | |
| state_shape = (self.state_shape[0], 1, *self.state_shape[2:]) # C,T,H,W | |
| else: | |
| log.debug(f"Default Video state shape is used. {self.state_shape}") | |
| state_shape = self.state_shape | |
| assert condition_latent is not None, "condition_latent should be provided" | |
| x0_fn = self.get_x0_fn_from_batch_with_condition_latent( | |
| data_batch, | |
| guidance, | |
| is_negative_prompt=is_negative_prompt, | |
| condition_latent=condition_latent, | |
| num_condition_t=num_condition_t, | |
| condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, | |
| add_input_frames_guidance=add_input_frames_guidance, | |
| seed_inference=seed, # Use for noise of augment sigma | |
| ) | |
| x_sigma_max = ( | |
| misc.arch_invariant_rand( | |
| (n_sample,) + tuple(state_shape), torch.float32, self.tensor_kwargs["device"], seed | |
| ) | |
| * self.sde.sigma_max | |
| ) | |
| if self.net.is_context_parallel_enabled: | |
| x_sigma_max = split_inputs_cp(x=x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) | |
| samples = self.sampler(x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=self.sde.sigma_max) | |
| if self.net.is_context_parallel_enabled: | |
| samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) | |
| if return_noise: | |
| if self.net.is_context_parallel_enabled: | |
| x_sigma_max = cat_outputs_cp(x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) | |
| return samples, x_sigma_max / self.sde.sigma_max | |
| return samples | |
| def get_x0_fn_from_batch_with_condition_latent( | |
| self, | |
| data_batch: Dict, | |
| guidance: float = 1.5, | |
| is_negative_prompt: bool = False, | |
| condition_latent: torch.Tensor = None, | |
| num_condition_t: Union[int, None] = None, | |
| condition_video_augment_sigma_in_inference: float = None, | |
| add_input_frames_guidance: bool = False, | |
| seed_inference: int = 1, | |
| ) -> Callable: | |
| """ | |
| Generates a callable function `x0_fn` based on the provided data batch and guidance factor. | |
| Different from the base model, this function support condition latent as input, it will add the condition information into the condition and uncondition object. | |
| This function first processes the input data batch through a conditioning workflow (`conditioner`) to obtain conditioned and unconditioned states. It then defines a nested function `x0_fn` which applies a denoising operation on an input `noise_x` at a given noise level `sigma` using both the conditioned and unconditioned states. | |
| Args: | |
| - data_batch (Dict): A batch of data used for conditioning. The format and content of this dictionary should align with the expectations of the `self.conditioner` | |
| - guidance (float, optional): A scalar value that modulates the influence of the conditioned state relative to the unconditioned state in the output. Defaults to 1.5. | |
| - is_negative_prompt (bool): use negative prompt t5 in uncondition if true | |
| - condition_latent (torch.Tensor): latent tensor in shape B,C,T,H,W as condition to generate video. | |
| - num_condition_t (int): number of condition latent T, used in inference to decide the condition region and config.conditioner.video_cond_bool.condition_location == "first_n" | |
| - condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference | |
| - add_input_frames_guidance (bool): add guidance to the input frames, used for cfg on input frames | |
| Returns: | |
| - Callable: A function `x0_fn(noise_x, sigma)` that takes two arguments, `noise_x` and `sigma`, and return x0 predictoin | |
| The returned function is suitable for use in scenarios where a denoised state is required based on both conditioned and unconditioned inputs, with an adjustable level of guidance influence. | |
| """ | |
| if is_negative_prompt: | |
| condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) | |
| else: | |
| condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) | |
| condition.video_cond_bool = True | |
| condition = self.add_condition_video_indicator_and_video_input_mask( | |
| condition_latent, condition, num_condition_t | |
| ) | |
| if self.config.conditioner.video_cond_bool.add_pose_condition: | |
| condition = self.add_condition_pose(data_batch, condition) | |
| uncondition.video_cond_bool = False if add_input_frames_guidance else True | |
| uncondition = self.add_condition_video_indicator_and_video_input_mask( | |
| condition_latent, uncondition, num_condition_t | |
| ) | |
| if self.config.conditioner.video_cond_bool.add_pose_condition: | |
| uncondition = self.add_condition_pose(data_batch, uncondition) | |
| to_cp = self.net.is_context_parallel_enabled | |
| # For inference, check if parallel_state is initialized | |
| if parallel_state.is_initialized(): | |
| condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp) | |
| uncondition = broadcast_condition(uncondition, to_tp=True, to_cp=to_cp) | |
| else: | |
| assert not to_cp, "parallel_state is not initialized, context parallel should be turned off." | |
| def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: | |
| cond_x0 = self.denoise( | |
| noise_x, | |
| sigma, | |
| condition, | |
| condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, | |
| seed_inference=seed_inference, | |
| ).x0_pred_replaced | |
| uncond_x0 = self.denoise( | |
| noise_x, | |
| sigma, | |
| uncondition, | |
| condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, | |
| seed_inference=seed_inference, | |
| ).x0_pred_replaced | |
| return cond_x0 + guidance * (cond_x0 - uncond_x0) | |
| return x0_fn | |
| def add_condition_video_indicator_and_video_input_mask( | |
| self, latent_state: torch.Tensor, condition: VideoExtendCondition, num_condition_t: Union[int, None] = None | |
| ) -> VideoExtendCondition: | |
| """Add condition_video_indicator and condition_video_input_mask to the condition object for video conditioning. | |
| condition_video_indicator is a binary tensor indicating the condition region in the latent state. 1x1xTx1x1 tensor. | |
| condition_video_input_mask will be concat with the input for the network. | |
| Args: | |
| latent_state (torch.Tensor): latent state tensor in shape B,C,T,H,W | |
| condition (VideoExtendCondition): condition object | |
| num_condition_t (int): number of condition latent T, used in inference to decide the condition region and config.conditioner.video_cond_bool.condition_location == "first_n" | |
| Returns: | |
| VideoExtendCondition: updated condition object | |
| """ | |
| T = latent_state.shape[2] | |
| latent_dtype = latent_state.dtype | |
| condition_video_indicator = torch.zeros(1, 1, T, 1, 1, device=latent_state.device).type( | |
| latent_dtype | |
| ) # 1 for condition region | |
| if self.config.conditioner.video_cond_bool.condition_location == "first_n": | |
| # Only in inference to decide the condition region | |
| assert num_condition_t is not None, "num_condition_t should be provided" | |
| assert num_condition_t <= T, f"num_condition_t should be less than T, get {num_condition_t}, {T}" | |
| log.info( | |
| f"condition_location first_n, num_condition_t {num_condition_t}, condition.video_cond_bool {condition.video_cond_bool}" | |
| ) | |
| condition_video_indicator[:, :, :num_condition_t] += 1.0 | |
| elif self.config.conditioner.video_cond_bool.condition_location == "first_random_n": | |
| # Only in training | |
| num_condition_t_max = self.config.conditioner.video_cond_bool.first_random_n_num_condition_t_max | |
| assert ( | |
| num_condition_t_max <= T | |
| ), f"num_condition_t_max should be less than T, get {num_condition_t_max}, {T}" | |
| assert num_condition_t_max >= self.config.conditioner.video_cond_bool.first_random_n_num_condition_t_min | |
| num_condition_t = torch.randint( | |
| self.config.conditioner.video_cond_bool.first_random_n_num_condition_t_min, | |
| num_condition_t_max + 1, | |
| (1,), | |
| ).item() | |
| condition_video_indicator[:, :, :num_condition_t] += 1.0 | |
| elif self.config.conditioner.video_cond_bool.condition_location == "random": | |
| # Only in training | |
| condition_rate = self.config.conditioner.video_cond_bool.random_conditon_rate | |
| flag = torch.ones(1, 1, T, 1, 1, device=latent_state.device).type(latent_dtype) * condition_rate | |
| condition_video_indicator = torch.bernoulli(flag).type(latent_dtype).to(latent_state.device) | |
| else: | |
| raise NotImplementedError( | |
| f"condition_location {self.config.conditioner.video_cond_bool.condition_location} not implemented; training={self.training}" | |
| ) | |
| condition.gt_latent = latent_state | |
| condition.condition_video_indicator = condition_video_indicator | |
| B, C, T, H, W = latent_state.shape | |
| # Create additional input_mask channel, this will be concatenated to the input of the network | |
| # See design doc section (Implementation detail A.1 and A.2) for visualization | |
| ones_padding = torch.ones((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device) | |
| zeros_padding = torch.zeros((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device) | |
| assert condition.video_cond_bool is not None, "video_cond_bool should be set" | |
| # The input mask indicate whether the input is conditional region or not | |
| if condition.video_cond_bool: # Condition one given video frames | |
| condition.condition_video_input_mask = ( | |
| condition_video_indicator * ones_padding + (1 - condition_video_indicator) * zeros_padding | |
| ) | |
| else: # Unconditional case, use for cfg | |
| condition.condition_video_input_mask = zeros_padding | |
| to_cp = self.net.is_context_parallel_enabled | |
| # For inference, check if parallel_state is initialized | |
| if parallel_state.is_initialized(): | |
| condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp) | |
| else: | |
| assert not to_cp, "parallel_state is not initialized, context parallel should be turned off." | |
| return condition | |
| def add_condition_pose(self, data_batch: Dict, condition: VideoExtendCondition) -> VideoExtendCondition: | |
| """Add pose condition to the condition object. For camera control model | |
| Args: | |
| data_batch (Dict): data batch, with key "plucker_embeddings", in shape B,T,C,H,W | |
| latent_state (torch.Tensor): latent state tensor in shape B,C,T,H,W | |
| condition (VideoExtendCondition): condition object | |
| num_condition_t (int): number of condition latent T, used in inference to decide the condition region and config.conditioner.video_cond_bool.condition_location == "first_n" | |
| Returns: | |
| VideoExtendCondition: updated condition object | |
| """ | |
| assert ( | |
| "plucker_embeddings" in data_batch or "plucker_embeddings_downsample" in data_batch.keys() | |
| ), f"plucker_embeddings should be in data_batch. only find {data_batch.keys()}" | |
| plucker_embeddings = ( | |
| data_batch["plucker_embeddings"] | |
| if "plucker_embeddings_downsample" not in data_batch.keys() | |
| else data_batch["plucker_embeddings_downsample"] | |
| ) | |
| condition.condition_video_pose = rearrange(plucker_embeddings, "b t c h w -> b c t h w").contiguous() | |
| to_cp = self.net.is_context_parallel_enabled | |
| # For inference, check if parallel_state is initialized | |
| if parallel_state.is_initialized(): | |
| condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp) | |
| else: | |
| assert not to_cp, "parallel_state is not initialized, context parallel should be turned off." | |
| return condition | |
| def sample_tokens_start_from_p_or_i(self, latent_state: torch.Tensor) -> torch.Tensor: | |
| """Sample the PPP... from the IPPP... sequence, only for video sequence | |
| Args: | |
| latent_state (torch.Tensor): latent state tensor in shape B,C,T,H,W | |
| Returns: | |
| torch.Tensor: sampled PPP tensor in shape B,C,T,H,W | |
| """ | |
| B, C, T, H, W = latent_state.shape | |
| latent_dtype = latent_state.dtype | |
| T_target = self.state_shape[1] | |
| latent_state_sample = torch.zeros((B, C, T_target, H, W), dtype=latent_dtype, device=latent_state.device) | |
| t_start = torch.randint(0, T - T_target + 1, (1,)) | |
| # broadcast to other device | |
| latent_state_sample = latent_state[:, :, t_start : t_start + T_target].contiguous() | |
| if parallel_state.is_initialized(): | |
| latent_state_sample = _broadcast(latent_state_sample, to_tp=True, to_cp=True) | |
| return latent_state_sample | |
| class FSDPExtendDiffusionModel(ExtendDiffusionModel): | |
| pass | |