File size: 13,070 Bytes
69e1a8d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 | from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin
from ...utils import BaseOutput, is_torchvision_available, logging
from ..modeling_utils import ModelMixin
from ..transformers.transformer_cosmos import (
CosmosEmbedding,
CosmosLearnablePositionalEmbed,
CosmosPatchEmbed,
CosmosRotaryPosEmbed,
CosmosTransformerBlock,
)
if is_torchvision_available():
from torchvision import transforms
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class CosmosControlNetOutput(BaseOutput):
"""
Output of [`CosmosControlNetModel`].
Args:
control_block_samples (`list[torch.Tensor]`):
List of control block activations to be injected into transformer blocks.
"""
control_block_samples: List[torch.Tensor]
class CosmosControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
r"""
ControlNet for Cosmos Transfer2.5.
This model duplicates the shared embedding modules from the transformer (patch_embed, time_embed,
learnable_pos_embed, img_context_proj) to enable proper CPU offloading. The forward() method computes everything
internally from raw inputs.
"""
_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["patch_embed", "patch_embed_base", "time_embed"]
_no_split_modules = ["CosmosTransformerBlock"]
_keep_in_fp32_modules = ["learnable_pos_embed"]
@register_to_config
def __init__(
self,
n_controlnet_blocks: int = 4,
in_channels: int = 130,
latent_channels: int = 18, # base latent channels (latents + condition_mask) + padding_mask
model_channels: int = 2048,
num_attention_heads: int = 32,
attention_head_dim: int = 128,
mlp_ratio: float = 4.0,
text_embed_dim: int = 1024,
adaln_lora_dim: int = 256,
patch_size: Tuple[int, int, int] = (1, 2, 2),
max_size: Tuple[int, int, int] = (128, 240, 240),
rope_scale: Tuple[float, float, float] = (2.0, 1.0, 1.0),
extra_pos_embed_type: str | None = None,
img_context_dim_in: int | None = None,
img_context_dim_out: int = 2048,
use_crossattn_projection: bool = False,
crossattn_proj_in_channels: int = 1024,
encoder_hidden_states_channels: int = 1024,
):
super().__init__()
self.patch_embed = CosmosPatchEmbed(in_channels, model_channels, patch_size, bias=False)
self.patch_embed_base = CosmosPatchEmbed(latent_channels, model_channels, patch_size, bias=False)
self.time_embed = CosmosEmbedding(model_channels, model_channels)
self.learnable_pos_embed = None
if extra_pos_embed_type == "learnable":
self.learnable_pos_embed = CosmosLearnablePositionalEmbed(
hidden_size=model_channels,
max_size=max_size,
patch_size=patch_size,
)
self.img_context_proj = None
if img_context_dim_in is not None and img_context_dim_in > 0:
self.img_context_proj = nn.Sequential(
nn.Linear(img_context_dim_in, img_context_dim_out, bias=True),
nn.GELU(),
)
# Cross-attention projection for text embeddings (same as transformer)
self.crossattn_proj = None
if use_crossattn_projection:
self.crossattn_proj = nn.Sequential(
nn.Linear(crossattn_proj_in_channels, encoder_hidden_states_channels, bias=True),
nn.GELU(),
)
# RoPE for both control and base latents
self.rope = CosmosRotaryPosEmbed(
hidden_size=attention_head_dim, max_size=max_size, patch_size=patch_size, rope_scale=rope_scale
)
self.control_blocks = nn.ModuleList(
[
CosmosTransformerBlock(
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
cross_attention_dim=text_embed_dim,
mlp_ratio=mlp_ratio,
adaln_lora_dim=adaln_lora_dim,
qk_norm="rms_norm",
out_bias=False,
img_context=img_context_dim_in is not None and img_context_dim_in > 0,
before_proj=(block_idx == 0),
after_proj=True,
)
for block_idx in range(n_controlnet_blocks)
]
)
self.gradient_checkpointing = False
def _expand_conditioning_scale(self, conditioning_scale: float | list[float]) -> List[float]:
if isinstance(conditioning_scale, list):
scales = conditioning_scale
else:
scales = [conditioning_scale] * len(self.control_blocks)
if len(scales) < len(self.control_blocks):
logger.warning(
"Received %d control scales, but control network defines %d blocks. "
"Scales will be trimmed or repeated to match.",
len(scales),
len(self.control_blocks),
)
scales = (scales * len(self.control_blocks))[: len(self.control_blocks)]
return scales
def forward(
self,
controls_latents: torch.Tensor,
latents: torch.Tensor,
timestep: torch.Tensor,
encoder_hidden_states: Union[Optional[torch.Tensor], Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
condition_mask: torch.Tensor,
conditioning_scale: float | list[float] = 1.0,
padding_mask: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
fps: int | None = None,
return_dict: bool = True,
) -> Union[CosmosControlNetOutput, Tuple[List[torch.Tensor]]]:
"""
Forward pass for the ControlNet.
Args:
controls_latents: Control signal latents [B, C, T, H, W]
latents: Base latents from the noising process [B, C, T, H, W]
timestep: Diffusion timestep tensor
encoder_hidden_states: Tuple of (text_context, img_context) or text_context
condition_mask: Conditioning mask [B, 1, T, H, W]
conditioning_scale: Scale factor(s) for control outputs
padding_mask: Padding mask [B, 1, H, W] or None
attention_mask: Optional attention mask or None
fps: Frames per second for RoPE or None
return_dict: Whether to return a CosmosControlNetOutput or a tuple
Returns:
CosmosControlNetOutput or tuple of control tensors
"""
B, C, T, H, W = controls_latents.shape
# 1. Prepare control latents
control_hidden_states = controls_latents
vace_in_channels = self.config.in_channels - 1
if control_hidden_states.shape[1] < vace_in_channels - 1:
pad_C = vace_in_channels - 1 - control_hidden_states.shape[1]
control_hidden_states = torch.cat(
[
control_hidden_states,
torch.zeros(
(B, pad_C, T, H, W), dtype=control_hidden_states.dtype, device=control_hidden_states.device
),
],
dim=1,
)
if condition_mask is not None:
control_hidden_states = torch.cat([control_hidden_states, condition_mask], dim=1)
else:
control_hidden_states = torch.cat(
[control_hidden_states, torch.zeros_like(controls_latents[:, :1])], dim=1
)
padding_mask_resized = transforms.functional.resize(
padding_mask, list(control_hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
)
control_hidden_states = torch.cat(
[control_hidden_states, padding_mask_resized.unsqueeze(2).repeat(B, 1, T, 1, 1)], dim=1
)
# 2. Prepare base latents (same processing as transformer.forward)
base_hidden_states = latents
if condition_mask is not None:
base_hidden_states = torch.cat([base_hidden_states, condition_mask], dim=1)
base_padding_mask = transforms.functional.resize(
padding_mask, list(base_hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
)
base_hidden_states = torch.cat(
[base_hidden_states, base_padding_mask.unsqueeze(2).repeat(B, 1, T, 1, 1)], dim=1
)
# 3. Generate positional embeddings (shared for both)
image_rotary_emb = self.rope(control_hidden_states, fps=fps)
extra_pos_emb = self.learnable_pos_embed(control_hidden_states) if self.learnable_pos_embed else None
# 4. Patchify control latents
control_hidden_states = self.patch_embed(control_hidden_states)
control_hidden_states = control_hidden_states.flatten(1, 3)
# 5. Patchify base latents
p_t, p_h, p_w = self.config.patch_size
post_patch_num_frames = T // p_t
post_patch_height = H // p_h
post_patch_width = W // p_w
base_hidden_states = self.patch_embed_base(base_hidden_states)
base_hidden_states = base_hidden_states.flatten(1, 3)
# 6. Time embeddings
if timestep.ndim == 1:
temb, embedded_timestep = self.time_embed(base_hidden_states, timestep)
elif timestep.ndim == 5:
batch_size, _, num_frames, _, _ = latents.shape
assert timestep.shape == (batch_size, 1, num_frames, 1, 1), (
f"Expected timestep to have shape [B, 1, T, 1, 1], but got {timestep.shape}"
)
timestep_flat = timestep.flatten()
temb, embedded_timestep = self.time_embed(base_hidden_states, timestep_flat)
temb, embedded_timestep = (
x.view(batch_size, post_patch_num_frames, 1, 1, -1)
.expand(-1, -1, post_patch_height, post_patch_width, -1)
.flatten(1, 3)
for x in (temb, embedded_timestep)
)
else:
raise ValueError(f"Expected timestep to have shape [B, 1, T, 1, 1] or [T], but got {timestep.shape}")
# 7. Process encoder hidden states
if isinstance(encoder_hidden_states, tuple):
text_context, img_context = encoder_hidden_states
else:
text_context = encoder_hidden_states
img_context = None
# Apply cross-attention projection to text context
if self.crossattn_proj is not None:
text_context = self.crossattn_proj(text_context)
# Apply cross-attention projection to image context (if provided)
if img_context is not None and self.img_context_proj is not None:
img_context = self.img_context_proj(img_context)
# Combine text and image context into a single tuple
if self.config.img_context_dim_in is not None and self.config.img_context_dim_in > 0:
processed_encoder_hidden_states = (text_context, img_context)
else:
processed_encoder_hidden_states = text_context
# 8. Prepare attention mask
if attention_mask is not None:
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, S]
# 9. Run control blocks
scales = self._expand_conditioning_scale(conditioning_scale)
result = []
for block_idx, (block, scale) in enumerate(zip(self.control_blocks, scales)):
if torch.is_grad_enabled() and self.gradient_checkpointing:
control_hidden_states, control_proj = self._gradient_checkpointing_func(
block,
control_hidden_states,
processed_encoder_hidden_states,
embedded_timestep,
temb,
image_rotary_emb,
extra_pos_emb,
attention_mask,
None, # controlnet_residual
base_hidden_states,
block_idx,
)
else:
control_hidden_states, control_proj = block(
hidden_states=control_hidden_states,
encoder_hidden_states=processed_encoder_hidden_states,
embedded_timestep=embedded_timestep,
temb=temb,
image_rotary_emb=image_rotary_emb,
extra_pos_emb=extra_pos_emb,
attention_mask=attention_mask,
controlnet_residual=None,
latents=base_hidden_states,
block_idx=block_idx,
)
result.append(control_proj * scale)
if not return_dict:
return (result,)
return CosmosControlNetOutput(control_block_samples=result)
|