# Copyright (C) 2025. Huawei Technologies Co., Ltd. All Rights Reserved. (authors: Dehua Tao) # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ ein notation: b - batch n - sequence nt - text sequence nw - raw wave length d - dimension """ from __future__ import annotations # import math # from typing import Optional import torch import torch.nn.functional as F # import torchaudio from librosa.filters import mel as librosa_mel_fn from torch import nn from x_transformers.x_transformers import apply_rotary_pos_emb mel_basis_cache = {} hann_window_cache = {} from f5_tts.model.modules import AdaLayerNormZero, Attention, AttnProcessor, FeedForward # Cross-attention with audio as query and text as key/value class CrossAttention(nn.Module): def __init__( self, processor: CrossAttnProcessor, dim: int, dim_to_k: int, heads: int = 8, dim_head: int = 64, dropout: float = 0.0, ): super().__init__() if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( "Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." ) self.processor = processor self.dim = dim self.heads = heads self.inner_dim = dim_head * heads self.dropout = dropout self.to_q = nn.Linear(dim, self.inner_dim) self.to_k = nn.Linear(dim_to_k, self.inner_dim) self.to_v = nn.Linear(dim_to_k, self.inner_dim) self.to_out = nn.ModuleList([]) self.to_out.append(nn.Linear(self.inner_dim, dim)) self.to_out.append(nn.Dropout(dropout)) def forward( self, x_for_q: float["b n d"], # (noisy + masked) audio input, x_for_q # noqa: F722 x_for_k: float["b n d"] = None, # text input, x_for_k # noqa: F722 mask: bool["b n"] | None = None, # noqa: F722 rope=None, # rotary position embedding for x ) -> torch.Tensor: return self.processor( self, x_for_q, x_for_k, mask=mask, rope=rope, ) # Cross-attention processor class CrossAttnProcessor: def __init__(self): pass def __call__( self, attn: CrossAttention, x_for_q: float["b n d"], # (noisy + masked) audio input, x_for_q # noqa: F722 x_for_k: float["b n d"], # text input, x_for_k # noqa: F722 mask: bool["b n"] | None = None, # noqa: F722 rope=None, # rotary position embedding ) -> torch.FloatTensor: batch_size = x_for_q.shape[0] # `sample` projections. query = attn.to_q(x_for_q) key = attn.to_k(x_for_k) value = attn.to_v(x_for_k) # apply rotary position embedding if rope is not None: freqs, xpos_scale = rope q_xpos_scale, k_xpos_scale = ( (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0) ) query = apply_rotary_pos_emb(query, freqs, q_xpos_scale) key = apply_rotary_pos_emb(key, freqs, k_xpos_scale) # attention inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # mask. e.g. inference got a batch with different target durations, mask out the padding if mask is not None: attn_mask = mask attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n' attn_mask = attn_mask.expand( batch_size, attn.heads, query.shape[-2], key.shape[-2] ) else: attn_mask = None x = F.scaled_dot_product_attention( query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False ) x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) x = x.to(query.dtype) # linear proj x = attn.to_out[0](x) # dropout x = attn.to_out[1](x) if mask is not None: mask = mask.unsqueeze(-1) x = x.masked_fill(~mask, 0.0) return x # Cross-attention DiT Block class CADiTBlock(nn.Module): def __init__(self, dim, text_dim, heads, dim_head, ff_mult=4, dropout=0.1): super().__init__() self.attn_norm = AdaLayerNormZero(dim) self.attn = Attention( processor=AttnProcessor(), dim=dim, heads=heads, dim_head=dim_head, dropout=dropout, ) self.cross_attn_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) self.cross_attn = CrossAttention( processor=CrossAttnProcessor(), dim=dim, dim_to_k=text_dim, heads=heads, dim_head=dim_head, dropout=dropout, ) self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) self.ff = FeedForward( dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh" ) def forward( self, x, y, t, mask=None, rope=None, ): # x: audio input, y: text input, t: time embedding ## for self-attention # pre-norm & modulation for attention input norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t) # attention attn_output = self.attn(x=norm, mask=mask, rope=rope) # process attention output for input x x = x + gate_msa.unsqueeze(1) * attn_output ## for cross-attention ca_norm = self.cross_attn_norm(x) cross_attn_output = self.cross_attn(ca_norm, y, mask=mask, rope=rope) x = x + cross_attn_output norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None] ff_output = self.ff(norm) x = x + gate_mlp.unsqueeze(1) * ff_output return x