Spaces:
Build error
Build error
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from typing import Optional | |
| import torch | |
| from megatron.core import parallel_state | |
| from torch import Tensor | |
| from cosmos_predict1.diffusion.conditioner import VideoExtendCondition | |
| from cosmos_predict1.diffusion.model.model_t2w import DiffusionT2WModel, broadcast_condition | |
| from cosmos_predict1.diffusion.module.parallel import cat_outputs_cp, split_inputs_cp | |
| from cosmos_predict1.utils import log, misc | |
| class DiffusionV2WModel(DiffusionT2WModel): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| def add_condition_video_indicator_and_video_input_mask( | |
| self, latent_state: torch.Tensor, condition: VideoExtendCondition, num_condition_t: Optional[int] = None | |
| ) -> VideoExtendCondition: | |
| """Adds conditioning masks to VideoExtendCondition object. | |
| Creates binary indicators and input masks for conditional video generation. | |
| Args: | |
| latent_state: Input latent tensor (B,C,T,H,W) | |
| condition: VideoExtendCondition object to update | |
| num_condition_t: Number of frames to condition on | |
| Returns: | |
| Updated VideoExtendCondition with added masks: | |
| - condition_video_indicator: Binary tensor marking condition regions | |
| - condition_video_input_mask: Input mask for network | |
| - gt_latent: Ground truth latent tensor | |
| """ | |
| T = latent_state.shape[2] | |
| latent_dtype = latent_state.dtype | |
| condition_video_indicator = torch.zeros(1, 1, T, 1, 1, device=latent_state.device).type( | |
| latent_dtype | |
| ) # 1 for condition region | |
| # Only in inference to decide the condition region | |
| assert num_condition_t is not None, "num_condition_t should be provided" | |
| assert num_condition_t <= T, f"num_condition_t should be less than T, get {num_condition_t}, {T}" | |
| log.debug( | |
| f"condition_location first_n, num_condition_t {num_condition_t}, condition.video_cond_bool {condition.video_cond_bool}" | |
| ) | |
| condition_video_indicator[:, :, :num_condition_t] += 1.0 | |
| condition.gt_latent = latent_state | |
| condition.condition_video_indicator = condition_video_indicator | |
| B, C, T, H, W = latent_state.shape | |
| # Create additional input_mask channel, this will be concatenated to the input of the network | |
| # See design doc section (Implementation detail A.1 and A.2) for visualization | |
| ones_padding = torch.ones((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device) | |
| zeros_padding = torch.zeros((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device) | |
| assert condition.video_cond_bool is not None, "video_cond_bool should be set" | |
| # The input mask indicate whether the input is conditional region or not | |
| if condition.video_cond_bool: # Condition one given video frames | |
| condition.condition_video_input_mask = ( | |
| condition_video_indicator * ones_padding + (1 - condition_video_indicator) * zeros_padding | |
| ) | |
| else: # Unconditional case, use for cfg | |
| condition.condition_video_input_mask = zeros_padding | |
| return condition | |
| def generate_samples_from_batch( | |
| self, | |
| data_batch: dict, | |
| guidance: float = 1.5, | |
| seed: int = 1, | |
| state_shape: tuple | None = None, | |
| n_sample: int | None = 1, | |
| is_negative_prompt: bool = False, | |
| num_steps: int = 35, | |
| condition_latent: Optional[torch.Tensor] = None, | |
| num_condition_t: Optional[int] = None, | |
| condition_augment_sigma: float = None, | |
| add_input_frames_guidance: bool = False, | |
| ) -> Tensor: | |
| """Generates video samples conditioned on input frames. | |
| Args: | |
| data_batch: Input data dictionary | |
| guidance: Classifier-free guidance scale | |
| seed: Random seed for reproducibility | |
| state_shape: Shape of output tensor (defaults to model's state shape) | |
| n_sample: Number of samples to generate (defaults to batch size) | |
| is_negative_prompt: Whether to use negative prompting | |
| num_steps: Number of denoising steps | |
| condition_latent: Conditioning frames tensor (B,C,T,H,W) | |
| num_condition_t: Number of frames to condition on | |
| condition_augment_sigma: Noise level for condition augmentation | |
| add_input_frames_guidance: Whether to apply guidance to input frames | |
| Returns: | |
| Generated video samples tensor | |
| """ | |
| assert condition_latent is not None, "condition_latent should be provided" | |
| condition, uncondition = self._get_conditions( | |
| data_batch, is_negative_prompt, condition_latent, num_condition_t, add_input_frames_guidance | |
| ) | |
| self.scheduler.set_timesteps(num_steps) | |
| if n_sample is None: | |
| n_sample = condition_latent.shape[0] | |
| xt = torch.randn(size=(n_sample,) + tuple(state_shape), **self.tensor_kwargs) * self.scheduler.init_noise_sigma | |
| to_cp = self.net.is_context_parallel_enabled | |
| if to_cp: | |
| xt = split_inputs_cp(x=xt, seq_dim=2, cp_group=self.net.cp_group) | |
| for t in self.scheduler.timesteps: | |
| self.scheduler._init_step_index(t) | |
| sigma = self.scheduler.sigmas[self.scheduler.step_index].to(**self.tensor_kwargs) | |
| # Form new noise from latent | |
| xt = xt.to(**self.tensor_kwargs) | |
| new_xt, latent, indicator = self._augment_noise_with_latent( | |
| xt, sigma, condition, condition_augment_sigma=condition_augment_sigma, seed=seed | |
| ) | |
| new_xt = new_xt.to(**self.tensor_kwargs) | |
| new_xt_scaled = self.scheduler.scale_model_input(new_xt, timestep=t) | |
| # Predict the noise residual | |
| t = t.to(**self.tensor_kwargs) | |
| net_output_cond = self.net(x=new_xt_scaled, timesteps=t, **condition.to_dict()) | |
| net_output_uncond = self.net(x=new_xt_scaled, timesteps=t, **uncondition.to_dict()) | |
| net_output = net_output_cond + guidance * (net_output_cond - net_output_uncond) | |
| # Replace indicated output with latent | |
| latent_unscaled = self._reverse_precondition_output(latent, xt=new_xt, sigma=sigma) | |
| new_output = indicator * latent_unscaled + (1 - indicator) * net_output | |
| # Compute the previous noisy sample x_t -> x_t-1 | |
| xt = self.scheduler.step(new_output, t, new_xt).prev_sample | |
| samples = xt | |
| if to_cp: | |
| samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) | |
| return samples | |
| def _get_conditions( | |
| self, | |
| data_batch: dict, | |
| is_negative_prompt: bool = False, | |
| condition_latent: Optional[torch.Tensor] = None, | |
| num_condition_t: Optional[int] = None, | |
| add_input_frames_guidance: bool = False, | |
| ): | |
| """Get the conditions for the model. | |
| Args: | |
| data_batch: Input data dictionary | |
| is_negative_prompt: Whether to use negative prompting | |
| condition_latent: Conditioning frames tensor (B,C,T,H,W) | |
| num_condition_t: Number of frames to condition on | |
| add_input_frames_guidance: Whether to apply guidance to input frames | |
| Returns: | |
| condition: Input conditions | |
| uncondition: Conditions removed/reduced to minimum (unconditioned) | |
| """ | |
| if is_negative_prompt: | |
| condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) | |
| else: | |
| condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) | |
| condition.video_cond_bool = True | |
| condition = self.add_condition_video_indicator_and_video_input_mask( | |
| condition_latent, condition, num_condition_t | |
| ) | |
| uncondition.video_cond_bool = False if add_input_frames_guidance else True | |
| uncondition = self.add_condition_video_indicator_and_video_input_mask( | |
| condition_latent, uncondition, num_condition_t | |
| ) | |
| assert condition.gt_latent.allclose(uncondition.gt_latent) | |
| # For inference, check if parallel_state is initialized | |
| to_cp = self.net.is_context_parallel_enabled | |
| if parallel_state.is_initialized(): | |
| condition = broadcast_condition(condition, to_tp=False, to_cp=to_cp) | |
| uncondition = broadcast_condition(uncondition, to_tp=False, to_cp=to_cp) | |
| return condition, uncondition | |
| def _augment_noise_with_latent( | |
| self, | |
| xt: Tensor, | |
| sigma: Tensor, | |
| condition: VideoExtendCondition, | |
| condition_augment_sigma: float = 0.001, | |
| seed: int = 1, | |
| ) -> tuple[Tensor, Tensor, Tensor]: | |
| """Augments the conditional frames with noise during inference. | |
| Args: | |
| xt (Tensor): noise | |
| sigma (Tensor): noise level for the generation region | |
| condition (VideoExtendCondition): condition object | |
| condition_video_indicator: binary tensor indicating the region is condition(value=1) or generation(value=0). Bx1xTx1x1 tensor. | |
| condition_video_input_mask: input mask for the network input, indicating the condition region. B,1,T,H,W tensor. will be concat with the input for the network. | |
| condition_augment_sigma (float): sigma for condition video augmentation in inference | |
| seed (int): random seed for reproducibility | |
| Returns: | |
| new_xt (Tensor): new latent-augmented noise tensor in shape B,C,T,H,W | |
| latent (Tensor): ground-truth latent tensor in shape B,C,T,H,W | |
| indicator (Tensor): ground-truth latent binary indicator tensor in shape B,C,T,H,W | |
| """ | |
| # Augment the latent with different sigma value, and add the augment_sigma to the condition object if needed | |
| augment_sigma = condition_augment_sigma | |
| latent = condition.gt_latent | |
| indicator = condition.condition_video_indicator | |
| if augment_sigma >= sigma: | |
| indicator = torch.zeros_like(indicator) | |
| # Now apply the augment_sigma to the gt_latent | |
| noise = misc.arch_invariant_rand( | |
| latent.shape, | |
| torch.float32, | |
| self.tensor_kwargs["device"], | |
| seed, | |
| ) | |
| augment_latent = latent + noise * augment_sigma | |
| augment_latent = self.scheduler.precondition_inputs(augment_latent, augment_sigma) | |
| augment_latent_unscaled = self._reverse_precondition_input(augment_latent, sigma) | |
| if self.net.is_context_parallel_enabled: | |
| latent = split_inputs_cp(condition.gt_latent, seq_dim=2, cp_group=self.net.cp_group) | |
| indicator = split_inputs_cp(indicator, seq_dim=2, cp_group=self.net.cp_group) | |
| augment_latent_unscaled = split_inputs_cp(augment_latent_unscaled, seq_dim=2, cp_group=self.net.cp_group) | |
| # Compose the model input with condition region (augment_latent) and generation region (noise_x) | |
| new_xt = indicator * augment_latent_unscaled + (1 - indicator) * xt | |
| return new_xt, latent, indicator | |
| def _reverse_precondition_input(self, xt: Tensor, sigma: Tensor) -> Tensor: | |
| c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5) | |
| xt_unscaled = xt / c_in | |
| return xt_unscaled | |
| def _reverse_precondition_output(self, latent: Tensor, xt: Tensor, sigma: Tensor) -> Tensor: | |
| sigma_data = self.scheduler.config.sigma_data | |
| c_skip = sigma_data**2 / (sigma**2 + sigma_data**2) | |
| c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 | |
| latent_unscaled = (latent - c_skip * xt) / c_out | |
| return latent_unscaled | |