|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
ein notation: |
|
|
b - batch |
|
|
n - sequence |
|
|
nt - text sequence |
|
|
nw - raw wave length |
|
|
d - dimension |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from librosa.filters import mel as librosa_mel_fn |
|
|
from torch import nn |
|
|
from x_transformers.x_transformers import apply_rotary_pos_emb |
|
|
|
|
|
mel_basis_cache = {} |
|
|
hann_window_cache = {} |
|
|
|
|
|
from f5_tts.model.modules import AdaLayerNormZero, Attention, AttnProcessor, FeedForward |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CrossAttention(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
processor: CrossAttnProcessor, |
|
|
dim: int, |
|
|
dim_to_k: int, |
|
|
heads: int = 8, |
|
|
dim_head: int = 64, |
|
|
dropout: float = 0.0, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
if not hasattr(F, "scaled_dot_product_attention"): |
|
|
raise ImportError( |
|
|
"Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." |
|
|
) |
|
|
|
|
|
self.processor = processor |
|
|
|
|
|
self.dim = dim |
|
|
self.heads = heads |
|
|
self.inner_dim = dim_head * heads |
|
|
self.dropout = dropout |
|
|
|
|
|
self.to_q = nn.Linear(dim, self.inner_dim) |
|
|
self.to_k = nn.Linear(dim_to_k, self.inner_dim) |
|
|
self.to_v = nn.Linear(dim_to_k, self.inner_dim) |
|
|
|
|
|
self.to_out = nn.ModuleList([]) |
|
|
self.to_out.append(nn.Linear(self.inner_dim, dim)) |
|
|
self.to_out.append(nn.Dropout(dropout)) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x_for_q: float["b n d"], |
|
|
x_for_k: float["b n d"] = None, |
|
|
mask: bool["b n"] | None = None, |
|
|
rope=None, |
|
|
) -> torch.Tensor: |
|
|
return self.processor( |
|
|
self, |
|
|
x_for_q, |
|
|
x_for_k, |
|
|
mask=mask, |
|
|
rope=rope, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CrossAttnProcessor: |
|
|
def __init__(self): |
|
|
pass |
|
|
|
|
|
def __call__( |
|
|
self, |
|
|
attn: CrossAttention, |
|
|
x_for_q: float["b n d"], |
|
|
x_for_k: float["b n d"], |
|
|
mask: bool["b n"] | None = None, |
|
|
rope=None, |
|
|
) -> torch.FloatTensor: |
|
|
batch_size = x_for_q.shape[0] |
|
|
|
|
|
|
|
|
query = attn.to_q(x_for_q) |
|
|
key = attn.to_k(x_for_k) |
|
|
value = attn.to_v(x_for_k) |
|
|
|
|
|
|
|
|
if rope is not None: |
|
|
freqs, xpos_scale = rope |
|
|
q_xpos_scale, k_xpos_scale = ( |
|
|
(xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0) |
|
|
) |
|
|
|
|
|
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale) |
|
|
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale) |
|
|
|
|
|
|
|
|
inner_dim = key.shape[-1] |
|
|
head_dim = inner_dim // attn.heads |
|
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
|
if mask is not None: |
|
|
attn_mask = mask |
|
|
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) |
|
|
attn_mask = attn_mask.expand( |
|
|
batch_size, attn.heads, query.shape[-2], key.shape[-2] |
|
|
) |
|
|
else: |
|
|
attn_mask = None |
|
|
|
|
|
x = F.scaled_dot_product_attention( |
|
|
query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False |
|
|
) |
|
|
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
|
|
x = x.to(query.dtype) |
|
|
|
|
|
|
|
|
x = attn.to_out[0](x) |
|
|
|
|
|
x = attn.to_out[1](x) |
|
|
|
|
|
if mask is not None: |
|
|
mask = mask.unsqueeze(-1) |
|
|
x = x.masked_fill(~mask, 0.0) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CADiTBlock(nn.Module): |
|
|
def __init__(self, dim, text_dim, heads, dim_head, ff_mult=4, dropout=0.1): |
|
|
super().__init__() |
|
|
|
|
|
self.attn_norm = AdaLayerNormZero(dim) |
|
|
self.attn = Attention( |
|
|
processor=AttnProcessor(), |
|
|
dim=dim, |
|
|
heads=heads, |
|
|
dim_head=dim_head, |
|
|
dropout=dropout, |
|
|
) |
|
|
|
|
|
self.cross_attn_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) |
|
|
self.cross_attn = CrossAttention( |
|
|
processor=CrossAttnProcessor(), |
|
|
dim=dim, |
|
|
dim_to_k=text_dim, |
|
|
heads=heads, |
|
|
dim_head=dim_head, |
|
|
dropout=dropout, |
|
|
) |
|
|
|
|
|
self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) |
|
|
self.ff = FeedForward( |
|
|
dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh" |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x, |
|
|
y, |
|
|
t, |
|
|
mask=None, |
|
|
rope=None, |
|
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t) |
|
|
|
|
|
|
|
|
attn_output = self.attn(x=norm, mask=mask, rope=rope) |
|
|
|
|
|
|
|
|
x = x + gate_msa.unsqueeze(1) * attn_output |
|
|
|
|
|
|
|
|
ca_norm = self.cross_attn_norm(x) |
|
|
cross_attn_output = self.cross_attn(ca_norm, y, mask=mask, rope=rope) |
|
|
x = x + cross_attn_output |
|
|
|
|
|
norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None] |
|
|
ff_output = self.ff(norm) |
|
|
x = x + gate_mlp.unsqueeze(1) * ff_output |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
|