|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional |
|
|
|
|
|
import torch |
|
|
from megatron.core import parallel_state |
|
|
from torch import Tensor |
|
|
|
|
|
from cosmos_predict1.diffusion.conditioner import VideoExtendCondition |
|
|
from cosmos_predict1.diffusion.model.model_v2w import DiffusionV2WModel, broadcast_condition |
|
|
|
|
|
|
|
|
class DiffusionGen3CModel(DiffusionV2WModel): |
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.frame_buffer_max = config.frame_buffer_max |
|
|
self.chunk_size = 121 |
|
|
|
|
|
def encode_warped_frames( |
|
|
self, |
|
|
condition_state: torch.Tensor, |
|
|
condition_state_mask: torch.Tensor, |
|
|
dtype: torch.dtype, |
|
|
): |
|
|
|
|
|
assert condition_state.dim() == 6 |
|
|
condition_state_mask = (condition_state_mask * 2 - 1).repeat(1, 1, 1, 3, 1, 1) |
|
|
latent_condition = [] |
|
|
for i in range(condition_state.shape[2]): |
|
|
current_video_latent = self.encode( |
|
|
condition_state[:, :, i].permute(0, 2, 1, 3, 4).to(dtype) |
|
|
).contiguous() |
|
|
|
|
|
current_mask_latent = self.encode( |
|
|
condition_state_mask[:, :, i].permute(0, 2, 1, 3, 4).to(dtype) |
|
|
).contiguous() |
|
|
latent_condition.append(current_video_latent) |
|
|
latent_condition.append(current_mask_latent) |
|
|
for _ in range(self.frame_buffer_max - condition_state.shape[2]): |
|
|
latent_condition.append(torch.zeros_like(current_video_latent)) |
|
|
latent_condition.append(torch.zeros_like(current_mask_latent)) |
|
|
|
|
|
latent_condition = torch.cat(latent_condition, dim=1) |
|
|
return latent_condition |
|
|
|
|
|
def _get_conditions( |
|
|
self, |
|
|
data_batch: dict, |
|
|
is_negative_prompt: bool = False, |
|
|
condition_latent: Optional[torch.Tensor] = None, |
|
|
num_condition_t: Optional[int] = None, |
|
|
add_input_frames_guidance: bool = False, |
|
|
): |
|
|
"""Get the conditions for the model. |
|
|
|
|
|
Args: |
|
|
data_batch: Input data dictionary |
|
|
is_negative_prompt: Whether to use negative prompting |
|
|
condition_latent: Conditioning frames tensor (B,C,T,H,W) |
|
|
num_condition_t: Number of frames to condition on |
|
|
add_input_frames_guidance: Whether to apply guidance to input frames |
|
|
|
|
|
Returns: |
|
|
condition: Input conditions |
|
|
uncondition: Conditions removed/reduced to minimum (unconditioned) |
|
|
""" |
|
|
if is_negative_prompt: |
|
|
condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) |
|
|
else: |
|
|
condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) |
|
|
|
|
|
|
|
|
condition_state, condition_state_mask = ( |
|
|
data_batch["condition_state"], |
|
|
data_batch["condition_state_mask"], |
|
|
) |
|
|
latent_condition = self.encode_warped_frames( |
|
|
condition_state, condition_state_mask, self.tensor_kwargs["dtype"] |
|
|
) |
|
|
|
|
|
condition.video_cond_bool = True |
|
|
condition = self.add_condition_video_indicator_and_video_input_mask( |
|
|
condition_latent, condition, num_condition_t |
|
|
) |
|
|
condition = self.add_condition_pose(latent_condition, condition) |
|
|
|
|
|
uncondition.video_cond_bool = False if add_input_frames_guidance else True |
|
|
uncondition = self.add_condition_video_indicator_and_video_input_mask( |
|
|
condition_latent, uncondition, num_condition_t |
|
|
) |
|
|
uncondition = self.add_condition_pose(latent_condition, uncondition, drop_out_latent = True) |
|
|
assert condition.gt_latent.allclose(uncondition.gt_latent) |
|
|
|
|
|
|
|
|
to_cp = self.net.is_context_parallel_enabled |
|
|
if parallel_state.is_initialized(): |
|
|
condition = broadcast_condition(condition, to_tp=False, to_cp=to_cp) |
|
|
uncondition = broadcast_condition(uncondition, to_tp=False, to_cp=to_cp) |
|
|
|
|
|
return condition, uncondition |
|
|
|
|
|
def add_condition_pose(self, latent_condition: torch.Tensor, condition: VideoExtendCondition, |
|
|
drop_out_latent: bool = False) -> VideoExtendCondition: |
|
|
"""Add pose condition to the condition object. For camera control model |
|
|
Args: |
|
|
data_batch (Dict): data batch, with key "plucker_embeddings", in shape B,T,C,H,W |
|
|
latent_state (torch.Tensor): latent state tensor in shape B,C,T,H,W |
|
|
condition (VideoExtendCondition): condition object |
|
|
num_condition_t (int): number of condition latent T, used in inference to decide the condition region and config.conditioner.video_cond_bool.condition_location == "first_n" |
|
|
Returns: |
|
|
VideoExtendCondition: updated condition object |
|
|
""" |
|
|
if drop_out_latent: |
|
|
condition.condition_video_pose = torch.zeros_like(latent_condition.contiguous()) |
|
|
else: |
|
|
condition.condition_video_pose = latent_condition.contiguous() |
|
|
|
|
|
to_cp = self.net.is_context_parallel_enabled |
|
|
|
|
|
|
|
|
if parallel_state.is_initialized(): |
|
|
condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp) |
|
|
else: |
|
|
assert not to_cp, "parallel_state is not initialized, context parallel should be turned off." |
|
|
|
|
|
return condition |
|
|
|