--- tags: - tts - voice-conversion - speech-synthesis - speaker-embedding - speaker-proxy - ecapa-tdnn - qwen3-tts license: mit --- # Speaker Proxy Network (RVQ → Speaker Embedding) A lightweight differentiable surrogate that maps **Qwen3-TTS RVQ embeddings** directly to **speaker embeddings**, bypassing the expensive audio-decoding → feature-extraction pipeline during voice-conversion training. > ⚠️ **Note:** This repository contains **only the Speaker Proxy**. The full RVQ proxy (speaker + wav2vec + mel) is a separate effort. This checkpoint is the standalone speaker branch, trained with a pure contrastive objective on real speaker labels. --- ## Why a Speaker Proxy? During voice-conversion training, the standard pipeline is: ``` model logits → argmax → RVQ tokens → decoder → waveform → ECAPA-TDNN → speaker embedding ``` This pipeline is **non-differentiable** because of `argmax` and the audio decoder. The Speaker Proxy replaces it with: ``` model logits → softmax → RVQ sum embedding → SpeakerProxyECAPA → L2-normalized speaker embedding ``` Everything after `softmax` is now differentiable, enabling end-to-end backpropagation through the entire voice-conversion objective. --- ## Architecture **SpeakerProxyECAPA** — an ECAPA-TDNN-style network adapted for RVQ-sum inputs. | Component | Details | |-----------|---------| | Input | `[B, T, 2048]` RVQ sum embedding (sum of 16 learned codebook embeddings) | | Front-end | Conv1d projection + SE-Res2Blocks (dilations 2, 3, 4) | | Pooling | Attentive Statistics Pooling (mean + std, attention-weighted) | | Bottleneck | FC → 192-dim | | Output | L2-normalized 192-dim speaker embedding | | **Parameters** | **~4.6M** | The architecture mirrors the original SpeechBrain ECAPA-TDNN but is trained end-to-end on RVQ inputs rather than raw audio spectrograms. --- ## Training | Detail | Value | |--------|-------| | Dataset | `lonesamurai/emilia_clean_10k` (10,000 clips, 200 speakers) | | Train / Val split | 8,000 / 2,000 clips | | Epochs | ~200 | | Loss | Pure contrastive — `(1−cos)²` alignment + `λ·ReLU(cos−margin)²` repulsion | | λ (repel) | 5.0 | | Optimizer | AdamW, lr = 1e-4, weight_decay = 1e-5 | | Best val separation | **0.8141** | ### Validation performance (contrastive separation metric) - **Best checkpoint:** epoch ~140, separation = **0.8141** - **Final checkpoint:** epoch ~197, separation ≈ 0.73 (plateaued) --- ## Comparison with Original ECAPA-TDNN Tested on 5 seen + 5 unseen speakers from EMILIA: | Metric | SpeakerProxy (Ours) | Original ECAPA-TDNN | |--------|---------------------|---------------------| | Seen-Seen off-diag mean | **0.050** | 0.094 | | Unseen-Unseen off-diag mean | **−0.026** | 0.060 | | Seen-Unseen off-diag mean | **−0.026** | 0.033 | | **All off-diag mean** | **−0.009** | 0.053 | | Off-diag std | 0.156 | **0.098** | | Worst confusion (max) | 0.420 | **0.327** | | Per-speaker separation (seen avg) | **0.992** | 0.940 | | Per-speaker separation (unseen avg) | **1.024** | 0.955 | **Takeaway:** Our proxy achieves **stronger average separation** than the original audio-based ECAPA, especially on **unseen speakers** (negative mean similarity vs. positive). The trade-off is slightly higher variance — a few outlier pairs show stronger confusion, but the vast majority of speaker pairs are pushed farther apart. --- ## Checkpoints | File | Description | |------|-------------| | `speaker_proxy_10k_best.pt` | **Best checkpoint** (val separation = 0.8141, ~epoch 140) | The checkpoint contains: - `model_state_dict`: full network weights - `config`: architecture hyperparameters - `epoch`: training epoch at save time - `val_separation`: best validation metric --- ## Usage ```python import torch from exiv.components.models.qwen3_tts.sern.speaker_proxy_ecapa import SpeakerProxyECAPA # Load checkpoint checkpoint = torch.load("speaker_proxy_10k_best.pt", map_location="cpu") config = checkpoint["config"] # Build model proxy = SpeakerProxyECAPA( input_dim=config["input_dim"], # 2048 embed_dim=config["embed_dim"], # 192 channels=config["channels"], # 512 num_blocks=config["num_blocks"], # 3 ) proxy.load_state_dict(checkpoint["model_state_dict"]) proxy.eval().cuda() # Forward pass — E_rvq is the sum of 16 RVQ embedding tables # E_rvq: [B, T, 2048] from Qwen3-TTS RVQ tokens speaker_embedding = proxy(E_rvq) # [B, 192], L2-normalized ``` ### Computing RVQ sum embeddings from Qwen3-TTS tokens ```python # Extract the 16 embedding tables from Qwen3-TTS embedding_tables = [ model.model.embed_tokens[i].weight for i in range(16) ] # tokens: [B, T, 16] integer RVQ indices E_rvq = torch.stack([ embedding_tables[i][tokens[..., i]] for i in range(16) ], dim=-1).sum(dim=-1) # [B, T, 2048] ``` --- ## Requirements - PyTorch ≥ 2.0 - See [Exiv](https://github.com/piyushK52/Exiv) for full integration with Qwen3-TTS SERN adapter --- ## License MIT