rvq_proxy_network / README.md
lonesamurai's picture
Upload README.md with huggingface_hub
8d46dec verified
metadata
tags:
  - tts
  - voice-conversion
  - speech-synthesis
  - speaker-embedding
  - speaker-proxy
  - ecapa-tdnn
  - qwen3-tts
license: mit

Speaker Proxy Network (RVQ → Speaker Embedding)

A lightweight differentiable surrogate that maps Qwen3-TTS RVQ embeddings directly to speaker embeddings, bypassing the expensive audio-decoding → feature-extraction pipeline during voice-conversion training.

⚠️ Note: This repository contains only the Speaker Proxy. The full RVQ proxy (speaker + wav2vec + mel) is a separate effort. This checkpoint is the standalone speaker branch, trained with a pure contrastive objective on real speaker labels.


Why a Speaker Proxy?

During voice-conversion training, the standard pipeline is:

model logits → argmax → RVQ tokens → decoder → waveform → ECAPA-TDNN → speaker embedding

This pipeline is non-differentiable because of argmax and the audio decoder. The Speaker Proxy replaces it with:

model logits → softmax → RVQ sum embedding → SpeakerProxyECAPA → L2-normalized speaker embedding

Everything after softmax is now differentiable, enabling end-to-end backpropagation through the entire voice-conversion objective.


Architecture

SpeakerProxyECAPA — an ECAPA-TDNN-style network adapted for RVQ-sum inputs.

Component Details
Input [B, T, 2048] RVQ sum embedding (sum of 16 learned codebook embeddings)
Front-end Conv1d projection + SE-Res2Blocks (dilations 2, 3, 4)
Pooling Attentive Statistics Pooling (mean + std, attention-weighted)
Bottleneck FC → 192-dim
Output L2-normalized 192-dim speaker embedding
Parameters ~4.6M

The architecture mirrors the original SpeechBrain ECAPA-TDNN but is trained end-to-end on RVQ inputs rather than raw audio spectrograms.


Training

Detail Value
Dataset lonesamurai/emilia_clean_10k (10,000 clips, 200 speakers)
Train / Val split 8,000 / 2,000 clips
Epochs ~200
Loss Pure contrastive — (1−cos)² alignment + λ·ReLU(cos−margin)² repulsion
λ (repel) 5.0
Optimizer AdamW, lr = 1e-4, weight_decay = 1e-5
Best val separation 0.8141

Validation performance (contrastive separation metric)

  • Best checkpoint: epoch ~140, separation = 0.8141
  • Final checkpoint: epoch ~197, separation ≈ 0.73 (plateaued)

Comparison with Original ECAPA-TDNN

Tested on 5 seen + 5 unseen speakers from EMILIA:

Metric SpeakerProxy (Ours) Original ECAPA-TDNN
Seen-Seen off-diag mean 0.050 0.094
Unseen-Unseen off-diag mean −0.026 0.060
Seen-Unseen off-diag mean −0.026 0.033
All off-diag mean −0.009 0.053
Off-diag std 0.156 0.098
Worst confusion (max) 0.420 0.327
Per-speaker separation (seen avg) 0.992 0.940
Per-speaker separation (unseen avg) 1.024 0.955

Takeaway: Our proxy achieves stronger average separation than the original audio-based ECAPA, especially on unseen speakers (negative mean similarity vs. positive). The trade-off is slightly higher variance — a few outlier pairs show stronger confusion, but the vast majority of speaker pairs are pushed farther apart.


Checkpoints

File Description
speaker_proxy_10k_best.pt Best checkpoint (val separation = 0.8141, ~epoch 140)

The checkpoint contains:

  • model_state_dict: full network weights
  • config: architecture hyperparameters
  • epoch: training epoch at save time
  • val_separation: best validation metric

Usage

import torch
from exiv.components.models.qwen3_tts.sern.speaker_proxy_ecapa import SpeakerProxyECAPA

# Load checkpoint
checkpoint = torch.load("speaker_proxy_10k_best.pt", map_location="cpu")
config = checkpoint["config"]

# Build model
proxy = SpeakerProxyECAPA(
    input_dim=config["input_dim"],      # 2048
    embed_dim=config["embed_dim"],      # 192
    channels=config["channels"],        # 512
    num_blocks=config["num_blocks"],    # 3
)
proxy.load_state_dict(checkpoint["model_state_dict"])
proxy.eval().cuda()

# Forward pass — E_rvq is the sum of 16 RVQ embedding tables
# E_rvq: [B, T, 2048] from Qwen3-TTS RVQ tokens
speaker_embedding = proxy(E_rvq)  # [B, 192], L2-normalized

Computing RVQ sum embeddings from Qwen3-TTS tokens

# Extract the 16 embedding tables from Qwen3-TTS
embedding_tables = [
    model.model.embed_tokens[i].weight for i in range(16)
]

# tokens: [B, T, 16] integer RVQ indices
E_rvq = torch.stack([
    embedding_tables[i][tokens[..., i]] for i in range(16)
], dim=-1).sum(dim=-1)  # [B, T, 2048]

Requirements

  • PyTorch ≥ 2.0
  • See Exiv for full integration with Qwen3-TTS SERN adapter

License

MIT