| from dataclasses import dataclass |
| from typing import List, Optional, Tuple, Union |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from ...configuration_utils import ConfigMixin, register_to_config |
| from ...loaders import FromOriginalModelMixin |
| from ...utils import BaseOutput, is_torchvision_available, logging |
| from ..modeling_utils import ModelMixin |
| from ..transformers.transformer_cosmos import ( |
| CosmosEmbedding, |
| CosmosLearnablePositionalEmbed, |
| CosmosPatchEmbed, |
| CosmosRotaryPosEmbed, |
| CosmosTransformerBlock, |
| ) |
|
|
|
|
| if is_torchvision_available(): |
| from torchvision import transforms |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| @dataclass |
| class CosmosControlNetOutput(BaseOutput): |
| """ |
| Output of [`CosmosControlNetModel`]. |
| |
| Args: |
| control_block_samples (`list[torch.Tensor]`): |
| List of control block activations to be injected into transformer blocks. |
| """ |
|
|
| control_block_samples: List[torch.Tensor] |
|
|
|
|
| class CosmosControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): |
| r""" |
| ControlNet for Cosmos Transfer2.5. |
| |
| This model duplicates the shared embedding modules from the transformer (patch_embed, time_embed, |
| learnable_pos_embed, img_context_proj) to enable proper CPU offloading. The forward() method computes everything |
| internally from raw inputs. |
| """ |
|
|
| _supports_gradient_checkpointing = True |
| _skip_layerwise_casting_patterns = ["patch_embed", "patch_embed_base", "time_embed"] |
| _no_split_modules = ["CosmosTransformerBlock"] |
| _keep_in_fp32_modules = ["learnable_pos_embed"] |
|
|
| @register_to_config |
| def __init__( |
| self, |
| n_controlnet_blocks: int = 4, |
| in_channels: int = 130, |
| latent_channels: int = 18, |
| model_channels: int = 2048, |
| num_attention_heads: int = 32, |
| attention_head_dim: int = 128, |
| mlp_ratio: float = 4.0, |
| text_embed_dim: int = 1024, |
| adaln_lora_dim: int = 256, |
| patch_size: Tuple[int, int, int] = (1, 2, 2), |
| max_size: Tuple[int, int, int] = (128, 240, 240), |
| rope_scale: Tuple[float, float, float] = (2.0, 1.0, 1.0), |
| extra_pos_embed_type: str | None = None, |
| img_context_dim_in: int | None = None, |
| img_context_dim_out: int = 2048, |
| use_crossattn_projection: bool = False, |
| crossattn_proj_in_channels: int = 1024, |
| encoder_hidden_states_channels: int = 1024, |
| ): |
| super().__init__() |
|
|
| self.patch_embed = CosmosPatchEmbed(in_channels, model_channels, patch_size, bias=False) |
|
|
| self.patch_embed_base = CosmosPatchEmbed(latent_channels, model_channels, patch_size, bias=False) |
| self.time_embed = CosmosEmbedding(model_channels, model_channels) |
|
|
| self.learnable_pos_embed = None |
| if extra_pos_embed_type == "learnable": |
| self.learnable_pos_embed = CosmosLearnablePositionalEmbed( |
| hidden_size=model_channels, |
| max_size=max_size, |
| patch_size=patch_size, |
| ) |
|
|
| self.img_context_proj = None |
| if img_context_dim_in is not None and img_context_dim_in > 0: |
| self.img_context_proj = nn.Sequential( |
| nn.Linear(img_context_dim_in, img_context_dim_out, bias=True), |
| nn.GELU(), |
| ) |
|
|
| |
| self.crossattn_proj = None |
| if use_crossattn_projection: |
| self.crossattn_proj = nn.Sequential( |
| nn.Linear(crossattn_proj_in_channels, encoder_hidden_states_channels, bias=True), |
| nn.GELU(), |
| ) |
|
|
| |
| self.rope = CosmosRotaryPosEmbed( |
| hidden_size=attention_head_dim, max_size=max_size, patch_size=patch_size, rope_scale=rope_scale |
| ) |
|
|
| self.control_blocks = nn.ModuleList( |
| [ |
| CosmosTransformerBlock( |
| num_attention_heads=num_attention_heads, |
| attention_head_dim=attention_head_dim, |
| cross_attention_dim=text_embed_dim, |
| mlp_ratio=mlp_ratio, |
| adaln_lora_dim=adaln_lora_dim, |
| qk_norm="rms_norm", |
| out_bias=False, |
| img_context=img_context_dim_in is not None and img_context_dim_in > 0, |
| before_proj=(block_idx == 0), |
| after_proj=True, |
| ) |
| for block_idx in range(n_controlnet_blocks) |
| ] |
| ) |
|
|
| self.gradient_checkpointing = False |
|
|
| def _expand_conditioning_scale(self, conditioning_scale: float | list[float]) -> List[float]: |
| if isinstance(conditioning_scale, list): |
| scales = conditioning_scale |
| else: |
| scales = [conditioning_scale] * len(self.control_blocks) |
|
|
| if len(scales) < len(self.control_blocks): |
| logger.warning( |
| "Received %d control scales, but control network defines %d blocks. " |
| "Scales will be trimmed or repeated to match.", |
| len(scales), |
| len(self.control_blocks), |
| ) |
| scales = (scales * len(self.control_blocks))[: len(self.control_blocks)] |
| return scales |
|
|
| def forward( |
| self, |
| controls_latents: torch.Tensor, |
| latents: torch.Tensor, |
| timestep: torch.Tensor, |
| encoder_hidden_states: Union[Optional[torch.Tensor], Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]], |
| condition_mask: torch.Tensor, |
| conditioning_scale: float | list[float] = 1.0, |
| padding_mask: torch.Tensor | None = None, |
| attention_mask: torch.Tensor | None = None, |
| fps: int | None = None, |
| return_dict: bool = True, |
| ) -> Union[CosmosControlNetOutput, Tuple[List[torch.Tensor]]]: |
| """ |
| Forward pass for the ControlNet. |
| |
| Args: |
| controls_latents: Control signal latents [B, C, T, H, W] |
| latents: Base latents from the noising process [B, C, T, H, W] |
| timestep: Diffusion timestep tensor |
| encoder_hidden_states: Tuple of (text_context, img_context) or text_context |
| condition_mask: Conditioning mask [B, 1, T, H, W] |
| conditioning_scale: Scale factor(s) for control outputs |
| padding_mask: Padding mask [B, 1, H, W] or None |
| attention_mask: Optional attention mask or None |
| fps: Frames per second for RoPE or None |
| return_dict: Whether to return a CosmosControlNetOutput or a tuple |
| |
| Returns: |
| CosmosControlNetOutput or tuple of control tensors |
| """ |
| B, C, T, H, W = controls_latents.shape |
|
|
| |
| control_hidden_states = controls_latents |
| vace_in_channels = self.config.in_channels - 1 |
| if control_hidden_states.shape[1] < vace_in_channels - 1: |
| pad_C = vace_in_channels - 1 - control_hidden_states.shape[1] |
| control_hidden_states = torch.cat( |
| [ |
| control_hidden_states, |
| torch.zeros( |
| (B, pad_C, T, H, W), dtype=control_hidden_states.dtype, device=control_hidden_states.device |
| ), |
| ], |
| dim=1, |
| ) |
|
|
| if condition_mask is not None: |
| control_hidden_states = torch.cat([control_hidden_states, condition_mask], dim=1) |
| else: |
| control_hidden_states = torch.cat( |
| [control_hidden_states, torch.zeros_like(controls_latents[:, :1])], dim=1 |
| ) |
|
|
| padding_mask_resized = transforms.functional.resize( |
| padding_mask, list(control_hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST |
| ) |
| control_hidden_states = torch.cat( |
| [control_hidden_states, padding_mask_resized.unsqueeze(2).repeat(B, 1, T, 1, 1)], dim=1 |
| ) |
|
|
| |
| base_hidden_states = latents |
| if condition_mask is not None: |
| base_hidden_states = torch.cat([base_hidden_states, condition_mask], dim=1) |
|
|
| base_padding_mask = transforms.functional.resize( |
| padding_mask, list(base_hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST |
| ) |
| base_hidden_states = torch.cat( |
| [base_hidden_states, base_padding_mask.unsqueeze(2).repeat(B, 1, T, 1, 1)], dim=1 |
| ) |
|
|
| |
| image_rotary_emb = self.rope(control_hidden_states, fps=fps) |
| extra_pos_emb = self.learnable_pos_embed(control_hidden_states) if self.learnable_pos_embed else None |
|
|
| |
| control_hidden_states = self.patch_embed(control_hidden_states) |
| control_hidden_states = control_hidden_states.flatten(1, 3) |
|
|
| |
| p_t, p_h, p_w = self.config.patch_size |
| post_patch_num_frames = T // p_t |
| post_patch_height = H // p_h |
| post_patch_width = W // p_w |
|
|
| base_hidden_states = self.patch_embed_base(base_hidden_states) |
| base_hidden_states = base_hidden_states.flatten(1, 3) |
|
|
| |
| if timestep.ndim == 1: |
| temb, embedded_timestep = self.time_embed(base_hidden_states, timestep) |
| elif timestep.ndim == 5: |
| batch_size, _, num_frames, _, _ = latents.shape |
| assert timestep.shape == (batch_size, 1, num_frames, 1, 1), ( |
| f"Expected timestep to have shape [B, 1, T, 1, 1], but got {timestep.shape}" |
| ) |
| timestep_flat = timestep.flatten() |
| temb, embedded_timestep = self.time_embed(base_hidden_states, timestep_flat) |
| temb, embedded_timestep = ( |
| x.view(batch_size, post_patch_num_frames, 1, 1, -1) |
| .expand(-1, -1, post_patch_height, post_patch_width, -1) |
| .flatten(1, 3) |
| for x in (temb, embedded_timestep) |
| ) |
| else: |
| raise ValueError(f"Expected timestep to have shape [B, 1, T, 1, 1] or [T], but got {timestep.shape}") |
|
|
| |
| if isinstance(encoder_hidden_states, tuple): |
| text_context, img_context = encoder_hidden_states |
| else: |
| text_context = encoder_hidden_states |
| img_context = None |
|
|
| |
| if self.crossattn_proj is not None: |
| text_context = self.crossattn_proj(text_context) |
|
|
| |
| if img_context is not None and self.img_context_proj is not None: |
| img_context = self.img_context_proj(img_context) |
|
|
| |
| if self.config.img_context_dim_in is not None and self.config.img_context_dim_in > 0: |
| processed_encoder_hidden_states = (text_context, img_context) |
| else: |
| processed_encoder_hidden_states = text_context |
|
|
| |
| if attention_mask is not None: |
| attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) |
|
|
| |
| scales = self._expand_conditioning_scale(conditioning_scale) |
| result = [] |
| for block_idx, (block, scale) in enumerate(zip(self.control_blocks, scales)): |
| if torch.is_grad_enabled() and self.gradient_checkpointing: |
| control_hidden_states, control_proj = self._gradient_checkpointing_func( |
| block, |
| control_hidden_states, |
| processed_encoder_hidden_states, |
| embedded_timestep, |
| temb, |
| image_rotary_emb, |
| extra_pos_emb, |
| attention_mask, |
| None, |
| base_hidden_states, |
| block_idx, |
| ) |
| else: |
| control_hidden_states, control_proj = block( |
| hidden_states=control_hidden_states, |
| encoder_hidden_states=processed_encoder_hidden_states, |
| embedded_timestep=embedded_timestep, |
| temb=temb, |
| image_rotary_emb=image_rotary_emb, |
| extra_pos_emb=extra_pos_emb, |
| attention_mask=attention_mask, |
| controlnet_residual=None, |
| latents=base_hidden_states, |
| block_idx=block_idx, |
| ) |
| result.append(control_proj * scale) |
|
|
| if not return_dict: |
| return (result,) |
|
|
| return CosmosControlNetOutput(control_block_samples=result) |
|
|