gooorillax's picture
first push of codes and models for g2p, t2u, tokenizer and detokenizer
cd8454d
# Copyright (C) 2025. Huawei Technologies Co., Ltd. All Rights Reserved. (authors: Dehua Tao)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
ein notation:
b - batch
n - sequence
nt - text sequence
nw - raw wave length
d - dimension
"""
from __future__ import annotations
# import math
# from typing import Optional
import torch
import torch.nn.functional as F
# import torchaudio
from librosa.filters import mel as librosa_mel_fn
from torch import nn
from x_transformers.x_transformers import apply_rotary_pos_emb
mel_basis_cache = {}
hann_window_cache = {}
from f5_tts.model.modules import AdaLayerNormZero, Attention, AttnProcessor, FeedForward
# Cross-attention with audio as query and text as key/value
class CrossAttention(nn.Module):
def __init__(
self,
processor: CrossAttnProcessor,
dim: int,
dim_to_k: int,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
self.processor = processor
self.dim = dim
self.heads = heads
self.inner_dim = dim_head * heads
self.dropout = dropout
self.to_q = nn.Linear(dim, self.inner_dim)
self.to_k = nn.Linear(dim_to_k, self.inner_dim)
self.to_v = nn.Linear(dim_to_k, self.inner_dim)
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(self.inner_dim, dim))
self.to_out.append(nn.Dropout(dropout))
def forward(
self,
x_for_q: float["b n d"], # (noisy + masked) audio input, x_for_q # noqa: F722
x_for_k: float["b n d"] = None, # text input, x_for_k # noqa: F722
mask: bool["b n"] | None = None, # noqa: F722
rope=None, # rotary position embedding for x
) -> torch.Tensor:
return self.processor(
self,
x_for_q,
x_for_k,
mask=mask,
rope=rope,
)
# Cross-attention processor
class CrossAttnProcessor:
def __init__(self):
pass
def __call__(
self,
attn: CrossAttention,
x_for_q: float["b n d"], # (noisy + masked) audio input, x_for_q # noqa: F722
x_for_k: float["b n d"], # text input, x_for_k # noqa: F722
mask: bool["b n"] | None = None, # noqa: F722
rope=None, # rotary position embedding
) -> torch.FloatTensor:
batch_size = x_for_q.shape[0]
# `sample` projections.
query = attn.to_q(x_for_q)
key = attn.to_k(x_for_k)
value = attn.to_v(x_for_k)
# apply rotary position embedding
if rope is not None:
freqs, xpos_scale = rope
q_xpos_scale, k_xpos_scale = (
(xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
)
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
# attention
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# mask. e.g. inference got a batch with different target durations, mask out the padding
if mask is not None:
attn_mask = mask
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
attn_mask = attn_mask.expand(
batch_size, attn.heads, query.shape[-2], key.shape[-2]
)
else:
attn_mask = None
x = F.scaled_dot_product_attention(
query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False
)
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
x = x.to(query.dtype)
# linear proj
x = attn.to_out[0](x)
# dropout
x = attn.to_out[1](x)
if mask is not None:
mask = mask.unsqueeze(-1)
x = x.masked_fill(~mask, 0.0)
return x
# Cross-attention DiT Block
class CADiTBlock(nn.Module):
def __init__(self, dim, text_dim, heads, dim_head, ff_mult=4, dropout=0.1):
super().__init__()
self.attn_norm = AdaLayerNormZero(dim)
self.attn = Attention(
processor=AttnProcessor(),
dim=dim,
heads=heads,
dim_head=dim_head,
dropout=dropout,
)
self.cross_attn_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.cross_attn = CrossAttention(
processor=CrossAttnProcessor(),
dim=dim,
dim_to_k=text_dim,
heads=heads,
dim_head=dim_head,
dropout=dropout,
)
self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(
dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh"
)
def forward(
self,
x,
y,
t,
mask=None,
rope=None,
): # x: audio input, y: text input, t: time embedding
## for self-attention
# pre-norm & modulation for attention input
norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
# attention
attn_output = self.attn(x=norm, mask=mask, rope=rope)
# process attention output for input x
x = x + gate_msa.unsqueeze(1) * attn_output
## for cross-attention
ca_norm = self.cross_attn_norm(x)
cross_attn_output = self.cross_attn(ca_norm, y, mask=mask, rope=rope)
x = x + cross_attn_output
norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
ff_output = self.ff(norm)
x = x + gate_mlp.unsqueeze(1) * ff_output
return x