rvq_proxy_network / README.md
lonesamurai's picture
Upload README.md with huggingface_hub
8d46dec verified
---
tags:
- tts
- voice-conversion
- speech-synthesis
- speaker-embedding
- speaker-proxy
- ecapa-tdnn
- qwen3-tts
license: mit
---
# Speaker Proxy Network (RVQ → Speaker Embedding)
A lightweight differentiable surrogate that maps **Qwen3-TTS RVQ embeddings** directly to **speaker embeddings**, bypassing the expensive audio-decoding → feature-extraction pipeline during voice-conversion training.
> ⚠️ **Note:** This repository contains **only the Speaker Proxy**. The full RVQ proxy (speaker + wav2vec + mel) is a separate effort. This checkpoint is the standalone speaker branch, trained with a pure contrastive objective on real speaker labels.
---
## Why a Speaker Proxy?
During voice-conversion training, the standard pipeline is:
```
model logits → argmax → RVQ tokens → decoder → waveform → ECAPA-TDNN → speaker embedding
```
This pipeline is **non-differentiable** because of `argmax` and the audio decoder. The Speaker Proxy replaces it with:
```
model logits → softmax → RVQ sum embedding → SpeakerProxyECAPA → L2-normalized speaker embedding
```
Everything after `softmax` is now differentiable, enabling end-to-end backpropagation through the entire voice-conversion objective.
---
## Architecture
**SpeakerProxyECAPA** — an ECAPA-TDNN-style network adapted for RVQ-sum inputs.
| Component | Details |
|-----------|---------|
| Input | `[B, T, 2048]` RVQ sum embedding (sum of 16 learned codebook embeddings) |
| Front-end | Conv1d projection + SE-Res2Blocks (dilations 2, 3, 4) |
| Pooling | Attentive Statistics Pooling (mean + std, attention-weighted) |
| Bottleneck | FC → 192-dim |
| Output | L2-normalized 192-dim speaker embedding |
| **Parameters** | **~4.6M** |
The architecture mirrors the original SpeechBrain ECAPA-TDNN but is trained end-to-end on RVQ inputs rather than raw audio spectrograms.
---
## Training
| Detail | Value |
|--------|-------|
| Dataset | `lonesamurai/emilia_clean_10k` (10,000 clips, 200 speakers) |
| Train / Val split | 8,000 / 2,000 clips |
| Epochs | ~200 |
| Loss | Pure contrastive — `(1−cos)²` alignment + `λ·ReLU(cos−margin)²` repulsion |
| λ (repel) | 5.0 |
| Optimizer | AdamW, lr = 1e-4, weight_decay = 1e-5 |
| Best val separation | **0.8141** |
### Validation performance (contrastive separation metric)
- **Best checkpoint:** epoch ~140, separation = **0.8141**
- **Final checkpoint:** epoch ~197, separation ≈ 0.73 (plateaued)
---
## Comparison with Original ECAPA-TDNN
Tested on 5 seen + 5 unseen speakers from EMILIA:
| Metric | SpeakerProxy (Ours) | Original ECAPA-TDNN |
|--------|---------------------|---------------------|
| Seen-Seen off-diag mean | **0.050** | 0.094 |
| Unseen-Unseen off-diag mean | **−0.026** | 0.060 |
| Seen-Unseen off-diag mean | **−0.026** | 0.033 |
| **All off-diag mean** | **−0.009** | 0.053 |
| Off-diag std | 0.156 | **0.098** |
| Worst confusion (max) | 0.420 | **0.327** |
| Per-speaker separation (seen avg) | **0.992** | 0.940 |
| Per-speaker separation (unseen avg) | **1.024** | 0.955 |
**Takeaway:** Our proxy achieves **stronger average separation** than the original audio-based ECAPA, especially on **unseen speakers** (negative mean similarity vs. positive). The trade-off is slightly higher variance — a few outlier pairs show stronger confusion, but the vast majority of speaker pairs are pushed farther apart.
---
## Checkpoints
| File | Description |
|------|-------------|
| `speaker_proxy_10k_best.pt` | **Best checkpoint** (val separation = 0.8141, ~epoch 140) |
The checkpoint contains:
- `model_state_dict`: full network weights
- `config`: architecture hyperparameters
- `epoch`: training epoch at save time
- `val_separation`: best validation metric
---
## Usage
```python
import torch
from exiv.components.models.qwen3_tts.sern.speaker_proxy_ecapa import SpeakerProxyECAPA
# Load checkpoint
checkpoint = torch.load("speaker_proxy_10k_best.pt", map_location="cpu")
config = checkpoint["config"]
# Build model
proxy = SpeakerProxyECAPA(
input_dim=config["input_dim"], # 2048
embed_dim=config["embed_dim"], # 192
channels=config["channels"], # 512
num_blocks=config["num_blocks"], # 3
)
proxy.load_state_dict(checkpoint["model_state_dict"])
proxy.eval().cuda()
# Forward pass — E_rvq is the sum of 16 RVQ embedding tables
# E_rvq: [B, T, 2048] from Qwen3-TTS RVQ tokens
speaker_embedding = proxy(E_rvq) # [B, 192], L2-normalized
```
### Computing RVQ sum embeddings from Qwen3-TTS tokens
```python
# Extract the 16 embedding tables from Qwen3-TTS
embedding_tables = [
model.model.embed_tokens[i].weight for i in range(16)
]
# tokens: [B, T, 16] integer RVQ indices
E_rvq = torch.stack([
embedding_tables[i][tokens[..., i]] for i in range(16)
], dim=-1).sum(dim=-1) # [B, T, 2048]
```
---
## Requirements
- PyTorch ≥ 2.0
- See [Exiv](https://github.com/piyushK52/Exiv) for full integration with Qwen3-TTS SERN adapter
---
## License
MIT