File size: 15,632 Bytes
c14d03d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 |
from typing import Optional
from typing import Union
import torch
from einops import rearrange
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from einops.layers.torch import Rearrange
import torch
from torch import nn
from torch.nn import functional as F
from .modules import RMSNorm
# https://github.com/facebookresearch/DiT
# Ref: https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py
# Ref: https://github.com/lucidrains/rotary-embedding-torch
def compute_rope_rotations(length: int,
dim: int,
theta: int,
*,
freq_scaling: float = 1.0,
device: Union[torch.device, str] = 'cpu') -> Tensor:
assert dim % 2 == 0
with torch.amp.autocast(device_type='cuda', enabled=False):
pos = torch.arange(length, dtype=torch.float32, device=device)
freqs = 1.0 / (theta**(torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
freqs *= freq_scaling
rot = torch.einsum('..., f -> ... f', pos, freqs)
rot = torch.stack([torch.cos(rot), -torch.sin(rot), torch.sin(rot), torch.cos(rot)], dim=-1)
rot = rearrange(rot, 'n d (i j) -> 1 n d i j', i=2, j=2)
return rot
def apply_rope(x: Tensor, rot: Tensor) -> tuple[Tensor, Tensor]:
with torch.amp.autocast(device_type='cuda', enabled=False):
_x = x.float()
_x = _x.view(*_x.shape[:-1], -1, 1, 2)
x_out = rot[..., 0] * _x[..., 0] + rot[..., 1] * _x[..., 1]
return x_out.reshape(*x.shape).to(dtype=x.dtype)
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, dim, frequency_embedding_size, max_period):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, dim),
nn.SiLU(),
nn.Linear(dim, dim),
)
self.dim = dim
self.max_period = max_period
assert dim % 2 == 0, 'dim must be even.'
with torch.autocast('cuda', enabled=False):
# 1. 先计算出最终的张量
initial_freqs = 1.0 / (10000**(torch.arange(0, frequency_embedding_size, 2, dtype=torch.float32) /
frequency_embedding_size))
freq_scale = 10000 / max_period
freqs_tensor = freq_scale * initial_freqs
# 2. 使用 register_buffer() 将最终的张量注册为 buffer
self.register_buffer('freqs', freqs_tensor, persistent=False)
def timestep_embedding(self, t):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
args = t[:, None].float() * self.freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t).to(t.dtype)
t_emb = self.mlp(t_freq)
return t_emb
class ChannelLastConv1d(nn.Conv1d):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.permute(0, 2, 1)
x = super().forward(x)
x = x.permute(0, 2, 1)
return x
# https://github.com/Stability-AI/sd3-ref
class MLP(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int = 256,
):
"""
Initialize the FeedForward module.
Args:
dim (int): Input dimension.
hidden_dim (int): Hidden dimension of the feedforward layer.
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
Attributes:
w1 (ColumnParallelLinear): Linear transformation for the first layer.
w2 (RowParallelLinear): Linear transformation for the second layer.
w3 (ColumnParallelLinear): Linear transformation for the third layer.
"""
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class ConvMLP(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int = 256,
kernel_size: int = 3,
padding: int = 1,
):
"""
Initialize the FeedForward module.
Args:
dim (int): Input dimension.
hidden_dim (int): Hidden dimension of the feedforward layer.
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
Attributes:
w1 (ColumnParallelLinear): Linear transformation for the first layer.
w2 (RowParallelLinear): Linear transformation for the second layer.
w3 (ColumnParallelLinear): Linear transformation for the third layer.
"""
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = ChannelLastConv1d(dim,
hidden_dim,
bias=False,
kernel_size=kernel_size,
padding=padding)
self.w2 = ChannelLastConv1d(hidden_dim,
dim,
bias=False,
kernel_size=kernel_size,
padding=padding)
self.w3 = ChannelLastConv1d(dim,
hidden_dim,
bias=False,
kernel_size=kernel_size,
padding=padding)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
return x * (1 + scale) + shift
def attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
# training will crash without these contiguous calls and the CUDNN limitation
# I believe this is related to https://github.com/pytorch/pytorch/issues/133974
# unresolved at the time of writing
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
out = F.scaled_dot_product_attention(q, k, v)
out = rearrange(out, 'b h n d -> b n (h d)').contiguous()
return out
class SelfAttention(nn.Module):
def __init__(self, dim: int, nheads: int):
super().__init__()
self.dim = dim
self.nheads = nheads
self.qkv = nn.Linear(dim, dim * 3, bias=True)
self.q_norm = RMSNorm(dim // nheads)
self.k_norm = RMSNorm(dim // nheads)
self.split_into_heads = Rearrange('b n (h d j) -> b h n d j',
h=nheads,
d=dim // nheads,
j=3)
def pre_attention(
self, x: torch.Tensor,
rot: Optional[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# x: batch_size * n_tokens * n_channels
qkv = self.qkv(x)
q, k, v = self.split_into_heads(qkv).chunk(3, dim=-1)
q = q.squeeze(-1)
k = k.squeeze(-1)
v = v.squeeze(-1)
q = self.q_norm(q)
k = self.k_norm(k)
if rot is not None:
q = apply_rope(q, rot)
k = apply_rope(k, rot)
return q, k, v
def forward(
self,
x: torch.Tensor, # batch_size * n_tokens * n_channels
) -> torch.Tensor:
q, k, v = self.pre_attention(x)
out = attention(q, k, v)
return out
class MMDitSingleBlock(nn.Module):
def __init__(self,
dim: int,
nhead: int,
mlp_ratio: float = 4.0,
pre_only: bool = False,
kernel_size: int = 7,
padding: int = 3):
super().__init__()
self.norm1 = nn.LayerNorm(dim, elementwise_affine=False)
self.attn = SelfAttention(dim, nhead)
self.pre_only = pre_only
if pre_only:
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 2 * dim, bias=True))
else:
if kernel_size == 1:
self.linear1 = nn.Linear(dim, dim)
else:
self.linear1 = ChannelLastConv1d(dim, dim, kernel_size=kernel_size, padding=padding)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False)
if kernel_size == 1:
self.ffn = MLP(dim, int(dim * mlp_ratio))
else:
self.ffn = ConvMLP(dim,
int(dim * mlp_ratio),
kernel_size=kernel_size,
padding=padding)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 6 * dim, bias=True))
def pre_attention(self, x: torch.Tensor, c: torch.Tensor, rot: Optional[torch.Tensor]):
# x: BS * N * D
# cond: BS * D
modulation = self.adaLN_modulation(c)
if self.pre_only:
(shift_msa, scale_msa) = modulation.chunk(2, dim=-1)
gate_msa = shift_mlp = scale_mlp = gate_mlp = None
else:
(shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp,
gate_mlp) = modulation.chunk(6, dim=-1)
x = modulate(self.norm1(x), shift_msa, scale_msa)
q, k, v = self.attn.pre_attention(x, rot)
return (q, k, v), (gate_msa, shift_mlp, scale_mlp, gate_mlp)
def post_attention(self, x: torch.Tensor, attn_out: torch.Tensor, c: tuple[torch.Tensor]):
if self.pre_only:
return x
(gate_msa, shift_mlp, scale_mlp, gate_mlp) = c
x = x + self.linear1(attn_out) * gate_msa
r = modulate(self.norm2(x), shift_mlp, scale_mlp)
x = x + self.ffn(r) * gate_mlp
return x
# 这里的forward似乎没有用到
def forward(self, x: torch.Tensor, cond: torch.Tensor,
rot: Optional[torch.Tensor]) -> torch.Tensor:
# x: BS * N * D
# cond: BS * D
x_qkv, x_conditions = self.pre_attention(x, cond, rot)
attn_out = attention(*x_qkv)
x = self.post_attention(x, attn_out, x_conditions)
return x
class JointBlock_AT(nn.Module):
"""
Audio + Text only JointBlock(去掉 clip 分支)
返回 (latent, text_f)
"""
def __init__(self, dim: int, nhead: int, mlp_ratio: float = 4.0, pre_only: bool = False):
super().__init__()
self.pre_only = pre_only
self.latent_block = MMDitSingleBlock(dim,
nhead,
mlp_ratio,
pre_only=False,
kernel_size=3,
padding=1)
# text_block 仍保留 pre_only 参数(可能是 pre-only 的 AdaLN)
self.text_block = MMDitSingleBlock(dim, nhead, mlp_ratio, pre_only=pre_only, kernel_size=1)
def forward(self, latent: torch.Tensor, text_f: torch.Tensor,
global_c: torch.Tensor, extended_c: torch.Tensor, latent_rot: Optional[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
# latent: (B, N_latent, D)
# text_f: (B, N_text, D)
# global_c: (B, 1, D) or (B, D)
# extended_c: (B, N_latent, D) or (B, 1, D)
x_qkv, x_mod = self.latent_block.pre_attention(latent, extended_c, latent_rot)
# text没有做rope编码, 也有点奇怪,可能audiollm中带有
t_qkv, t_mod = self.text_block.pre_attention(text_f, global_c, rot=None)
latent_len = latent.shape[1]
text_len = text_f.shape[1]
# 只拼接 latent + text
joint_qkv = [torch.cat([x_qkv[i], t_qkv[i]], dim=2) for i in range(3)] # dim=2=token dim
attn_out = attention(*joint_qkv) # (B, latent_len + text_len, D)
x_attn_out = attn_out[:, :latent_len] # (B, latent_len, D)
t_attn_out = attn_out[:, latent_len:] # (B, text_len, D)
latent = self.latent_block.post_attention(latent, x_attn_out, x_mod)
if not self.pre_only:
text_f = self.text_block.post_attention(text_f, t_attn_out, t_mod)
return latent, text_f
# 改一下mask的逻辑
# def forward(self, latent, text_f, global_c, extended_c, latent_rot,
# latent_mask: torch.Tensor, text_mask: torch.Tensor):
# # latent_mask: (B, N_latent) {0,1}
# # text_mask: (B, N_text) {0,1}
# x_qkv, x_mod = self.latent_block.pre_attention(latent, extended_c, latent_rot)
# t_qkv, t_mod = self.text_block.pre_attention(text_f, global_c, rot=None)
# latent_len = latent.shape[1]
# text_len = text_f.shape[1]
# # 1) 拼 qkv
# joint_qkv = [torch.cat([x_qkv[i], t_qkv[i]], dim=2) for i in range(3)] # 这里假设 token 维=2
# # 2) 构造 key mask(拼接后的)
# key_mask = torch.cat([latent_mask, text_mask], dim=1).bool() # (B, N_total)
# # 3) 调用注意力(要求 attention 支持 key_mask)
# # 若你的 attention 不支持,需要自己在里面对 logits 做 -inf 掩码;示例见后
# attn_out = attention(*joint_qkv, key_mask=key_mask) # (B, N_total, D)
# # 4) 切回两段
# x_attn_out = attn_out[:, :latent_len, :]
# t_attn_out = attn_out[:, latent_len:, :]
# # 5) 对 query 端输出做屏蔽(避免 padding query 写回)
# x_attn_out = x_attn_out * latent_mask.unsqueeze(-1) # (B, N_latent, D)
# t_attn_out = t_attn_out * text_mask.unsqueeze(-1) # (B, N_text, D)
# # 6) post_attention 内部**还要**用 query mask 把残差和 FFN 的更新再屏蔽一次(见下一节)
# latent = self.latent_block.post_attention(latent, x_attn_out, x_mod,
# query_mask=latent_mask)
# if not self.text_block.pre_only:
# text_f = self.text_block.post_attention(text_f, t_attn_out, t_mod,
# query_mask=text_mask)
# return latent, text_f
class FinalBlock(nn.Module):
def __init__(self, dim, out_dim):
super().__init__()
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 2 * dim, bias=True))
self.norm = nn.LayerNorm(dim, elementwise_affine=False)
self.conv = ChannelLastConv1d(dim, out_dim, kernel_size=7, padding=3)
def forward(self, latent, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
latent = modulate(self.norm(latent), shift, scale)
latent = self.conv(latent)
return latent |