build-tools / diffusers /models /controlnets /controlnet_cosmos.py
salmankhanpm's picture
Add files using upload-large-folder tool
69e1a8d verified
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin
from ...utils import BaseOutput, is_torchvision_available, logging
from ..modeling_utils import ModelMixin
from ..transformers.transformer_cosmos import (
CosmosEmbedding,
CosmosLearnablePositionalEmbed,
CosmosPatchEmbed,
CosmosRotaryPosEmbed,
CosmosTransformerBlock,
)
if is_torchvision_available():
from torchvision import transforms
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class CosmosControlNetOutput(BaseOutput):
"""
Output of [`CosmosControlNetModel`].
Args:
control_block_samples (`list[torch.Tensor]`):
List of control block activations to be injected into transformer blocks.
"""
control_block_samples: List[torch.Tensor]
class CosmosControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
r"""
ControlNet for Cosmos Transfer2.5.
This model duplicates the shared embedding modules from the transformer (patch_embed, time_embed,
learnable_pos_embed, img_context_proj) to enable proper CPU offloading. The forward() method computes everything
internally from raw inputs.
"""
_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["patch_embed", "patch_embed_base", "time_embed"]
_no_split_modules = ["CosmosTransformerBlock"]
_keep_in_fp32_modules = ["learnable_pos_embed"]
@register_to_config
def __init__(
self,
n_controlnet_blocks: int = 4,
in_channels: int = 130,
latent_channels: int = 18, # base latent channels (latents + condition_mask) + padding_mask
model_channels: int = 2048,
num_attention_heads: int = 32,
attention_head_dim: int = 128,
mlp_ratio: float = 4.0,
text_embed_dim: int = 1024,
adaln_lora_dim: int = 256,
patch_size: Tuple[int, int, int] = (1, 2, 2),
max_size: Tuple[int, int, int] = (128, 240, 240),
rope_scale: Tuple[float, float, float] = (2.0, 1.0, 1.0),
extra_pos_embed_type: str | None = None,
img_context_dim_in: int | None = None,
img_context_dim_out: int = 2048,
use_crossattn_projection: bool = False,
crossattn_proj_in_channels: int = 1024,
encoder_hidden_states_channels: int = 1024,
):
super().__init__()
self.patch_embed = CosmosPatchEmbed(in_channels, model_channels, patch_size, bias=False)
self.patch_embed_base = CosmosPatchEmbed(latent_channels, model_channels, patch_size, bias=False)
self.time_embed = CosmosEmbedding(model_channels, model_channels)
self.learnable_pos_embed = None
if extra_pos_embed_type == "learnable":
self.learnable_pos_embed = CosmosLearnablePositionalEmbed(
hidden_size=model_channels,
max_size=max_size,
patch_size=patch_size,
)
self.img_context_proj = None
if img_context_dim_in is not None and img_context_dim_in > 0:
self.img_context_proj = nn.Sequential(
nn.Linear(img_context_dim_in, img_context_dim_out, bias=True),
nn.GELU(),
)
# Cross-attention projection for text embeddings (same as transformer)
self.crossattn_proj = None
if use_crossattn_projection:
self.crossattn_proj = nn.Sequential(
nn.Linear(crossattn_proj_in_channels, encoder_hidden_states_channels, bias=True),
nn.GELU(),
)
# RoPE for both control and base latents
self.rope = CosmosRotaryPosEmbed(
hidden_size=attention_head_dim, max_size=max_size, patch_size=patch_size, rope_scale=rope_scale
)
self.control_blocks = nn.ModuleList(
[
CosmosTransformerBlock(
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
cross_attention_dim=text_embed_dim,
mlp_ratio=mlp_ratio,
adaln_lora_dim=adaln_lora_dim,
qk_norm="rms_norm",
out_bias=False,
img_context=img_context_dim_in is not None and img_context_dim_in > 0,
before_proj=(block_idx == 0),
after_proj=True,
)
for block_idx in range(n_controlnet_blocks)
]
)
self.gradient_checkpointing = False
def _expand_conditioning_scale(self, conditioning_scale: float | list[float]) -> List[float]:
if isinstance(conditioning_scale, list):
scales = conditioning_scale
else:
scales = [conditioning_scale] * len(self.control_blocks)
if len(scales) < len(self.control_blocks):
logger.warning(
"Received %d control scales, but control network defines %d blocks. "
"Scales will be trimmed or repeated to match.",
len(scales),
len(self.control_blocks),
)
scales = (scales * len(self.control_blocks))[: len(self.control_blocks)]
return scales
def forward(
self,
controls_latents: torch.Tensor,
latents: torch.Tensor,
timestep: torch.Tensor,
encoder_hidden_states: Union[Optional[torch.Tensor], Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
condition_mask: torch.Tensor,
conditioning_scale: float | list[float] = 1.0,
padding_mask: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
fps: int | None = None,
return_dict: bool = True,
) -> Union[CosmosControlNetOutput, Tuple[List[torch.Tensor]]]:
"""
Forward pass for the ControlNet.
Args:
controls_latents: Control signal latents [B, C, T, H, W]
latents: Base latents from the noising process [B, C, T, H, W]
timestep: Diffusion timestep tensor
encoder_hidden_states: Tuple of (text_context, img_context) or text_context
condition_mask: Conditioning mask [B, 1, T, H, W]
conditioning_scale: Scale factor(s) for control outputs
padding_mask: Padding mask [B, 1, H, W] or None
attention_mask: Optional attention mask or None
fps: Frames per second for RoPE or None
return_dict: Whether to return a CosmosControlNetOutput or a tuple
Returns:
CosmosControlNetOutput or tuple of control tensors
"""
B, C, T, H, W = controls_latents.shape
# 1. Prepare control latents
control_hidden_states = controls_latents
vace_in_channels = self.config.in_channels - 1
if control_hidden_states.shape[1] < vace_in_channels - 1:
pad_C = vace_in_channels - 1 - control_hidden_states.shape[1]
control_hidden_states = torch.cat(
[
control_hidden_states,
torch.zeros(
(B, pad_C, T, H, W), dtype=control_hidden_states.dtype, device=control_hidden_states.device
),
],
dim=1,
)
if condition_mask is not None:
control_hidden_states = torch.cat([control_hidden_states, condition_mask], dim=1)
else:
control_hidden_states = torch.cat(
[control_hidden_states, torch.zeros_like(controls_latents[:, :1])], dim=1
)
padding_mask_resized = transforms.functional.resize(
padding_mask, list(control_hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
)
control_hidden_states = torch.cat(
[control_hidden_states, padding_mask_resized.unsqueeze(2).repeat(B, 1, T, 1, 1)], dim=1
)
# 2. Prepare base latents (same processing as transformer.forward)
base_hidden_states = latents
if condition_mask is not None:
base_hidden_states = torch.cat([base_hidden_states, condition_mask], dim=1)
base_padding_mask = transforms.functional.resize(
padding_mask, list(base_hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
)
base_hidden_states = torch.cat(
[base_hidden_states, base_padding_mask.unsqueeze(2).repeat(B, 1, T, 1, 1)], dim=1
)
# 3. Generate positional embeddings (shared for both)
image_rotary_emb = self.rope(control_hidden_states, fps=fps)
extra_pos_emb = self.learnable_pos_embed(control_hidden_states) if self.learnable_pos_embed else None
# 4. Patchify control latents
control_hidden_states = self.patch_embed(control_hidden_states)
control_hidden_states = control_hidden_states.flatten(1, 3)
# 5. Patchify base latents
p_t, p_h, p_w = self.config.patch_size
post_patch_num_frames = T // p_t
post_patch_height = H // p_h
post_patch_width = W // p_w
base_hidden_states = self.patch_embed_base(base_hidden_states)
base_hidden_states = base_hidden_states.flatten(1, 3)
# 6. Time embeddings
if timestep.ndim == 1:
temb, embedded_timestep = self.time_embed(base_hidden_states, timestep)
elif timestep.ndim == 5:
batch_size, _, num_frames, _, _ = latents.shape
assert timestep.shape == (batch_size, 1, num_frames, 1, 1), (
f"Expected timestep to have shape [B, 1, T, 1, 1], but got {timestep.shape}"
)
timestep_flat = timestep.flatten()
temb, embedded_timestep = self.time_embed(base_hidden_states, timestep_flat)
temb, embedded_timestep = (
x.view(batch_size, post_patch_num_frames, 1, 1, -1)
.expand(-1, -1, post_patch_height, post_patch_width, -1)
.flatten(1, 3)
for x in (temb, embedded_timestep)
)
else:
raise ValueError(f"Expected timestep to have shape [B, 1, T, 1, 1] or [T], but got {timestep.shape}")
# 7. Process encoder hidden states
if isinstance(encoder_hidden_states, tuple):
text_context, img_context = encoder_hidden_states
else:
text_context = encoder_hidden_states
img_context = None
# Apply cross-attention projection to text context
if self.crossattn_proj is not None:
text_context = self.crossattn_proj(text_context)
# Apply cross-attention projection to image context (if provided)
if img_context is not None and self.img_context_proj is not None:
img_context = self.img_context_proj(img_context)
# Combine text and image context into a single tuple
if self.config.img_context_dim_in is not None and self.config.img_context_dim_in > 0:
processed_encoder_hidden_states = (text_context, img_context)
else:
processed_encoder_hidden_states = text_context
# 8. Prepare attention mask
if attention_mask is not None:
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, S]
# 9. Run control blocks
scales = self._expand_conditioning_scale(conditioning_scale)
result = []
for block_idx, (block, scale) in enumerate(zip(self.control_blocks, scales)):
if torch.is_grad_enabled() and self.gradient_checkpointing:
control_hidden_states, control_proj = self._gradient_checkpointing_func(
block,
control_hidden_states,
processed_encoder_hidden_states,
embedded_timestep,
temb,
image_rotary_emb,
extra_pos_emb,
attention_mask,
None, # controlnet_residual
base_hidden_states,
block_idx,
)
else:
control_hidden_states, control_proj = block(
hidden_states=control_hidden_states,
encoder_hidden_states=processed_encoder_hidden_states,
embedded_timestep=embedded_timestep,
temb=temb,
image_rotary_emb=image_rotary_emb,
extra_pos_emb=extra_pos_emb,
attention_mask=attention_mask,
controlnet_residual=None,
latents=base_hidden_states,
block_idx=block_idx,
)
result.append(control_proj * scale)
if not return_dict:
return (result,)
return CosmosControlNetOutput(control_block_samples=result)