MMEdit / models /dit /mmdit_layers.py
CocoBro's picture
init space
c14d03d
from typing import Optional
from typing import Union
import torch
from einops import rearrange
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from einops.layers.torch import Rearrange
import torch
from torch import nn
from torch.nn import functional as F
from .modules import RMSNorm
# https://github.com/facebookresearch/DiT
# Ref: https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py
# Ref: https://github.com/lucidrains/rotary-embedding-torch
def compute_rope_rotations(length: int,
dim: int,
theta: int,
*,
freq_scaling: float = 1.0,
device: Union[torch.device, str] = 'cpu') -> Tensor:
assert dim % 2 == 0
with torch.amp.autocast(device_type='cuda', enabled=False):
pos = torch.arange(length, dtype=torch.float32, device=device)
freqs = 1.0 / (theta**(torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
freqs *= freq_scaling
rot = torch.einsum('..., f -> ... f', pos, freqs)
rot = torch.stack([torch.cos(rot), -torch.sin(rot), torch.sin(rot), torch.cos(rot)], dim=-1)
rot = rearrange(rot, 'n d (i j) -> 1 n d i j', i=2, j=2)
return rot
def apply_rope(x: Tensor, rot: Tensor) -> tuple[Tensor, Tensor]:
with torch.amp.autocast(device_type='cuda', enabled=False):
_x = x.float()
_x = _x.view(*_x.shape[:-1], -1, 1, 2)
x_out = rot[..., 0] * _x[..., 0] + rot[..., 1] * _x[..., 1]
return x_out.reshape(*x.shape).to(dtype=x.dtype)
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, dim, frequency_embedding_size, max_period):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, dim),
nn.SiLU(),
nn.Linear(dim, dim),
)
self.dim = dim
self.max_period = max_period
assert dim % 2 == 0, 'dim must be even.'
with torch.autocast('cuda', enabled=False):
# 1. 先计算出最终的张量
initial_freqs = 1.0 / (10000**(torch.arange(0, frequency_embedding_size, 2, dtype=torch.float32) /
frequency_embedding_size))
freq_scale = 10000 / max_period
freqs_tensor = freq_scale * initial_freqs
# 2. 使用 register_buffer() 将最终的张量注册为 buffer
self.register_buffer('freqs', freqs_tensor, persistent=False)
def timestep_embedding(self, t):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
args = t[:, None].float() * self.freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t).to(t.dtype)
t_emb = self.mlp(t_freq)
return t_emb
class ChannelLastConv1d(nn.Conv1d):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.permute(0, 2, 1)
x = super().forward(x)
x = x.permute(0, 2, 1)
return x
# https://github.com/Stability-AI/sd3-ref
class MLP(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int = 256,
):
"""
Initialize the FeedForward module.
Args:
dim (int): Input dimension.
hidden_dim (int): Hidden dimension of the feedforward layer.
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
Attributes:
w1 (ColumnParallelLinear): Linear transformation for the first layer.
w2 (RowParallelLinear): Linear transformation for the second layer.
w3 (ColumnParallelLinear): Linear transformation for the third layer.
"""
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class ConvMLP(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int = 256,
kernel_size: int = 3,
padding: int = 1,
):
"""
Initialize the FeedForward module.
Args:
dim (int): Input dimension.
hidden_dim (int): Hidden dimension of the feedforward layer.
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
Attributes:
w1 (ColumnParallelLinear): Linear transformation for the first layer.
w2 (RowParallelLinear): Linear transformation for the second layer.
w3 (ColumnParallelLinear): Linear transformation for the third layer.
"""
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = ChannelLastConv1d(dim,
hidden_dim,
bias=False,
kernel_size=kernel_size,
padding=padding)
self.w2 = ChannelLastConv1d(hidden_dim,
dim,
bias=False,
kernel_size=kernel_size,
padding=padding)
self.w3 = ChannelLastConv1d(dim,
hidden_dim,
bias=False,
kernel_size=kernel_size,
padding=padding)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
return x * (1 + scale) + shift
def attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
# training will crash without these contiguous calls and the CUDNN limitation
# I believe this is related to https://github.com/pytorch/pytorch/issues/133974
# unresolved at the time of writing
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
out = F.scaled_dot_product_attention(q, k, v)
out = rearrange(out, 'b h n d -> b n (h d)').contiguous()
return out
class SelfAttention(nn.Module):
def __init__(self, dim: int, nheads: int):
super().__init__()
self.dim = dim
self.nheads = nheads
self.qkv = nn.Linear(dim, dim * 3, bias=True)
self.q_norm = RMSNorm(dim // nheads)
self.k_norm = RMSNorm(dim // nheads)
self.split_into_heads = Rearrange('b n (h d j) -> b h n d j',
h=nheads,
d=dim // nheads,
j=3)
def pre_attention(
self, x: torch.Tensor,
rot: Optional[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# x: batch_size * n_tokens * n_channels
qkv = self.qkv(x)
q, k, v = self.split_into_heads(qkv).chunk(3, dim=-1)
q = q.squeeze(-1)
k = k.squeeze(-1)
v = v.squeeze(-1)
q = self.q_norm(q)
k = self.k_norm(k)
if rot is not None:
q = apply_rope(q, rot)
k = apply_rope(k, rot)
return q, k, v
def forward(
self,
x: torch.Tensor, # batch_size * n_tokens * n_channels
) -> torch.Tensor:
q, k, v = self.pre_attention(x)
out = attention(q, k, v)
return out
class MMDitSingleBlock(nn.Module):
def __init__(self,
dim: int,
nhead: int,
mlp_ratio: float = 4.0,
pre_only: bool = False,
kernel_size: int = 7,
padding: int = 3):
super().__init__()
self.norm1 = nn.LayerNorm(dim, elementwise_affine=False)
self.attn = SelfAttention(dim, nhead)
self.pre_only = pre_only
if pre_only:
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 2 * dim, bias=True))
else:
if kernel_size == 1:
self.linear1 = nn.Linear(dim, dim)
else:
self.linear1 = ChannelLastConv1d(dim, dim, kernel_size=kernel_size, padding=padding)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False)
if kernel_size == 1:
self.ffn = MLP(dim, int(dim * mlp_ratio))
else:
self.ffn = ConvMLP(dim,
int(dim * mlp_ratio),
kernel_size=kernel_size,
padding=padding)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 6 * dim, bias=True))
def pre_attention(self, x: torch.Tensor, c: torch.Tensor, rot: Optional[torch.Tensor]):
# x: BS * N * D
# cond: BS * D
modulation = self.adaLN_modulation(c)
if self.pre_only:
(shift_msa, scale_msa) = modulation.chunk(2, dim=-1)
gate_msa = shift_mlp = scale_mlp = gate_mlp = None
else:
(shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp,
gate_mlp) = modulation.chunk(6, dim=-1)
x = modulate(self.norm1(x), shift_msa, scale_msa)
q, k, v = self.attn.pre_attention(x, rot)
return (q, k, v), (gate_msa, shift_mlp, scale_mlp, gate_mlp)
def post_attention(self, x: torch.Tensor, attn_out: torch.Tensor, c: tuple[torch.Tensor]):
if self.pre_only:
return x
(gate_msa, shift_mlp, scale_mlp, gate_mlp) = c
x = x + self.linear1(attn_out) * gate_msa
r = modulate(self.norm2(x), shift_mlp, scale_mlp)
x = x + self.ffn(r) * gate_mlp
return x
# 这里的forward似乎没有用到
def forward(self, x: torch.Tensor, cond: torch.Tensor,
rot: Optional[torch.Tensor]) -> torch.Tensor:
# x: BS * N * D
# cond: BS * D
x_qkv, x_conditions = self.pre_attention(x, cond, rot)
attn_out = attention(*x_qkv)
x = self.post_attention(x, attn_out, x_conditions)
return x
class JointBlock_AT(nn.Module):
"""
Audio + Text only JointBlock(去掉 clip 分支)
返回 (latent, text_f)
"""
def __init__(self, dim: int, nhead: int, mlp_ratio: float = 4.0, pre_only: bool = False):
super().__init__()
self.pre_only = pre_only
self.latent_block = MMDitSingleBlock(dim,
nhead,
mlp_ratio,
pre_only=False,
kernel_size=3,
padding=1)
# text_block 仍保留 pre_only 参数(可能是 pre-only 的 AdaLN)
self.text_block = MMDitSingleBlock(dim, nhead, mlp_ratio, pre_only=pre_only, kernel_size=1)
def forward(self, latent: torch.Tensor, text_f: torch.Tensor,
global_c: torch.Tensor, extended_c: torch.Tensor, latent_rot: Optional[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
# latent: (B, N_latent, D)
# text_f: (B, N_text, D)
# global_c: (B, 1, D) or (B, D)
# extended_c: (B, N_latent, D) or (B, 1, D)
x_qkv, x_mod = self.latent_block.pre_attention(latent, extended_c, latent_rot)
# text没有做rope编码, 也有点奇怪,可能audiollm中带有
t_qkv, t_mod = self.text_block.pre_attention(text_f, global_c, rot=None)
latent_len = latent.shape[1]
text_len = text_f.shape[1]
# 只拼接 latent + text
joint_qkv = [torch.cat([x_qkv[i], t_qkv[i]], dim=2) for i in range(3)] # dim=2=token dim
attn_out = attention(*joint_qkv) # (B, latent_len + text_len, D)
x_attn_out = attn_out[:, :latent_len] # (B, latent_len, D)
t_attn_out = attn_out[:, latent_len:] # (B, text_len, D)
latent = self.latent_block.post_attention(latent, x_attn_out, x_mod)
if not self.pre_only:
text_f = self.text_block.post_attention(text_f, t_attn_out, t_mod)
return latent, text_f
# 改一下mask的逻辑
# def forward(self, latent, text_f, global_c, extended_c, latent_rot,
# latent_mask: torch.Tensor, text_mask: torch.Tensor):
# # latent_mask: (B, N_latent) {0,1}
# # text_mask: (B, N_text) {0,1}
# x_qkv, x_mod = self.latent_block.pre_attention(latent, extended_c, latent_rot)
# t_qkv, t_mod = self.text_block.pre_attention(text_f, global_c, rot=None)
# latent_len = latent.shape[1]
# text_len = text_f.shape[1]
# # 1) 拼 qkv
# joint_qkv = [torch.cat([x_qkv[i], t_qkv[i]], dim=2) for i in range(3)] # 这里假设 token 维=2
# # 2) 构造 key mask(拼接后的)
# key_mask = torch.cat([latent_mask, text_mask], dim=1).bool() # (B, N_total)
# # 3) 调用注意力(要求 attention 支持 key_mask)
# # 若你的 attention 不支持,需要自己在里面对 logits 做 -inf 掩码;示例见后
# attn_out = attention(*joint_qkv, key_mask=key_mask) # (B, N_total, D)
# # 4) 切回两段
# x_attn_out = attn_out[:, :latent_len, :]
# t_attn_out = attn_out[:, latent_len:, :]
# # 5) 对 query 端输出做屏蔽(避免 padding query 写回)
# x_attn_out = x_attn_out * latent_mask.unsqueeze(-1) # (B, N_latent, D)
# t_attn_out = t_attn_out * text_mask.unsqueeze(-1) # (B, N_text, D)
# # 6) post_attention 内部**还要**用 query mask 把残差和 FFN 的更新再屏蔽一次(见下一节)
# latent = self.latent_block.post_attention(latent, x_attn_out, x_mod,
# query_mask=latent_mask)
# if not self.text_block.pre_only:
# text_f = self.text_block.post_attention(text_f, t_attn_out, t_mod,
# query_mask=text_mask)
# return latent, text_f
class FinalBlock(nn.Module):
def __init__(self, dim, out_dim):
super().__init__()
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 2 * dim, bias=True))
self.norm = nn.LayerNorm(dim, elementwise_affine=False)
self.conv = ChannelLastConv1d(dim, out_dim, kernel_size=7, padding=3)
def forward(self, latent, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
latent = modulate(self.norm(latent), shift, scale)
latent = self.conv(latent)
return latent