multimodalart's picture
multimodalart HF Staff
Upload 247 files
7758cff verified
Raw
History Blame Contribute Delete
18.2 kB
import torch
import torch.nn as nn
import numbers
from .modules import RMSNorm, SelfAttention, CrossAttention, Mlp,MMdual_attention,MMsingle_attention,MMfour_attention
from einops import rearrange, repeat
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
#################################################################################
# Core DiT Model #
#################################################################################
class DiTBlock(nn.Module):
"""
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning, contains CrossAttention.
"""
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
super().__init__()
norm_type = block_kwargs.get("norm_type", "rms_norm")
assert norm_type in ["layer_norm", "rms_norm"]
make_norm_layer = (
nn.LayerNorm if norm_type == "layer_norm" else RMSNorm
)
self.norm1 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn1 = SelfAttention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
self.norm2 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def initialize_weights(self):
self.apply(_basic_init)
# Zero-out adaLN modulation layers in DiT blocks:
nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
def forward(self, x, c,mask=None,freqs_cis=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
x = x + gate_msa.unsqueeze(1) * self.attn1(modulate(self.norm1(x), shift_msa, scale_msa),mask,freqs_cis)
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
class MMSingleStreamBlock(nn.Module):
''' A multimodal dit block with seperate modulation '''
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
super().__init__()
norm_type = block_kwargs.get("norm_type", "rms_norm")
assert norm_type in ["layer_norm", "rms_norm"]
make_norm_layer = (
nn.LayerNorm if norm_type == "layer_norm" else RMSNorm
)
self.norm1 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn1 = MMsingle_attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
self.norm2 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
# self.attn2 = CrossAttention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
self.norm3 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
self.norm4 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.qkv_xs = nn.Linear(hidden_size, hidden_size * 3+mlp_hidden_dim, bias=True)
# self.xs_mlp = Mlp(in_features=hidden_size+mlp_hidden_dim, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
self.linear2 = nn.Linear(
hidden_size + mlp_hidden_dim, hidden_size,
)
self.mlp_act = approx_gelu()
self.adaLN_modulation_xs = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 3* hidden_size, bias=True)
)
self.hidden_size=hidden_size
self.mlp_hidden_dim=mlp_hidden_dim
def initialize_weights(self):
self.apply(_basic_init)
# Zero-out adaLN modulation layers in DiT blocks:
nn.init.constant_(self.adaLN_modulation_xs[-1].weight, 0)
nn.init.constant_(self.adaLN_modulation_xs[-1].bias, 0)
def forward(self,seq_len, x, c,mask=None,freqs_cis=None,freqs_cis2=None,causal=False):
shift_msa_xs, scale_msa_xs, gate_msa_xs = self.adaLN_modulation_xs(c).chunk(3, dim=1)
# Prepare for attention
x_mod=modulate(self.norm1(x), shift_msa_xs, scale_msa_xs)
qkv, mlp = torch.split(
self.qkv_xs(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1
)
att1= self.attn1(seq_len,qkv,mask,causal=causal,freqs_cis=freqs_cis,freqs_cis2=freqs_cis2)
output=self.linear2(torch.cat((att1, self.mlp_act(mlp)), 2))
x=x+gate_msa_xs.unsqueeze(1)*output
return x
class MMfourStreamBlock(nn.Module):
''' A multimodal dit block with seperate modulation '''
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
super().__init__()
norm_type = block_kwargs.get("norm_type", "rms_norm")
assert norm_type in ["layer_norm", "rms_norm"]
make_norm_layer = (
nn.LayerNorm if norm_type == "layer_norm" else RMSNorm
)
self.norm1 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn1 = MMfour_attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
self.norm2 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
# self.attn2 = CrossAttention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
self.norm3 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
self.norm4 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
self.norm5 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
self.norm6 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
self.norm7 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
self.norm8 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.xs_mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
self.audio_mlp1 = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
self.audio_mlp2 = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
self.audio_mlp3 = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
self.adaLN_modulation_xs = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
self.adaLN_modulation_audio1 = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
self.adaLN_modulation_audio2 = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True))
self.adaLN_modulation_audio3 = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True))
def initialize_weights(self):
self.apply(_basic_init)
# Zero-out adaLN modulation layers in DiT blocks:
nn.init.constant_(self.adaLN_modulation_xs[-1].weight, 0)
nn.init.constant_(self.adaLN_modulation_xs[-1].bias, 0)
nn.init.constant_(self.adaLN_modulation_audio1[-1].weight, 0)
nn.init.constant_(self.adaLN_modulation_audio1[-1].bias, 0)
nn.init.constant_(self.adaLN_modulation_audio2[-1].weight, 0)
nn.init.constant_(self.adaLN_modulation_audio2[-1].bias, 0)
nn.init.constant_(self.adaLN_modulation_audio3[-1].weight, 0)
nn.init.constant_(self.adaLN_modulation_audio3[-1].bias, 0)
def forward(self, x, c, y1,y2,y3,mask=None,freqs_cis=None,freqs_cis2=None,causal=False):
shift_msa_xs, scale_msa_xs, gate_msa_xs, shift_mlp_xs, scale_mlp_xs, gate_mlp_xs = self.adaLN_modulation_xs(c).chunk(6, dim=1)
shift_mca_audio1, scale_mca_audio1, gate_mca_audio1, shift_mlp_audio1, scale_mlp_audio1, gate_mlp_audio1 = self.adaLN_modulation_audio1(c).chunk(6, dim=1)
shift_mca_audio2, scale_mca_audio2, gate_mca_audio2, shift_mlp_audio2, scale_mlp_audio2, gate_mlp_audio2 = self.adaLN_modulation_audio2(c).chunk(6, dim=1)
shift_mca_audio3, scale_mca_audio3, gate_mca_audio3, shift_mlp_audio3, scale_mlp_audio3, gate_mlp_audio3= self.adaLN_modulation_audio3(c).chunk(6, dim=1)
# Prepare for attention
att1,att2,att3,att4= self.attn1( modulate(self.norm1(x), shift_msa_xs, scale_msa_xs),
modulate(self.norm2(y1), shift_mca_audio1, scale_mca_audio1),
modulate(self.norm3(y2), shift_mca_audio2, scale_mca_audio2),
modulate(self.norm4(y3), shift_mca_audio3, scale_mca_audio3),
mask,causal=causal,freqs_cis=freqs_cis,freqs_cis2=freqs_cis2)
x=x+gate_msa_xs.unsqueeze(1)*att1
y1=y1+gate_mca_audio1.unsqueeze(1)*att2
y2=y2+gate_mca_audio2.unsqueeze(1)*att3
y3=y3+gate_mca_audio3.unsqueeze(1)*att4
x = x + gate_mlp_xs.unsqueeze(1) * self.xs_mlp(modulate(self.norm5(x), shift_mlp_xs, scale_mlp_xs))
y1 = y1 + gate_mlp_audio1.unsqueeze(1) * self.audio_mlp1(modulate(self.norm6(y1), shift_mlp_audio1, scale_mlp_audio1))
y2 = y2 + gate_mlp_audio2.unsqueeze(1) * self.audio_mlp2(modulate(self.norm7(y2), shift_mlp_audio2, scale_mlp_audio2))
y3 = y3 + gate_mlp_audio3.unsqueeze(1) * self.audio_mlp3(modulate(self.norm8(y3), shift_mlp_audio3, scale_mlp_audio3))
return x,y1,y2,y3
class MMDoubleStreamBlock(nn.Module):
''' A multimodal dit block with seperate modulation '''
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
super().__init__()
norm_type = block_kwargs.get("norm_type", "rms_norm")
assert norm_type in ["layer_norm", "rms_norm"]
make_norm_layer = (
nn.LayerNorm if norm_type == "layer_norm" else RMSNorm
)
self.norm1 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn1 = MMdual_attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
self.norm2 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
# self.attn2 = CrossAttention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
self.norm3 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
self.norm4 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.xs_mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
self.audio_mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
self.adaLN_modulation_xs = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
self.adaLN_modulation_audio = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def initialize_weights(self):
self.apply(_basic_init)
# Zero-out adaLN modulation layers in DiT blocks:
nn.init.constant_(self.adaLN_modulation_xs[-1].weight, 0)
nn.init.constant_(self.adaLN_modulation_xs[-1].bias, 0)
nn.init.constant_(self.adaLN_modulation_audio[-1].weight, 0)
nn.init.constant_(self.adaLN_modulation_audio[-1].bias, 0)
def forward(self, seq_len,x, c, y,mask=None,freqs_cis=None,freqs_cis2=None,causal=False):
shift_msa_xs, scale_msa_xs, gate_msa_xs, shift_mlp_xs, scale_mlp_xs, gate_mlp_xs = self.adaLN_modulation_xs(c).chunk(6, dim=1)
shift_mca_audio, scale_mca_audio, gate_mca_audio, shift_mlp_audio, scale_mlp_audio, gate_mlp_audio = self.adaLN_modulation_audio(c).chunk(6, dim=1)
# Prepare for attention
att1,att2 = self.attn1(seq_len,modulate(self.norm1(x), shift_msa_xs, scale_msa_xs),modulate(self.norm2(y), shift_mca_audio, scale_mca_audio),mask,causal=causal,freqs_cis=freqs_cis,freqs_cis2=freqs_cis2)
x=x+gate_msa_xs.unsqueeze(1)*att1
y=y+gate_mca_audio.unsqueeze(1)*att2
x = x + gate_mlp_xs.unsqueeze(1) * self.xs_mlp(modulate(self.norm3(x), shift_mlp_xs, scale_mlp_xs))
y = y + gate_mlp_audio.unsqueeze(1) * self.audio_mlp(modulate(self.norm4(y), shift_mlp_audio, scale_mlp_audio))
return x,y
class CrossDiTBlock(nn.Module):
"""
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning, contains CrossAttention.
"""
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
super().__init__()
norm_type = block_kwargs.get("norm_type", "rms_norm")
assert norm_type in ["layer_norm", "rms_norm"]
make_norm_layer = (
nn.LayerNorm if norm_type == "layer_norm" else RMSNorm
)
self.norm1 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn1 = SelfAttention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
self.norm2 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn2 = CrossAttention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
self.norm3 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 9 * hidden_size, bias=True)
)
def initialize_weights(self):
self.apply(_basic_init)
# Zero-out adaLN modulation layers in DiT blocks:
nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
def forward(self, x, c, y,mask=None):
shift_msa, scale_msa, gate_msa, shift_mca, scale_mca, gate_mca, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(9, dim=1)
x = x + gate_msa.unsqueeze(1) * self.attn1(modulate(self.norm1(x), shift_msa, scale_msa),mask)
x = x + gate_mca.unsqueeze(1) * self.attn2(modulate(self.norm2(x), shift_mca, scale_mca), y,mask)
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm3(x), shift_mlp, scale_mlp))
return x
class SelfBlock(nn.Module):
"""
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning, contains CrossAttention.
"""
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
super().__init__()
norm_type = block_kwargs.get("norm_type", "rms_norm")
assert norm_type in ["layer_norm", "rms_norm"]
make_norm_layer = (
nn.LayerNorm if norm_type == "layer_norm" else RMSNorm
)
self.norm2 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn2 = SelfAttention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
def initialize_weights(self):
self.apply(_basic_init)
# Zero-out adaLN modulation layers in DiT blocks:
def forward(self, x, y,mask=None):
x = x + self.attn2(self.norm2(x),mask)
return x
class CrossBlock(nn.Module):
"""
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning, contains CrossAttention.
"""
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
super().__init__()
norm_type = block_kwargs.get("norm_type", "rms_norm")
assert norm_type in ["layer_norm", "rms_norm"]
make_norm_layer = (
nn.LayerNorm if norm_type == "layer_norm" else RMSNorm
)
self.norm2 = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn2 = CrossAttention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
def initialize_weights(self):
self.apply(_basic_init)
# Zero-out adaLN modulation layers in DiT blocks:
def forward(self, x, y,mask=None):
x = x + self.attn2(self.norm2(x), y,mask)
return x
class FinalLayer(nn.Module):
"""
The final layer of DiT.
"""
def __init__(self, hidden_size, out_channels, norm_type="rms_norm"):
super().__init__()
assert norm_type in ["layer_norm", "rms_norm"]
make_norm_layer = (
nn.LayerNorm if norm_type == "layer_norm" else RMSNorm
)
self.norm_final = make_norm_layer(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
)
def initialize_weights(self):
self.apply(_basic_init)
# Zero-out output layers:
nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.linear.weight, 0)
nn.init.constant_(self.linear.bias, 0)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x