LEMAS-Edit / lemas_tts /model /backbones /prosody_encoder.py
Approximetal's picture
Upload folder using huggingface_hub
f36e46d verified
"""
Prosody encoder backbone based on the Pretssel ECAPA-TDNN architecture.
This module provides:
- ProsodyEncoder: wraps an ECAPA-TDNN model to produce utterance-level
prosody embeddings from 80-dim FBANK features.
- extract_fbank_16k: utility to compute 80-bin FBANK from 16kHz audio.
It is self-contained (no fairseq2 dependency) and can be used inside
CFM or other models as a conditioning network.
"""
from __future__ import annotations
from pathlib import Path
from typing import List, Optional, Tuple
import json
import torch
import torchaudio
from torch import Tensor
from torch import nn
from torch.nn import Conv1d, LayerNorm, Module, ModuleList, ReLU, Sigmoid, Tanh, init
import torch.nn.functional as F
AUDIO_SAMPLE_RATE = 16_000
class ECAPA_TDNN(Module):
"""
ECAPA-TDNN core used in Pretssel prosody encoder.
Expects input features of shape (B, T, C) with C=80 and returns
a normalized embedding of shape (B, embed_dim).
"""
def __init__(
self,
channels: List[int],
kernel_sizes: List[int],
dilations: List[int],
attention_channels: int,
res2net_scale: int,
se_channels: int,
global_context: bool,
groups: List[int],
embed_dim: int,
input_dim: int,
):
super().__init__()
assert len(channels) == len(kernel_sizes) == len(dilations)
self.channels = channels
self.embed_dim = embed_dim
self.blocks = ModuleList()
self.blocks.append(
TDNNBlock(
input_dim,
channels[0],
kernel_sizes[0],
dilations[0],
groups[0],
)
)
for i in range(1, len(channels) - 1):
self.blocks.append(
SERes2NetBlock(
channels[i - 1],
channels[i],
res2net_scale=res2net_scale,
se_channels=se_channels,
kernel_size=kernel_sizes[i],
dilation=dilations[i],
groups=groups[i],
)
)
self.mfa = TDNNBlock(
channels[-1],
channels[-1],
kernel_sizes[-1],
dilations[-1],
groups=groups[-1],
)
self.asp = AttentiveStatisticsPooling(
channels[-1],
attention_channels=attention_channels,
global_context=global_context,
)
self.asp_norm = LayerNorm(channels[-1] * 2, eps=1e-12)
self.fc = Conv1d(
in_channels=channels[-1] * 2,
out_channels=embed_dim,
kernel_size=1,
)
self.reset_parameters()
def reset_parameters(self) -> None:
def encoder_init(m: Module) -> None:
if isinstance(m, Conv1d):
init.xavier_uniform_(m.weight, init.calculate_gain("relu"))
self.apply(encoder_init)
def forward(
self,
x: Tensor,
padding_mask: Optional[Tensor] = None,
) -> Tensor:
# x: (B, T, C)
x = x.transpose(1, 2) # (B, C, T)
xl = []
for layer in self.blocks:
x = layer(x, padding_mask=padding_mask)
xl.append(x)
x = torch.cat(xl[1:], dim=1)
x = self.mfa(x)
x = self.asp(x, padding_mask=padding_mask)
x = self.asp_norm(x.transpose(1, 2)).transpose(1, 2)
x = self.fc(x)
x = x.transpose(1, 2).squeeze(1) # (B, embed_dim)
return F.normalize(x, dim=-1)
class TDNNBlock(Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
dilation: int,
groups: int = 1,
):
super().__init__()
self.conv = Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
dilation=dilation,
padding=dilation * (kernel_size - 1) // 2,
groups=groups,
)
self.activation = ReLU()
self.norm = LayerNorm(out_channels, eps=1e-12)
def forward(self, x: Tensor, padding_mask: Optional[Tensor] = None) -> Tensor:
x = self.activation(self.conv(x))
return self.norm(x.transpose(1, 2)).transpose(1, 2)
class Res2NetBlock(Module):
def __init__(
self,
in_channels: int,
out_channels: int,
scale: int = 8,
kernel_size: int = 3,
dilation: int = 1,
):
super().__init__()
assert in_channels % scale == 0
assert out_channels % scale == 0
in_channel = in_channels // scale
hidden_channel = out_channels // scale
self.blocks = ModuleList(
[
TDNNBlock(
in_channel,
hidden_channel,
kernel_size=kernel_size,
dilation=dilation,
)
for _ in range(scale - 1)
]
)
self.scale = scale
def forward(self, x: Tensor) -> Tensor:
y = []
for i, x_i in enumerate(torch.chunk(x, self.scale, dim=1)):
if i == 0:
y_i = x_i
elif i == 1:
y_i = self.blocks[i - 1](x_i)
else:
y_i = self.blocks[i - 1](x_i + y_i)
y.append(y_i)
return torch.cat(y, dim=1)
class SEBlock(Module):
def __init__(
self,
in_channels: int,
se_channels: int,
out_channels: int,
):
super().__init__()
self.conv1 = Conv1d(in_channels=in_channels, out_channels=se_channels, kernel_size=1)
self.relu = ReLU(inplace=True)
self.conv2 = Conv1d(in_channels=se_channels, out_channels=out_channels, kernel_size=1)
self.sigmoid = Sigmoid()
def forward(self, x: Tensor, padding_mask: Optional[Tensor] = None) -> Tensor:
if padding_mask is not None:
# padding_mask: (B, T) with 1 for valid, 0 for pad
mask = padding_mask.unsqueeze(1) # (B, 1, T)
lengths = mask.sum(dim=2, keepdim=True)
s = (x * mask).sum(dim=2, keepdim=True) / torch.clamp(lengths, min=1.0)
else:
s = x.mean(dim=2, keepdim=True)
s = self.relu(self.conv1(s))
s = self.sigmoid(self.conv2(s))
return s * x
class AttentiveStatisticsPooling(Module):
def __init__(
self, channels: int, attention_channels: int = 128, global_context: bool = True
):
super().__init__()
self.eps = 1e-12
self.global_context = global_context
if global_context:
self.tdnn = TDNNBlock(channels * 3, attention_channels, 1, 1)
else:
self.tdnn = TDNNBlock(channels, attention_channels, 1, 1)
self.tanh = Tanh()
self.conv = Conv1d(in_channels=attention_channels, out_channels=channels, kernel_size=1)
def forward(self, x: Tensor, padding_mask: Optional[Tensor] = None) -> Tensor:
# x: (N, C, L)
N, C, L = x.shape
def _compute_statistics(
x: Tensor, m: Tensor, dim: int = 2, eps: float = 1e-12
) -> Tuple[Tensor, Tensor]:
mean = (m * x).sum(dim)
std = torch.sqrt((m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps))
return mean, std
if padding_mask is not None:
mask = padding_mask
else:
mask = torch.ones(N, L, device=x.device, dtype=x.dtype)
mask = mask.unsqueeze(1) # (N, 1, L)
if self.global_context:
total = mask.sum(dim=2, keepdim=True).to(x)
mean, std = _compute_statistics(x, mask / total)
mean = mean.unsqueeze(2).repeat(1, 1, L)
std = std.unsqueeze(2).repeat(1, 1, L)
attn = torch.cat([x, mean, std], dim=1)
else:
attn = x
attn = self.conv(self.tanh(self.tdnn(attn)))
attn = attn.masked_fill(mask == 0, float("-inf"))
attn = F.softmax(attn, dim=2)
mean, std = _compute_statistics(x, attn)
pooled_stats = torch.cat((mean, std), dim=1)
pooled_stats = pooled_stats.unsqueeze(2)
return pooled_stats
class SERes2NetBlock(Module):
def __init__(
self,
in_channels: int,
out_channels: int,
res2net_scale: int = 8,
se_channels: int = 128,
kernel_size: int = 1,
dilation: int = 1,
groups: int = 1,
):
super().__init__()
self.out_channels = out_channels
self.tdnn1 = TDNNBlock(
in_channels,
out_channels,
kernel_size=1,
dilation=1,
groups=groups,
)
self.res2net_block = Res2NetBlock(
out_channels,
out_channels,
res2net_scale,
kernel_size,
dilation,
)
self.tdnn2 = TDNNBlock(
out_channels,
out_channels,
kernel_size=1,
dilation=1,
groups=groups,
)
self.se_block = SEBlock(out_channels, se_channels, out_channels)
self.shortcut = None
if in_channels != out_channels:
self.shortcut = Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
)
def forward(self, x: Tensor, padding_mask: Optional[Tensor] = None) -> Tensor:
residual = x
if self.shortcut:
residual = self.shortcut(x)
x = self.tdnn1(x)
x = self.res2net_block(x)
x = self.tdnn2(x)
x = self.se_block(x, padding_mask=padding_mask)
return x + residual
def extract_fbank_16k(audio_16k: Tensor) -> Tensor:
"""
Compute 80-dim FBANK features from 16kHz audio.
Args:
audio_16k: Tensor of shape (T,) or (1, T)
Returns:
fbank: Tensor of shape (T_fbank, 80)
"""
if audio_16k.ndim == 1:
audio_16k = audio_16k.unsqueeze(0)
# Ensure minimum length for kaldi.fbank window (default 25ms @16k -> 400 samples)
min_len = 400
if audio_16k.shape[-1] < min_len:
repeat_times = (min_len // audio_16k.shape[-1]) + 1
audio_16k = audio_16k.repeat(1, repeat_times) if audio_16k.dim() > 1 else audio_16k.repeat(repeat_times)
fbank = torchaudio.compliance.kaldi.fbank(
audio_16k,
num_mel_bins=80,
sample_frequency=AUDIO_SAMPLE_RATE,
)
return fbank
class ProsodyEncoder(nn.Module):
"""
High-level wrapper for the Pretssel prosody encoder.
Usage:
encoder = ProsodyEncoder(cfg_path, ckpt_path, freeze=True)
emb = encoder(fbank_batch) # (B, 512)
"""
def __init__(self, cfg_path: Path, ckpt_path: Path, freeze: bool = True):
super().__init__()
model_cfg = self._load_pretssel_model_cfg(cfg_path)
self.encoder = self._build_prosody_encoder(model_cfg)
self._load_prosody_encoder_state(self.encoder, ckpt_path)
if freeze:
for p in self.encoder.parameters():
p.requires_grad = False
@staticmethod
def _load_pretssel_model_cfg(cfg_path: Path) -> dict:
cfg = json.loads(cfg_path.read_text())
if "model" not in cfg:
raise ValueError(f"{cfg_path} does not contain a top-level 'model' key.")
return cfg["model"]
@staticmethod
def _build_prosody_encoder(model_cfg: dict) -> ECAPA_TDNN:
encoder = ECAPA_TDNN(
channels=model_cfg["prosody_channels"],
kernel_sizes=model_cfg["prosody_kernel_sizes"],
dilations=model_cfg["prosody_dilations"],
attention_channels=model_cfg["prosody_attention_channels"],
res2net_scale=model_cfg["prosody_res2net_scale"],
se_channels=model_cfg["prosody_se_channels"],
global_context=model_cfg["prosody_global_context"],
groups=model_cfg["prosody_groups"],
embed_dim=model_cfg["prosody_embed_dim"],
input_dim=model_cfg["input_feat_per_channel"],
)
return encoder
@staticmethod
def _load_prosody_encoder_state(model: Module, ckpt_path: Path) -> None:
state = torch.load(ckpt_path, map_location="cpu")
if isinstance(state, dict):
if all(isinstance(k, str) for k in state.keys()) and (
any(k.startswith("prosody_encoder.") for k in state.keys())
or any(k.startswith("prosody_encoder_model.") for k in state.keys())
):
state = {
k.replace("prosody_encoder_model.", "", 1).replace("prosody_encoder.", "", 1): v
for k, v in state.items()
if k.startswith("prosody_encoder.") or k.startswith("prosody_encoder_model.")
}
missing, unexpected = model.load_state_dict(state, strict=False)
if missing or unexpected:
raise RuntimeError(
f"Error loading checkpoint {ckpt_path}: missing keys={missing}, "
f"unexpected keys={unexpected}"
)
def forward(self, fbank: Tensor, padding_mask: Optional[Tensor] = None) -> Tensor:
"""
Args:
fbank: Tensor of shape (B, T, 80)
padding_mask: Optional tensor of shape (B, T) with 1 for valid.
Returns:
emb: Tensor of shape (B, 512)
"""
return self.encoder(fbank, padding_mask=padding_mask)