Spaces:
Build error
Build error
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ | |
| A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing. | |
| """ | |
| from typing import List, Optional, Tuple | |
| import torch | |
| from einops import rearrange | |
| from torch import nn | |
| from torch.distributed import ProcessGroup, get_process_group_ranks | |
| from torchvision import transforms | |
| from cosmos_predict1.diffusion.conditioner import DataType | |
| from cosmos_predict1.diffusion.module.attention import get_normalization | |
| from cosmos_predict1.diffusion.module.blocks import ( | |
| FinalLayer, | |
| GeneralDITTransformerBlock, | |
| PatchEmbed, | |
| TimestepEmbedding, | |
| Timesteps, | |
| ) | |
| from cosmos_predict1.diffusion.module.position_embedding import LearnablePosEmbAxis, VideoRopePosition3DEmb | |
| from cosmos_predict1.utils import log | |
| class GeneralDIT(nn.Module): | |
| """ | |
| A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing. | |
| Args: | |
| max_img_h (int): Maximum height of the input images. | |
| max_img_w (int): Maximum width of the input images. | |
| max_frames (int): Maximum number of frames in the video sequence. | |
| in_channels (int): Number of input channels (e.g., RGB channels for color images). | |
| out_channels (int): Number of output channels. | |
| patch_spatial (tuple): Spatial resolution of patches for input processing. | |
| patch_temporal (int): Temporal resolution of patches for input processing. | |
| concat_padding_mask (bool): If True, includes a mask channel in the input to handle padding. | |
| block_config (str): Configuration of the transformer block. See Notes for supported block types. | |
| model_channels (int): Base number of channels used throughout the model. | |
| num_blocks (int): Number of transformer blocks. | |
| num_heads (int): Number of heads in the multi-head attention layers. | |
| mlp_ratio (float): Expansion ratio for MLP blocks. | |
| block_x_format (str): Format of input tensor for transformer blocks ('BTHWD' or 'THWBD'). | |
| crossattn_emb_channels (int): Number of embedding channels for cross-attention. | |
| use_cross_attn_mask (bool): Whether to use mask in cross-attention. | |
| pos_emb_cls (str): Type of positional embeddings. | |
| pos_emb_learnable (bool): Whether positional embeddings are learnable. | |
| pos_emb_interpolation (str): Method for interpolating positional embeddings. | |
| affline_emb_norm (bool): Whether to normalize affine embeddings. | |
| use_adaln_lora (bool): Whether to use AdaLN-LoRA. | |
| adaln_lora_dim (int): Dimension for AdaLN-LoRA. | |
| rope_h_extrapolation_ratio (float): Height extrapolation ratio for RoPE. | |
| rope_w_extrapolation_ratio (float): Width extrapolation ratio for RoPE. | |
| rope_t_extrapolation_ratio (float): Temporal extrapolation ratio for RoPE. | |
| extra_per_block_abs_pos_emb (bool): Whether to use extra per-block absolute positional embeddings. | |
| extra_per_block_abs_pos_emb_type (str): Type of extra per-block positional embeddings. | |
| extra_h_extrapolation_ratio (float): Height extrapolation ratio for extra embeddings. | |
| extra_w_extrapolation_ratio (float): Width extrapolation ratio for extra embeddings. | |
| extra_t_extrapolation_ratio (float): Temporal extrapolation ratio for extra embeddings. | |
| Notes: | |
| Supported block types in block_config: | |
| * cross_attn, ca: Cross attention | |
| * full_attn: Full attention on all flattened tokens | |
| * mlp, ff: Feed forward block | |
| """ | |
| def __init__( | |
| self, | |
| max_img_h: int, | |
| max_img_w: int, | |
| max_frames: int, | |
| in_channels: int, | |
| out_channels: int, | |
| patch_spatial: tuple, | |
| patch_temporal: int, | |
| concat_padding_mask: bool = True, | |
| # attention settings | |
| block_config: str = "FA-CA-MLP", | |
| model_channels: int = 768, | |
| num_blocks: int = 10, | |
| num_heads: int = 16, | |
| mlp_ratio: float = 4.0, | |
| block_x_format: str = "BTHWD", | |
| # cross attention settings | |
| crossattn_emb_channels: int = 1024, | |
| use_cross_attn_mask: bool = False, | |
| # positional embedding settings | |
| pos_emb_cls: str = "sincos", | |
| pos_emb_learnable: bool = False, | |
| pos_emb_interpolation: str = "crop", | |
| affline_emb_norm: bool = False, # whether or not to normalize the affine embedding | |
| use_adaln_lora: bool = False, | |
| adaln_lora_dim: int = 256, | |
| rope_h_extrapolation_ratio: float = 1.0, | |
| rope_w_extrapolation_ratio: float = 1.0, | |
| rope_t_extrapolation_ratio: float = 1.0, | |
| extra_per_block_abs_pos_emb: bool = True, | |
| extra_per_block_abs_pos_emb_type: str = "learnable", | |
| extra_h_extrapolation_ratio: float = 1.0, | |
| extra_w_extrapolation_ratio: float = 1.0, | |
| extra_t_extrapolation_ratio: float = 1.0, | |
| ) -> None: | |
| super().__init__() | |
| self.max_img_h = max_img_h | |
| self.max_img_w = max_img_w | |
| self.max_frames = max_frames | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.patch_spatial = patch_spatial | |
| self.patch_temporal = patch_temporal | |
| self.num_heads = num_heads | |
| self.num_blocks = num_blocks | |
| self.model_channels = model_channels | |
| self.use_cross_attn_mask = use_cross_attn_mask | |
| self.concat_padding_mask = concat_padding_mask | |
| # positional embedding settings | |
| self.pos_emb_cls = pos_emb_cls | |
| self.pos_emb_learnable = pos_emb_learnable | |
| self.pos_emb_interpolation = pos_emb_interpolation | |
| self.affline_emb_norm = affline_emb_norm | |
| self.rope_h_extrapolation_ratio = rope_h_extrapolation_ratio | |
| self.rope_w_extrapolation_ratio = rope_w_extrapolation_ratio | |
| self.rope_t_extrapolation_ratio = rope_t_extrapolation_ratio | |
| self.extra_per_block_abs_pos_emb = extra_per_block_abs_pos_emb | |
| self.extra_per_block_abs_pos_emb_type = extra_per_block_abs_pos_emb_type.lower() | |
| self.extra_h_extrapolation_ratio = extra_h_extrapolation_ratio | |
| self.extra_w_extrapolation_ratio = extra_w_extrapolation_ratio | |
| self.extra_t_extrapolation_ratio = extra_t_extrapolation_ratio | |
| self.build_patch_embed() | |
| self.build_pos_embed() | |
| self.cp_group = None | |
| self.block_x_format = block_x_format | |
| self.use_adaln_lora = use_adaln_lora | |
| self.adaln_lora_dim = adaln_lora_dim | |
| self.t_embedder = nn.Sequential( | |
| Timesteps(model_channels), | |
| TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora), | |
| ) | |
| self.blocks = nn.ModuleDict() | |
| for idx in range(num_blocks): | |
| self.blocks[f"block{idx}"] = GeneralDITTransformerBlock( | |
| x_dim=model_channels, | |
| context_dim=crossattn_emb_channels, | |
| num_heads=num_heads, | |
| block_config=block_config, | |
| mlp_ratio=mlp_ratio, | |
| x_format=self.block_x_format, | |
| use_adaln_lora=use_adaln_lora, | |
| adaln_lora_dim=adaln_lora_dim, | |
| ) | |
| self.build_decode_head() | |
| if self.affline_emb_norm: | |
| log.debug("Building affine embedding normalization layer") | |
| self.affline_norm = get_normalization("R", model_channels) | |
| else: | |
| self.affline_norm = nn.Identity() | |
| self.initialize_weights() | |
| def initialize_weights(self): | |
| # Initialize transformer layers: | |
| def _basic_init(module): | |
| if isinstance(module, nn.Linear): | |
| torch.nn.init.xavier_uniform_(module.weight) | |
| if module.bias is not None: | |
| nn.init.constant_(module.bias, 0) | |
| self.apply(_basic_init) | |
| # Initialize timestep embedding | |
| nn.init.normal_(self.t_embedder[1].linear_1.weight, std=0.02) | |
| if self.t_embedder[1].linear_1.bias is not None: | |
| nn.init.constant_(self.t_embedder[1].linear_1.bias, 0) | |
| nn.init.normal_(self.t_embedder[1].linear_2.weight, std=0.02) | |
| if self.t_embedder[1].linear_2.bias is not None: | |
| nn.init.constant_(self.t_embedder[1].linear_2.bias, 0) | |
| # Zero-out adaLN modulation layers in DiT blocks: | |
| for transformer_block in self.blocks.values(): | |
| for block in transformer_block.blocks: | |
| nn.init.constant_(block.adaLN_modulation[-1].weight, 0) | |
| if block.adaLN_modulation[-1].bias is not None: | |
| nn.init.constant_(block.adaLN_modulation[-1].bias, 0) | |
| def build_decode_head(self): | |
| self.final_layer = FinalLayer( | |
| hidden_size=self.model_channels, | |
| spatial_patch_size=self.patch_spatial, | |
| temporal_patch_size=self.patch_temporal, | |
| out_channels=self.out_channels, | |
| use_adaln_lora=self.use_adaln_lora, | |
| adaln_lora_dim=self.adaln_lora_dim, | |
| ) | |
| def build_patch_embed(self): | |
| ( | |
| concat_padding_mask, | |
| in_channels, | |
| patch_spatial, | |
| patch_temporal, | |
| model_channels, | |
| ) = ( | |
| self.concat_padding_mask, | |
| self.in_channels, | |
| self.patch_spatial, | |
| self.patch_temporal, | |
| self.model_channels, | |
| ) | |
| in_channels = in_channels + 1 if concat_padding_mask else in_channels | |
| self.x_embedder = PatchEmbed( | |
| spatial_patch_size=patch_spatial, | |
| temporal_patch_size=patch_temporal, | |
| in_channels=in_channels, | |
| out_channels=model_channels, | |
| bias=False, | |
| ) | |
| def build_pos_embed(self): | |
| if self.pos_emb_cls == "rope3d": | |
| cls_type = VideoRopePosition3DEmb | |
| else: | |
| raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}") | |
| log.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}") | |
| kwargs = dict( | |
| model_channels=self.model_channels, | |
| len_h=self.max_img_h // self.patch_spatial, | |
| len_w=self.max_img_w // self.patch_spatial, | |
| len_t=self.max_frames // self.patch_temporal, | |
| is_learnable=self.pos_emb_learnable, | |
| interpolation=self.pos_emb_interpolation, | |
| head_dim=self.model_channels // self.num_heads, | |
| h_extrapolation_ratio=self.rope_h_extrapolation_ratio, | |
| w_extrapolation_ratio=self.rope_w_extrapolation_ratio, | |
| t_extrapolation_ratio=self.rope_t_extrapolation_ratio, | |
| ) | |
| self.pos_embedder = cls_type( | |
| **kwargs, | |
| ) | |
| assert self.extra_per_block_abs_pos_emb is True, "extra_per_block_abs_pos_emb must be True" | |
| if self.extra_per_block_abs_pos_emb: | |
| assert self.extra_per_block_abs_pos_emb_type in [ | |
| "learnable", | |
| ], f"Unknown extra_per_block_abs_pos_emb_type {self.extra_per_block_abs_pos_emb_type}" | |
| kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio | |
| kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio | |
| kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio | |
| self.extra_pos_embedder = LearnablePosEmbAxis(**kwargs) | |
| def prepare_embedded_sequence( | |
| self, | |
| x_B_C_T_H_W: torch.Tensor, | |
| fps: Optional[torch.Tensor] = None, | |
| padding_mask: Optional[torch.Tensor] = None, | |
| latent_condition: Optional[torch.Tensor] = None, | |
| latent_condition_sigma: Optional[torch.Tensor] = None, | |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
| """ | |
| Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks. | |
| Args: | |
| x_B_C_T_H_W (torch.Tensor): video | |
| fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required. | |
| If None, a default value (`self.base_fps`) will be used. | |
| padding_mask (Optional[torch.Tensor]): current it is not used | |
| Returns: | |
| Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
| - A tensor of shape (B, T, H, W, D) with the embedded sequence. | |
| - An optional positional embedding tensor, returned only if the positional embedding class | |
| (`self.pos_emb_cls`) includes 'rope'. Otherwise, None. | |
| Notes: | |
| - If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor. | |
| - The method of applying positional embeddings depends on the value of `self.pos_emb_cls`. | |
| - If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using | |
| the `self.pos_embedder` with the shape [T, H, W]. | |
| - If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the | |
| `self.pos_embedder` with the fps tensor. | |
| - Otherwise, the positional embeddings are generated without considering fps. | |
| """ | |
| if self.concat_padding_mask: | |
| padding_mask = transforms.functional.resize( | |
| padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST | |
| ) | |
| x_B_C_T_H_W = torch.cat( | |
| [x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1 | |
| ) | |
| x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W) | |
| if self.extra_per_block_abs_pos_emb: | |
| extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps) | |
| else: | |
| extra_pos_emb = None | |
| if "rope" in self.pos_emb_cls.lower(): | |
| return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps), extra_pos_emb | |
| if "fps_aware" in self.pos_emb_cls: | |
| x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, fps=fps) # [B, T, H, W, D] | |
| else: | |
| x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D) # [B, T, H, W, D] | |
| return x_B_T_H_W_D, None, extra_pos_emb | |
| def decoder_head( | |
| self, | |
| x_B_T_H_W_D: torch.Tensor, | |
| emb_B_D: torch.Tensor, | |
| crossattn_emb: torch.Tensor, | |
| origin_shape: Tuple[int, int, int, int, int], # [B, C, T, H, W] | |
| crossattn_mask: Optional[torch.Tensor] = None, | |
| adaln_lora_B_3D: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| del crossattn_emb, crossattn_mask | |
| B, C, T_before_patchify, H_before_patchify, W_before_patchify = origin_shape | |
| x_BT_HW_D = rearrange(x_B_T_H_W_D, "B T H W D -> (B T) (H W) D") | |
| x_BT_HW_D = self.final_layer(x_BT_HW_D, emb_B_D, adaln_lora_B_3D=adaln_lora_B_3D) | |
| # This is to ensure x_BT_HW_D has the correct shape because | |
| # when we merge T, H, W into one dimension, x_BT_HW_D has shape (B * T * H * W, 1*1, D). | |
| x_BT_HW_D = x_BT_HW_D.view( | |
| B * T_before_patchify // self.patch_temporal, | |
| H_before_patchify // self.patch_spatial * W_before_patchify // self.patch_spatial, | |
| -1, | |
| ) | |
| x_B_D_T_H_W = rearrange( | |
| x_BT_HW_D, | |
| "(B T) (H W) (p1 p2 t C) -> B C (T t) (H p1) (W p2)", | |
| p1=self.patch_spatial, | |
| p2=self.patch_spatial, | |
| H=H_before_patchify // self.patch_spatial, | |
| W=W_before_patchify // self.patch_spatial, | |
| t=self.patch_temporal, | |
| B=B, | |
| ) | |
| return x_B_D_T_H_W | |
| def forward_before_blocks( | |
| self, | |
| x: torch.Tensor, | |
| timesteps: torch.Tensor, | |
| crossattn_emb: torch.Tensor, | |
| crossattn_mask: Optional[torch.Tensor] = None, | |
| fps: Optional[torch.Tensor] = None, | |
| image_size: Optional[torch.Tensor] = None, | |
| padding_mask: Optional[torch.Tensor] = None, | |
| scalar_feature: Optional[torch.Tensor] = None, | |
| data_type: Optional[DataType] = DataType.VIDEO, | |
| latent_condition: Optional[torch.Tensor] = None, | |
| latent_condition_sigma: Optional[torch.Tensor] = None, | |
| **kwargs, | |
| ) -> torch.Tensor: | |
| """ | |
| Args: | |
| x: (B, C, T, H, W) tensor of spatial-temp inputs | |
| timesteps: (B, ) tensor of timesteps | |
| crossattn_emb: (B, N, D) tensor of cross-attention embeddings | |
| crossattn_mask: (B, N) tensor of cross-attention masks | |
| """ | |
| del kwargs | |
| assert isinstance( | |
| data_type, DataType | |
| ), f"Expected DataType, got {type(data_type)}. We need discuss this flag later." | |
| original_shape = x.shape | |
| x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence( | |
| x, | |
| fps=fps, | |
| padding_mask=padding_mask, | |
| latent_condition=latent_condition, | |
| latent_condition_sigma=latent_condition_sigma, | |
| ) | |
| # logging affline scale information | |
| affline_scale_log_info = {} | |
| timesteps_B_D, adaln_lora_B_3D = self.t_embedder(timesteps.flatten()) | |
| affline_emb_B_D = timesteps_B_D | |
| affline_scale_log_info["timesteps_B_D"] = timesteps_B_D.detach() | |
| if scalar_feature is not None: | |
| raise NotImplementedError("Scalar feature is not implemented yet.") | |
| affline_scale_log_info["affline_emb_B_D"] = affline_emb_B_D.detach() | |
| affline_emb_B_D = self.affline_norm(affline_emb_B_D) | |
| if self.use_cross_attn_mask: | |
| crossattn_mask = crossattn_mask[:, None, None, :].to(dtype=torch.bool) # [B, 1, 1, length] | |
| else: | |
| crossattn_mask = None | |
| if self.blocks["block0"].x_format == "THWBD": | |
| x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D") | |
| if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: | |
| extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange( | |
| extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D" | |
| ) | |
| crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D") | |
| if crossattn_mask: | |
| crossattn_mask = rearrange(crossattn_mask, "B M -> M B") | |
| elif self.blocks["block0"].x_format == "BTHWD": | |
| x = x_B_T_H_W_D | |
| else: | |
| raise ValueError(f"Unknown x_format {self.blocks[0].x_format}") | |
| output = { | |
| "x": x, | |
| "affline_emb_B_D": affline_emb_B_D, | |
| "crossattn_emb": crossattn_emb, | |
| "crossattn_mask": crossattn_mask, | |
| "rope_emb_L_1_1_D": rope_emb_L_1_1_D, | |
| "adaln_lora_B_3D": adaln_lora_B_3D, | |
| "original_shape": original_shape, | |
| "extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, | |
| } | |
| return output | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| timesteps: torch.Tensor, | |
| crossattn_emb: torch.Tensor, | |
| crossattn_mask: Optional[torch.Tensor] = None, | |
| fps: Optional[torch.Tensor] = None, | |
| image_size: Optional[torch.Tensor] = None, | |
| padding_mask: Optional[torch.Tensor] = None, | |
| scalar_feature: Optional[torch.Tensor] = None, | |
| data_type: Optional[DataType] = DataType.VIDEO, | |
| latent_condition: Optional[torch.Tensor] = None, | |
| latent_condition_sigma: Optional[torch.Tensor] = None, | |
| condition_video_augment_sigma: Optional[torch.Tensor] = None, | |
| **kwargs, | |
| ) -> torch.Tensor | List[torch.Tensor] | Tuple[torch.Tensor, List[torch.Tensor]]: | |
| """ | |
| Args: | |
| x: (B, C, T, H, W) tensor of spatial-temp inputs | |
| timesteps: (B, ) tensor of timesteps | |
| crossattn_emb: (B, N, D) tensor of cross-attention embeddings | |
| crossattn_mask: (B, N) tensor of cross-attention masks | |
| condition_video_augment_sigma: (B,) used in lvg(long video generation), we add noise with this sigma to | |
| augment condition input, the lvg model will condition on the condition_video_augment_sigma value; | |
| we need forward_before_blocks pass to the forward_before_blocks function. | |
| """ | |
| inputs = self.forward_before_blocks( | |
| x=x, | |
| timesteps=timesteps, | |
| crossattn_emb=crossattn_emb, | |
| crossattn_mask=crossattn_mask, | |
| fps=fps, | |
| image_size=image_size, | |
| padding_mask=padding_mask, | |
| scalar_feature=scalar_feature, | |
| data_type=data_type, | |
| latent_condition=latent_condition, | |
| latent_condition_sigma=latent_condition_sigma, | |
| condition_video_augment_sigma=condition_video_augment_sigma, | |
| **kwargs, | |
| ) | |
| x, affline_emb_B_D, crossattn_emb, crossattn_mask, rope_emb_L_1_1_D, adaln_lora_B_3D, original_shape = ( | |
| inputs["x"], | |
| inputs["affline_emb_B_D"], | |
| inputs["crossattn_emb"], | |
| inputs["crossattn_mask"], | |
| inputs["rope_emb_L_1_1_D"], | |
| inputs["adaln_lora_B_3D"], | |
| inputs["original_shape"], | |
| ) | |
| extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = inputs["extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D"] | |
| if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: | |
| assert ( | |
| x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape | |
| ), f"{x.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape} {original_shape}" | |
| for _, block in self.blocks.items(): | |
| assert ( | |
| self.blocks["block0"].x_format == block.x_format | |
| ), f"First block has x_format {self.blocks[0].x_format}, got {block.x_format}" | |
| x = block( | |
| x, | |
| affline_emb_B_D, | |
| crossattn_emb, | |
| crossattn_mask, | |
| rope_emb_L_1_1_D=rope_emb_L_1_1_D, | |
| adaln_lora_B_3D=adaln_lora_B_3D, | |
| extra_per_block_pos_emb=extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, | |
| ) | |
| x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D") | |
| x_B_D_T_H_W = self.decoder_head( | |
| x_B_T_H_W_D=x_B_T_H_W_D, | |
| emb_B_D=affline_emb_B_D, | |
| crossattn_emb=None, | |
| origin_shape=original_shape, | |
| crossattn_mask=None, | |
| adaln_lora_B_3D=adaln_lora_B_3D, | |
| ) | |
| return x_B_D_T_H_W | |
| def enable_context_parallel(self, cp_group: ProcessGroup): | |
| cp_ranks = get_process_group_ranks(cp_group) | |
| cp_size = len(cp_ranks) | |
| # Set these attributes for spliting the data after embedding. | |
| self.cp_group = cp_group | |
| # Set these attributes for computing the loss. | |
| self.cp_size = cp_size | |
| self.pos_embedder.enable_context_parallel(cp_group) | |
| if self.extra_per_block_abs_pos_emb: | |
| self.extra_pos_embedder.enable_context_parallel(cp_group) | |
| # Loop through the model to set up context parallel. | |
| for block in self.blocks.values(): | |
| for layer in block.blocks: | |
| if layer.block_type in ["mlp", "ff", "cross_attn", "ca"]: | |
| continue | |
| elif layer.block.attn.backend == "transformer_engine": | |
| layer.block.attn.attn_op.set_context_parallel_group(cp_group, cp_ranks, torch.cuda.Stream()) | |
| log.debug(f"[CP] Enable context parallelism with size {cp_size}") | |
| def disable_context_parallel(self): | |
| self.cp_group = None | |
| self.cp_size = None | |
| self.pos_embedder.disable_context_parallel() | |
| if self.extra_per_block_abs_pos_emb: | |
| self.extra_pos_embedder.disable_context_parallel() | |
| # Loop through the model to disable context parallel. | |
| for block in self.blocks.values(): | |
| for layer in block.blocks: | |
| if layer.block_type in ["mlp", "ff"]: | |
| continue | |
| elif layer.block_type in ["cross_attn", "ca"]: | |
| continue | |
| else: | |
| layer.block.attn.attn_op.cp_group = None | |
| layer.block.attn.attn_op.cp_ranks = None | |
| layer.block.attn.attn_op.cp_stream = None | |
| log.debug("[CP] Disable context parallelism.") | |
| def is_context_parallel_enabled(self): | |
| return self.cp_group is not None | |