|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
|
from dataclasses import dataclass |
|
|
from typing import Literal, Union |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from megatron.core.jit import jit_fuser |
|
|
from megatron.core.tensor_parallel.layers import ColumnParallelLinear |
|
|
from megatron.core.transformer.attention import ( |
|
|
CrossAttention, |
|
|
CrossAttentionSubmodules, |
|
|
SelfAttention, |
|
|
SelfAttentionSubmodules, |
|
|
) |
|
|
from megatron.core.transformer.cuda_graphs import CudaGraphManager |
|
|
from megatron.core.transformer.enums import AttnMaskType |
|
|
from megatron.core.transformer.identity_op import IdentityOp |
|
|
from megatron.core.transformer.mlp import MLP, MLPSubmodules |
|
|
from megatron.core.transformer.module import MegatronModule |
|
|
from megatron.core.transformer.spec_utils import ModuleSpec, build_module |
|
|
from megatron.core.transformer.transformer_config import TransformerConfig |
|
|
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules |
|
|
from megatron.core.utils import make_viewless_tensor |
|
|
|
|
|
from nemo.collections.diffusion.models.dit.dit_attention import ( |
|
|
FluxSingleAttention, |
|
|
JointSelfAttention, |
|
|
JointSelfAttentionSubmodules, |
|
|
) |
|
|
|
|
|
try: |
|
|
from megatron.core.transformer.custom_layers.transformer_engine import ( |
|
|
TEColumnParallelLinear, |
|
|
TEDotProductAttention, |
|
|
TENorm, |
|
|
TERowParallelLinear, |
|
|
) |
|
|
except ImportError: |
|
|
from nemo.utils import logging |
|
|
|
|
|
logging.warning( |
|
|
"Failed to import Transformer Engine dependencies. " |
|
|
"`from megatron.core.transformer.custom_layers.transformer_engine import *`" |
|
|
"If using NeMo Run, this is expected. Otherwise, please verify the Transformer Engine installation." |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DiTWithAdaLNSubmodules(TransformerLayerSubmodules): |
|
|
""" |
|
|
Submodules for DiT with AdaLN. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
temporal_self_attention: Union[ModuleSpec, type] = IdentityOp |
|
|
full_self_attention: Union[ModuleSpec, type] = IdentityOp |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class STDiTWithAdaLNSubmodules(TransformerLayerSubmodules): |
|
|
""" |
|
|
Submodules for STDiT with AdaLN. |
|
|
""" |
|
|
|
|
|
|
|
|
spatial_self_attention: Union[ModuleSpec, type] = IdentityOp |
|
|
temporal_self_attention: Union[ModuleSpec, type] = IdentityOp |
|
|
full_self_attention: Union[ModuleSpec, type] = IdentityOp |
|
|
|
|
|
|
|
|
class RMSNorm(nn.Module): |
|
|
""" |
|
|
RMSNorm Module. |
|
|
""" |
|
|
|
|
|
|
|
|
def __init__(self, hidden_size: int, config, eps: float = 1e-6): |
|
|
super().__init__() |
|
|
self.eps = eps |
|
|
self.weight = nn.Parameter(torch.ones(hidden_size)) |
|
|
|
|
|
def _norm(self, x): |
|
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
|
|
|
|
|
def forward(self, x): |
|
|
output = self._norm(x.float()).type_as(x) |
|
|
return output * self.weight |
|
|
|
|
|
|
|
|
class AdaLN(MegatronModule): |
|
|
""" |
|
|
Adaptive Layer Normalization Module for DiT. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
config: TransformerConfig, |
|
|
n_adaln_chunks=9, |
|
|
norm=nn.LayerNorm, |
|
|
modulation_bias=False, |
|
|
use_second_norm=False, |
|
|
): |
|
|
super().__init__(config) |
|
|
if norm == TENorm: |
|
|
self.ln = norm(config, config.hidden_size, config.layernorm_epsilon) |
|
|
else: |
|
|
self.ln = norm(config.hidden_size, elementwise_affine=False, eps=self.config.layernorm_epsilon) |
|
|
self.n_adaln_chunks = n_adaln_chunks |
|
|
self.adaLN_modulation = nn.Sequential( |
|
|
nn.SiLU(), |
|
|
ColumnParallelLinear( |
|
|
config.hidden_size, |
|
|
self.n_adaln_chunks * config.hidden_size, |
|
|
config=config, |
|
|
init_method=nn.init.normal_, |
|
|
bias=modulation_bias, |
|
|
gather_output=True, |
|
|
), |
|
|
) |
|
|
self.use_second_norm = use_second_norm |
|
|
if self.use_second_norm: |
|
|
self.ln2 = nn.LayerNorm(config.hidden_size, elementwise_affine=False, eps=1e-6) |
|
|
nn.init.constant_(self.adaLN_modulation[-1].weight, 0) |
|
|
|
|
|
setattr(self.adaLN_modulation[-1].weight, "sequence_parallel", config.sequence_parallel) |
|
|
|
|
|
@jit_fuser |
|
|
def forward(self, timestep_emb): |
|
|
output, bias = self.adaLN_modulation(timestep_emb) |
|
|
output = output + bias if bias else output |
|
|
return output.chunk(self.n_adaln_chunks, dim=-1) |
|
|
|
|
|
@jit_fuser |
|
|
def modulate(self, x, shift, scale): |
|
|
return x * (1 + scale) + shift |
|
|
|
|
|
@jit_fuser |
|
|
def scale_add(self, residual, x, gate): |
|
|
return residual + gate * x |
|
|
|
|
|
@jit_fuser |
|
|
def modulated_layernorm(self, x, shift, scale, layernorm_idx=0): |
|
|
if self.use_second_norm and layernorm_idx == 1: |
|
|
layernorm = self.ln2 |
|
|
else: |
|
|
layernorm = self.ln |
|
|
|
|
|
input_layernorm_output = layernorm(x).type_as(x) |
|
|
|
|
|
|
|
|
return self.modulate(input_layernorm_output, shift, scale) |
|
|
|
|
|
@jit_fuser |
|
|
def scaled_modulated_layernorm(self, residual, x, gate, shift, scale, layernorm_idx=0): |
|
|
hidden_states = self.scale_add(residual, x, gate) |
|
|
shifted_pre_mlp_layernorm_output = self.modulated_layernorm(hidden_states, shift, scale, layernorm_idx) |
|
|
return hidden_states, shifted_pre_mlp_layernorm_output |
|
|
|
|
|
|
|
|
class AdaLNContinuous(MegatronModule): |
|
|
''' |
|
|
A variant of AdaLN used for flux models. |
|
|
''' |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
config: TransformerConfig, |
|
|
conditioning_embedding_dim: int, |
|
|
modulation_bias: bool = True, |
|
|
norm_type: str = "layer_norm", |
|
|
): |
|
|
super().__init__(config) |
|
|
self.adaLN_modulation = nn.Sequential( |
|
|
nn.SiLU(), nn.Linear(conditioning_embedding_dim, config.hidden_size * 2, bias=modulation_bias) |
|
|
) |
|
|
if norm_type == "layer_norm": |
|
|
self.norm = nn.LayerNorm(config.hidden_size, elementwise_affine=False, eps=1e-6, bias=modulation_bias) |
|
|
elif norm_type == "rms_norm": |
|
|
self.norm = RMSNorm(config.hidden_size, eps=1e-6) |
|
|
else: |
|
|
raise ValueError("Unknown normalization type {}".format(norm_type)) |
|
|
|
|
|
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: |
|
|
emb = self.adaLN_modulation(conditioning_embedding) |
|
|
scale, shift = torch.chunk(emb, 2, dim=1) |
|
|
x = self.norm(x) * (1 + scale) + shift |
|
|
return x |
|
|
|
|
|
|
|
|
class STDiTLayerWithAdaLN(TransformerLayer): |
|
|
"""A single transformer layer. |
|
|
|
|
|
Transformer layer takes input with size [s, b, h] and returns an |
|
|
output of the same size. |
|
|
|
|
|
Spatial-Temporal DiT with Adapative Layer Normalization. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
config: TransformerConfig, |
|
|
submodules: TransformerLayerSubmodules, |
|
|
layer_number: int = 1, |
|
|
hidden_dropout: float = None, |
|
|
position_embedding_type: Literal["learned_absolute", "rope"] = "learned_absolute", |
|
|
): |
|
|
def _replace_no_cp_submodules(submodules): |
|
|
modified_submods = copy.deepcopy(submodules) |
|
|
modified_submods.cross_attention = IdentityOp |
|
|
modified_submods.spatial_self_attention = IdentityOp |
|
|
return modified_submods |
|
|
|
|
|
|
|
|
modified_submods = _replace_no_cp_submodules(submodules) |
|
|
super().__init__( |
|
|
config=config, submodules=modified_submods, layer_number=layer_number, hidden_dropout=hidden_dropout |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sa_cp_override_config = copy.deepcopy(config) |
|
|
sa_cp_override_config.context_parallel_size = 1 |
|
|
sa_cp_override_config.tp_comm_overlap = False |
|
|
self.spatial_self_attention = build_module( |
|
|
submodules.spatial_self_attention, config=sa_cp_override_config, layer_number=layer_number |
|
|
) |
|
|
self.cross_attention = build_module( |
|
|
submodules.cross_attention, |
|
|
config=sa_cp_override_config, |
|
|
layer_number=layer_number, |
|
|
) |
|
|
|
|
|
self.temporal_self_attention = build_module( |
|
|
submodules.temporal_self_attention, |
|
|
config=self.config, |
|
|
layer_number=layer_number, |
|
|
) |
|
|
|
|
|
self.full_self_attention = build_module( |
|
|
submodules.full_self_attention, |
|
|
config=self.config, |
|
|
layer_number=layer_number, |
|
|
) |
|
|
|
|
|
self.adaLN = AdaLN(config=self.config, n_adaln_chunks=3) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states, |
|
|
attention_mask, |
|
|
context=None, |
|
|
context_mask=None, |
|
|
rotary_pos_emb=None, |
|
|
inference_params=None, |
|
|
packed_seq_params=None, |
|
|
): |
|
|
|
|
|
timestep_emb = attention_mask |
|
|
|
|
|
|
|
|
|
|
|
shift_sa, scale_sa, gate_sa = self.adaLN(timestep_emb) |
|
|
|
|
|
|
|
|
pre_spatial_attn_layernorm_output_ada = self.adaLN.modulated_layernorm( |
|
|
hidden_states, shift=shift_sa, scale=scale_sa |
|
|
) |
|
|
|
|
|
attention_output, _ = self.spatial_self_attention( |
|
|
pre_spatial_attn_layernorm_output_ada, |
|
|
attention_mask=None, |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
shift_full, scale_full, gate_full = self.adaLN(timestep_emb) |
|
|
|
|
|
|
|
|
hidden_states, pre_full_attn_layernorm_output_ada = self.adaLN.scaled_modulated_layernorm( |
|
|
residual=hidden_states, |
|
|
x=attention_output, |
|
|
gate=gate_sa, |
|
|
shift=shift_full, |
|
|
scale=scale_full, |
|
|
) |
|
|
|
|
|
attention_output, _ = self.full_self_attention( |
|
|
pre_full_attn_layernorm_output_ada, |
|
|
attention_mask=None, |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
shift_ca, scale_ca, gate_ca = self.adaLN(timestep_emb) |
|
|
|
|
|
|
|
|
hidden_states, pre_cross_attn_layernorm_output_ada = self.adaLN.scaled_modulated_layernorm( |
|
|
residual=hidden_states, |
|
|
x=attention_output, |
|
|
gate=gate_full, |
|
|
shift=shift_ca, |
|
|
scale=scale_ca, |
|
|
) |
|
|
|
|
|
attention_output, _ = self.cross_attention( |
|
|
pre_cross_attn_layernorm_output_ada, |
|
|
attention_mask=context_mask, |
|
|
key_value_states=context, |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
shift_ta, scale_ta, gate_ta = self.adaLN(timestep_emb) |
|
|
|
|
|
hidden_states, pre_temporal_attn_layernorm_output_ada = self.adaLN.scaled_modulated_layernorm( |
|
|
residual=hidden_states, |
|
|
x=attention_output, |
|
|
gate=gate_ca, |
|
|
shift=shift_ta, |
|
|
scale=scale_ta, |
|
|
) |
|
|
|
|
|
attention_output, _ = self.temporal_self_attention( |
|
|
pre_temporal_attn_layernorm_output_ada, |
|
|
attention_mask=None, |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
shift_mlp, scale_mlp, gate_mlp = self.adaLN(timestep_emb) |
|
|
|
|
|
hidden_states, pre_mlp_layernorm_output_ada = self.adaLN.scaled_modulated_layernorm( |
|
|
residual=hidden_states, |
|
|
x=attention_output, |
|
|
gate=gate_ta, |
|
|
shift=shift_mlp, |
|
|
scale=scale_mlp, |
|
|
) |
|
|
|
|
|
mlp_output, _ = self.mlp(pre_mlp_layernorm_output_ada) |
|
|
hidden_states = self.adaLN.scale_add(residual=hidden_states, x=mlp_output, gate=gate_mlp) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output = make_viewless_tensor(inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True) |
|
|
|
|
|
return output, context |
|
|
|
|
|
|
|
|
class DiTLayerWithAdaLN(TransformerLayer): |
|
|
"""A single transformer layer. |
|
|
|
|
|
Transformer layer takes input with size [s, b, h] and returns an |
|
|
output of the same size. |
|
|
|
|
|
DiT with Adapative Layer Normalization. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
config: TransformerConfig, |
|
|
submodules: TransformerLayerSubmodules, |
|
|
layer_number: int = 1, |
|
|
hidden_dropout: float = None, |
|
|
position_embedding_type: Literal["learned_absolute", "rope"] = "learned_absolute", |
|
|
): |
|
|
def _replace_no_cp_submodules(submodules): |
|
|
modified_submods = copy.deepcopy(submodules) |
|
|
modified_submods.cross_attention = IdentityOp |
|
|
|
|
|
return modified_submods |
|
|
|
|
|
|
|
|
modified_submods = _replace_no_cp_submodules(submodules) |
|
|
super().__init__( |
|
|
config=config, submodules=modified_submods, layer_number=layer_number, hidden_dropout=hidden_dropout |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if submodules.cross_attention != IdentityOp: |
|
|
cp_override_config = copy.deepcopy(config) |
|
|
cp_override_config.context_parallel_size = 1 |
|
|
cp_override_config.tp_comm_overlap = False |
|
|
self.cross_attention = build_module( |
|
|
submodules.cross_attention, |
|
|
config=cp_override_config, |
|
|
layer_number=layer_number, |
|
|
) |
|
|
else: |
|
|
self.cross_attention = None |
|
|
|
|
|
self.full_self_attention = build_module( |
|
|
submodules.full_self_attention, |
|
|
config=self.config, |
|
|
layer_number=layer_number, |
|
|
) |
|
|
|
|
|
self.adaLN = AdaLN(config=self.config, n_adaln_chunks=9 if self.cross_attention else 6) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states, |
|
|
attention_mask, |
|
|
context=None, |
|
|
context_mask=None, |
|
|
rotary_pos_emb=None, |
|
|
inference_params=None, |
|
|
packed_seq_params=None, |
|
|
): |
|
|
|
|
|
timestep_emb = attention_mask |
|
|
|
|
|
|
|
|
if self.cross_attention: |
|
|
shift_full, scale_full, gate_full, shift_ca, scale_ca, gate_ca, shift_mlp, scale_mlp, gate_mlp = ( |
|
|
self.adaLN(timestep_emb) |
|
|
) |
|
|
else: |
|
|
shift_full, scale_full, gate_full, shift_mlp, scale_mlp, gate_mlp = self.adaLN(timestep_emb) |
|
|
|
|
|
|
|
|
pre_full_attn_layernorm_output_ada = self.adaLN.modulated_layernorm( |
|
|
hidden_states, shift=shift_full, scale=scale_full |
|
|
) |
|
|
|
|
|
attention_output, _ = self.full_self_attention( |
|
|
pre_full_attn_layernorm_output_ada, |
|
|
attention_mask=None, |
|
|
packed_seq_params=None if packed_seq_params is None else packed_seq_params['self_attention'], |
|
|
) |
|
|
|
|
|
if self.cross_attention: |
|
|
|
|
|
|
|
|
hidden_states, pre_cross_attn_layernorm_output_ada = self.adaLN.scaled_modulated_layernorm( |
|
|
residual=hidden_states, |
|
|
x=attention_output, |
|
|
gate=gate_full, |
|
|
shift=shift_ca, |
|
|
scale=scale_ca, |
|
|
) |
|
|
|
|
|
attention_output, _ = self.cross_attention( |
|
|
pre_cross_attn_layernorm_output_ada, |
|
|
attention_mask=context_mask, |
|
|
key_value_states=context, |
|
|
packed_seq_params=None if packed_seq_params is None else packed_seq_params['cross_attention'], |
|
|
) |
|
|
|
|
|
|
|
|
hidden_states, pre_mlp_layernorm_output_ada = self.adaLN.scaled_modulated_layernorm( |
|
|
residual=hidden_states, |
|
|
x=attention_output, |
|
|
gate=gate_ca if self.cross_attention else gate_full, |
|
|
shift=shift_mlp, |
|
|
scale=scale_mlp, |
|
|
) |
|
|
|
|
|
mlp_output, _ = self.mlp(pre_mlp_layernorm_output_ada) |
|
|
hidden_states = self.adaLN.scale_add(residual=hidden_states, x=mlp_output, gate=gate_mlp) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output = make_viewless_tensor(inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True) |
|
|
|
|
|
return output, context |
|
|
|
|
|
|
|
|
class DiTLayer(TransformerLayer): |
|
|
"""A single transformer layer. |
|
|
|
|
|
Transformer layer takes input with size [s, b, h] and returns an |
|
|
output of the same size. |
|
|
|
|
|
Original DiT layer implementation from [https://arxiv.org/pdf/2212.09748]. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
config: TransformerConfig, |
|
|
submodules: TransformerLayerSubmodules, |
|
|
layer_number: int = 1, |
|
|
mlp_ratio: int = 4, |
|
|
n_adaln_chunks: int = 6, |
|
|
modulation_bias: bool = True, |
|
|
): |
|
|
|
|
|
config.ffn_hidden_size = int(mlp_ratio * config.hidden_size) |
|
|
super().__init__(config=config, submodules=submodules, layer_number=layer_number) |
|
|
|
|
|
self.adaLN = AdaLN( |
|
|
config=config, n_adaln_chunks=n_adaln_chunks, modulation_bias=modulation_bias, use_second_norm=True |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states, |
|
|
attention_mask, |
|
|
context=None, |
|
|
context_mask=None, |
|
|
rotary_pos_emb=None, |
|
|
inference_params=None, |
|
|
packed_seq_params=None, |
|
|
): |
|
|
|
|
|
c = attention_mask |
|
|
|
|
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN(c) |
|
|
|
|
|
shifted_input_layernorm_output = self.adaLN.modulated_layernorm( |
|
|
hidden_states, shift=shift_msa, scale=scale_msa, layernorm_idx=0 |
|
|
) |
|
|
|
|
|
x, bias = self.self_attention(shifted_input_layernorm_output, attention_mask=None) |
|
|
|
|
|
hidden_states = self.adaLN.scale_add(hidden_states, x=(x + bias), gate=gate_msa) |
|
|
|
|
|
residual = hidden_states |
|
|
|
|
|
shited_pre_mlp_layernorm_output = self.adaLN.modulated_layernorm( |
|
|
hidden_states, shift=shift_mlp, scale=scale_mlp, layernorm_idx=1 |
|
|
) |
|
|
|
|
|
x, bias = self.mlp(shited_pre_mlp_layernorm_output) |
|
|
|
|
|
hidden_states = self.adaLN.scale_add(residual, x=(x + bias), gate=gate_mlp) |
|
|
|
|
|
return hidden_states, context |
|
|
|
|
|
|
|
|
class MMDiTLayer(TransformerLayer): |
|
|
"""A multi-modal transformer layer. |
|
|
|
|
|
Transformer layer takes input with size [s, b, h] and returns an |
|
|
output of the same size. |
|
|
|
|
|
MMDiT layer implementation from [https://arxiv.org/pdf/2403.03206]. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
config: TransformerConfig, |
|
|
submodules: TransformerLayerSubmodules, |
|
|
layer_number: int = 1, |
|
|
context_pre_only: bool = False, |
|
|
): |
|
|
|
|
|
hidden_size = config.hidden_size |
|
|
super().__init__(config=config, submodules=submodules, layer_number=layer_number) |
|
|
|
|
|
if config.enable_cuda_graph: |
|
|
self.cudagraph_manager = CudaGraphManager(config, share_cudagraph_io_buffers=False) |
|
|
|
|
|
self.adaln = AdaLN(config, modulation_bias=True, n_adaln_chunks=6, use_second_norm=True) |
|
|
|
|
|
self.context_pre_only = context_pre_only |
|
|
context_norm_type = "ada_norm_continuous" if context_pre_only else "ada_norm_zero" |
|
|
|
|
|
if context_norm_type == "ada_norm_continuous": |
|
|
self.adaln_context = AdaLNContinuous(config, hidden_size, modulation_bias=True, norm_type="layer_norm") |
|
|
elif context_norm_type == "ada_norm_zero": |
|
|
self.adaln_context = AdaLN(config, modulation_bias=True, n_adaln_chunks=6, use_second_norm=True) |
|
|
else: |
|
|
raise ValueError( |
|
|
f"Unknown context_norm_type: {context_norm_type}, " |
|
|
f"currently only support `ada_norm_continous`, `ada_norm_zero`" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
cp_override_config = copy.deepcopy(config) |
|
|
cp_override_config.context_parallel_size = 1 |
|
|
cp_override_config.tp_comm_overlap = False |
|
|
|
|
|
if not context_pre_only: |
|
|
self.context_mlp = build_module( |
|
|
submodules.mlp, |
|
|
config=cp_override_config, |
|
|
) |
|
|
else: |
|
|
self.context_mlp = None |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states, |
|
|
encoder_hidden_states, |
|
|
attention_mask=None, |
|
|
context=None, |
|
|
context_mask=None, |
|
|
rotary_pos_emb=None, |
|
|
inference_params=None, |
|
|
packed_seq_params=None, |
|
|
emb=None, |
|
|
): |
|
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaln(emb) |
|
|
|
|
|
norm_hidden_states = self.adaln.modulated_layernorm( |
|
|
hidden_states, shift=shift_msa, scale=scale_msa, layernorm_idx=0 |
|
|
) |
|
|
if self.context_pre_only: |
|
|
norm_encoder_hidden_states = self.adaln_context(encoder_hidden_states, emb) |
|
|
else: |
|
|
c_shift_msa, c_scale_msa, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.adaln_context(emb) |
|
|
norm_encoder_hidden_states = self.adaln_context.modulated_layernorm( |
|
|
encoder_hidden_states, shift=c_shift_msa, scale=c_scale_msa, layernorm_idx=0 |
|
|
) |
|
|
|
|
|
attention_output, encoder_attention_output = self.self_attention( |
|
|
norm_hidden_states, |
|
|
attention_mask=attention_mask, |
|
|
key_value_states=None, |
|
|
additional_hidden_states=norm_encoder_hidden_states, |
|
|
rotary_pos_emb=rotary_pos_emb, |
|
|
) |
|
|
hidden_states = self.adaln.scale_add(hidden_states, x=attention_output, gate=gate_msa) |
|
|
norm_hidden_states = self.adaln.modulated_layernorm( |
|
|
hidden_states, shift=shift_mlp, scale=scale_mlp, layernorm_idx=1 |
|
|
) |
|
|
|
|
|
mlp_output, mlp_output_bias = self.mlp(norm_hidden_states) |
|
|
hidden_states = self.adaln.scale_add(hidden_states, x=(mlp_output + mlp_output_bias), gate=gate_mlp) |
|
|
|
|
|
if self.context_pre_only: |
|
|
encoder_hidden_states = None |
|
|
else: |
|
|
encoder_hidden_states = self.adaln_context.scale_add( |
|
|
encoder_hidden_states, x=encoder_attention_output, gate=c_gate_msa |
|
|
) |
|
|
norm_encoder_hidden_states = self.adaln_context.modulated_layernorm( |
|
|
encoder_hidden_states, shift=c_shift_mlp, scale=c_scale_mlp, layernorm_idx=1 |
|
|
) |
|
|
|
|
|
context_mlp_output, context_mlp_output_bias = self.context_mlp(norm_encoder_hidden_states) |
|
|
encoder_hidden_states = self.adaln.scale_add( |
|
|
encoder_hidden_states, x=(context_mlp_output + context_mlp_output_bias), gate=c_gate_mlp |
|
|
) |
|
|
|
|
|
return hidden_states, encoder_hidden_states |
|
|
|
|
|
def __call__(self, *args, **kwargs): |
|
|
if hasattr(self, 'cudagraph_manager'): |
|
|
return self.cudagraph_manager(self, args, kwargs) |
|
|
return super(MegatronModule, self).__call__(*args, **kwargs) |
|
|
|
|
|
|
|
|
class FluxSingleTransformerBlock(TransformerLayer): |
|
|
""" |
|
|
Flux Single Transformer Block. |
|
|
|
|
|
Single transformer layer mathematically equivalent to original Flux single transformer. |
|
|
|
|
|
This layer is re-implemented with megatron-core and also altered in structure for better performance. |
|
|
|
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
config: TransformerConfig, |
|
|
submodules: TransformerLayerSubmodules, |
|
|
layer_number: int = 1, |
|
|
mlp_ratio: int = 4, |
|
|
n_adaln_chunks: int = 3, |
|
|
modulation_bias: bool = True, |
|
|
): |
|
|
super().__init__(config=config, submodules=submodules, layer_number=layer_number) |
|
|
|
|
|
if config.enable_cuda_graph: |
|
|
self.cudagraph_manager = CudaGraphManager(config, share_cudagraph_io_buffers=False) |
|
|
self.adaln = AdaLN( |
|
|
config=config, n_adaln_chunks=n_adaln_chunks, modulation_bias=modulation_bias, use_second_norm=False |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states, |
|
|
attention_mask=None, |
|
|
context=None, |
|
|
context_mask=None, |
|
|
rotary_pos_emb=None, |
|
|
inference_params=None, |
|
|
packed_seq_params=None, |
|
|
emb=None, |
|
|
): |
|
|
residual = hidden_states |
|
|
|
|
|
shift, scale, gate = self.adaln(emb) |
|
|
|
|
|
norm_hidden_states = self.adaln.modulated_layernorm(hidden_states, shift=shift, scale=scale) |
|
|
|
|
|
mlp_hidden_states, mlp_bias = self.mlp(norm_hidden_states) |
|
|
|
|
|
attention_output = self.self_attention( |
|
|
norm_hidden_states, attention_mask=attention_mask, rotary_pos_emb=rotary_pos_emb |
|
|
) |
|
|
|
|
|
hidden_states = mlp_hidden_states + mlp_bias + attention_output |
|
|
|
|
|
hidden_states = self.adaln.scale_add(residual, x=hidden_states, gate=gate) |
|
|
|
|
|
return hidden_states, None |
|
|
|
|
|
def __call__(self, *args, **kwargs): |
|
|
if hasattr(self, 'cudagraph_manager'): |
|
|
return self.cudagraph_manager(self, args, kwargs) |
|
|
return super(MegatronModule, self).__call__(*args, **kwargs) |
|
|
|
|
|
|
|
|
def get_stdit_adaln_block_with_transformer_engine_spec() -> ModuleSpec: |
|
|
params = {"attn_mask_type": AttnMaskType.padding} |
|
|
return ModuleSpec( |
|
|
module=STDiTLayerWithAdaLN, |
|
|
submodules=STDiTWithAdaLNSubmodules( |
|
|
spatial_self_attention=ModuleSpec( |
|
|
module=SelfAttention, |
|
|
params=params, |
|
|
submodules=SelfAttentionSubmodules( |
|
|
linear_qkv=TEColumnParallelLinear, |
|
|
core_attention=TEDotProductAttention, |
|
|
linear_proj=TERowParallelLinear, |
|
|
q_layernorm=TENorm, |
|
|
k_layernorm=TENorm, |
|
|
), |
|
|
), |
|
|
temporal_self_attention=ModuleSpec( |
|
|
module=SelfAttention, |
|
|
params=params, |
|
|
submodules=SelfAttentionSubmodules( |
|
|
linear_qkv=TEColumnParallelLinear, |
|
|
core_attention=TEDotProductAttention, |
|
|
linear_proj=TERowParallelLinear, |
|
|
q_layernorm=TENorm, |
|
|
k_layernorm=TENorm, |
|
|
), |
|
|
), |
|
|
full_self_attention=ModuleSpec( |
|
|
module=SelfAttention, |
|
|
params=params, |
|
|
submodules=SelfAttentionSubmodules( |
|
|
linear_qkv=TEColumnParallelLinear, |
|
|
core_attention=TEDotProductAttention, |
|
|
linear_proj=TERowParallelLinear, |
|
|
q_layernorm=TENorm, |
|
|
k_layernorm=TENorm, |
|
|
), |
|
|
), |
|
|
cross_attention=ModuleSpec( |
|
|
module=CrossAttention, |
|
|
params=params, |
|
|
submodules=CrossAttentionSubmodules( |
|
|
linear_q=TEColumnParallelLinear, |
|
|
linear_kv=TEColumnParallelLinear, |
|
|
core_attention=TEDotProductAttention, |
|
|
linear_proj=TERowParallelLinear, |
|
|
q_layernorm=TENorm, |
|
|
k_layernorm=TENorm, |
|
|
), |
|
|
), |
|
|
mlp=ModuleSpec( |
|
|
module=MLP, |
|
|
submodules=MLPSubmodules( |
|
|
linear_fc1=TEColumnParallelLinear, |
|
|
linear_fc2=TERowParallelLinear, |
|
|
), |
|
|
), |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
def get_dit_adaln_block_with_transformer_engine_spec(attn_mask_type=AttnMaskType.padding) -> ModuleSpec: |
|
|
params = {"attn_mask_type": attn_mask_type} |
|
|
return ModuleSpec( |
|
|
module=DiTLayerWithAdaLN, |
|
|
submodules=DiTWithAdaLNSubmodules( |
|
|
full_self_attention=ModuleSpec( |
|
|
module=SelfAttention, |
|
|
params=params, |
|
|
submodules=SelfAttentionSubmodules( |
|
|
linear_qkv=TEColumnParallelLinear, |
|
|
core_attention=TEDotProductAttention, |
|
|
linear_proj=TERowParallelLinear, |
|
|
q_layernorm=RMSNorm, |
|
|
k_layernorm=RMSNorm, |
|
|
), |
|
|
), |
|
|
cross_attention=ModuleSpec( |
|
|
module=CrossAttention, |
|
|
params=params, |
|
|
submodules=CrossAttentionSubmodules( |
|
|
linear_q=TEColumnParallelLinear, |
|
|
linear_kv=TEColumnParallelLinear, |
|
|
core_attention=TEDotProductAttention, |
|
|
linear_proj=TERowParallelLinear, |
|
|
q_layernorm=RMSNorm, |
|
|
k_layernorm=RMSNorm, |
|
|
), |
|
|
), |
|
|
mlp=ModuleSpec( |
|
|
module=MLP, |
|
|
submodules=MLPSubmodules( |
|
|
linear_fc1=TEColumnParallelLinear, |
|
|
linear_fc2=TERowParallelLinear, |
|
|
), |
|
|
), |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
def get_official_dit_adaln_block_with_transformer_engine_spec() -> ModuleSpec: |
|
|
params = {"attn_mask_type": AttnMaskType.no_mask} |
|
|
return ModuleSpec( |
|
|
module=DiTLayerWithAdaLN, |
|
|
submodules=DiTWithAdaLNSubmodules( |
|
|
full_self_attention=ModuleSpec( |
|
|
module=SelfAttention, |
|
|
params=params, |
|
|
submodules=SelfAttentionSubmodules( |
|
|
linear_qkv=TEColumnParallelLinear, |
|
|
core_attention=TEDotProductAttention, |
|
|
linear_proj=TERowParallelLinear, |
|
|
), |
|
|
), |
|
|
mlp=ModuleSpec( |
|
|
module=MLP, |
|
|
submodules=MLPSubmodules( |
|
|
linear_fc1=TEColumnParallelLinear, |
|
|
linear_fc2=TERowParallelLinear, |
|
|
), |
|
|
), |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
def get_mm_dit_block_with_transformer_engine_spec() -> ModuleSpec: |
|
|
|
|
|
return ModuleSpec( |
|
|
module=MMDiTLayer, |
|
|
submodules=TransformerLayerSubmodules( |
|
|
self_attention=ModuleSpec( |
|
|
module=JointSelfAttention, |
|
|
params={"attn_mask_type": AttnMaskType.no_mask}, |
|
|
submodules=JointSelfAttentionSubmodules( |
|
|
linear_qkv=TEColumnParallelLinear, |
|
|
added_linear_qkv=TEColumnParallelLinear, |
|
|
core_attention=TEDotProductAttention, |
|
|
linear_proj=TERowParallelLinear, |
|
|
), |
|
|
), |
|
|
mlp=ModuleSpec( |
|
|
module=MLP, |
|
|
submodules=MLPSubmodules( |
|
|
linear_fc1=TEColumnParallelLinear, |
|
|
linear_fc2=TERowParallelLinear, |
|
|
), |
|
|
), |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
def get_flux_single_transformer_engine_spec() -> ModuleSpec: |
|
|
return ModuleSpec( |
|
|
module=FluxSingleTransformerBlock, |
|
|
submodules=TransformerLayerSubmodules( |
|
|
self_attention=ModuleSpec( |
|
|
module=FluxSingleAttention, |
|
|
params={"attn_mask_type": AttnMaskType.no_mask}, |
|
|
submodules=SelfAttentionSubmodules( |
|
|
linear_qkv=TEColumnParallelLinear, |
|
|
core_attention=TEDotProductAttention, |
|
|
q_layernorm=TENorm, |
|
|
k_layernorm=TENorm, |
|
|
linear_proj=TERowParallelLinear, |
|
|
), |
|
|
), |
|
|
mlp=ModuleSpec( |
|
|
module=MLP, |
|
|
submodules=MLPSubmodules( |
|
|
linear_fc1=TEColumnParallelLinear, |
|
|
linear_fc2=TERowParallelLinear, |
|
|
), |
|
|
), |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
def get_flux_double_transformer_engine_spec() -> ModuleSpec: |
|
|
return ModuleSpec( |
|
|
module=MMDiTLayer, |
|
|
submodules=TransformerLayerSubmodules( |
|
|
self_attention=ModuleSpec( |
|
|
module=JointSelfAttention, |
|
|
params={"attn_mask_type": AttnMaskType.no_mask}, |
|
|
submodules=JointSelfAttentionSubmodules( |
|
|
q_layernorm=TENorm, |
|
|
k_layernorm=TENorm, |
|
|
added_q_layernorm=TENorm, |
|
|
added_k_layernorm=TENorm, |
|
|
linear_qkv=TEColumnParallelLinear, |
|
|
added_linear_qkv=TEColumnParallelLinear, |
|
|
core_attention=TEDotProductAttention, |
|
|
linear_proj=TERowParallelLinear, |
|
|
), |
|
|
), |
|
|
mlp=ModuleSpec( |
|
|
module=MLP, |
|
|
submodules=MLPSubmodules( |
|
|
linear_fc1=TEColumnParallelLinear, |
|
|
linear_fc2=TERowParallelLinear, |
|
|
), |
|
|
), |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|