lonesamurai commited on
Commit
8d46dec
·
verified ·
1 Parent(s): 538ce97

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +116 -27
README.md CHANGED
@@ -3,66 +3,155 @@ tags:
3
  - tts
4
  - voice-conversion
5
  - speech-synthesis
6
- - differentiable-proxy
 
 
7
  - qwen3-tts
8
  license: mit
9
  ---
10
 
11
- # RVQ Proxy Network
12
 
13
- A lightweight differentiable surrogate network that maps from Qwen3-TTS RVQ embedding space to high-level perceptual audio features: speaker embedding, wav2vec2 content features, and mel spectrogram.
14
 
15
- ## Purpose
16
 
17
- During voice conversion training, the standard pipeline (logits → argmax → RVQ tokens → decoder → waveform → feature extractors) is non-differentiable. The RVQ Proxy replaces this with a tiny differentiable network, enabling end-to-end training without audio decoding.
 
 
 
 
18
 
19
  ```
20
- model logits → softmaxE_softRVQProxyspeaker / wav2vec / mel
21
  ```
22
 
 
 
 
 
 
 
 
 
 
 
23
  ## Architecture
24
 
25
- - **Shared temporal encoder:** 3-layer 1D conv (receptive field ~560ms) with GroupNorm + GELU
26
- - **Speaker head:** 2-layer MLP + mean pooling → 2048-dim speaker embedding
27
- - **Wav2vec head:** Single linear projection → 768-dim features
28
- - **Mel head:** 2-layer MLP → 80-bin mel spectrogram
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- **Parameters:** ~6.7M
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  ## Checkpoints
33
 
34
  | File | Description |
35
  |------|-------------|
36
- | `rvq_proxy_10k.pt` | Best checkpoint (val speaker cosine = 0.9925) |
37
- | `rvq_proxy_10k_final.pt` | Final epoch checkpoint (epoch 20) |
38
 
39
- Both checkpoints include metadata (`input_dim`, `num_speaker_dims`) for easy loading.
 
 
 
 
 
 
40
 
41
  ## Usage
42
 
43
  ```python
44
- from exiv.components.models.qwen3_tts.sern.rvq_proxy import RVQProxy
45
  import torch
46
-
47
- ckpt = torch.load("rvq_proxy_10k.pt", map_location="cpu")
48
- proxy = RVQProxy(
49
- input_dim=ckpt["input_dim"],
50
- num_speaker_dims=ckpt["num_speaker_dims"]
 
 
 
 
 
 
 
51
  )
52
- proxy.load_state_dict(ckpt["proxy_state"])
53
  proxy.eval().cuda()
54
 
55
- # Forward pass
56
- out = proxy(E_soft, mask=mask) # E_soft: [B, T, 512]
57
- speaker = out["speaker"] # [B, 2048]
58
- wav2vec = out["wav2vec"] # [B, T, 768]
59
- mel = out["mel"] # [B, T, 80]
 
 
 
 
 
 
 
 
 
 
 
 
60
  ```
61
 
 
 
62
  ## Requirements
63
 
64
  - PyTorch ≥ 2.0
65
- - See [Exiv](https://github.com/piyushK52/Exiv) for full integration with Qwen3-TTS
 
 
66
 
67
  ## License
68
 
 
3
  - tts
4
  - voice-conversion
5
  - speech-synthesis
6
+ - speaker-embedding
7
+ - speaker-proxy
8
+ - ecapa-tdnn
9
  - qwen3-tts
10
  license: mit
11
  ---
12
 
13
+ # Speaker Proxy Network (RVQ → Speaker Embedding)
14
 
15
+ A lightweight differentiable surrogate that maps **Qwen3-TTS RVQ embeddings** directly to **speaker embeddings**, bypassing the expensive audio-decoding feature-extraction pipeline during voice-conversion training.
16
 
17
+ > ⚠️ **Note:** This repository contains **only the Speaker Proxy**. The full RVQ proxy (speaker + wav2vec + mel) is a separate effort. This checkpoint is the standalone speaker branch, trained with a pure contrastive objective on real speaker labels.
18
 
19
+ ---
20
+
21
+ ## Why a Speaker Proxy?
22
+
23
+ During voice-conversion training, the standard pipeline is:
24
 
25
  ```
26
+ model logits → argmaxRVQ tokens decoderwaveform ECAPA-TDNN speaker embedding
27
  ```
28
 
29
+ This pipeline is **non-differentiable** because of `argmax` and the audio decoder. The Speaker Proxy replaces it with:
30
+
31
+ ```
32
+ model logits → softmax → RVQ sum embedding → SpeakerProxyECAPA → L2-normalized speaker embedding
33
+ ```
34
+
35
+ Everything after `softmax` is now differentiable, enabling end-to-end backpropagation through the entire voice-conversion objective.
36
+
37
+ ---
38
+
39
  ## Architecture
40
 
41
+ **SpeakerProxyECAPA** an ECAPA-TDNN-style network adapted for RVQ-sum inputs.
42
+
43
+ | Component | Details |
44
+ |-----------|---------|
45
+ | Input | `[B, T, 2048]` RVQ sum embedding (sum of 16 learned codebook embeddings) |
46
+ | Front-end | Conv1d projection + SE-Res2Blocks (dilations 2, 3, 4) |
47
+ | Pooling | Attentive Statistics Pooling (mean + std, attention-weighted) |
48
+ | Bottleneck | FC → 192-dim |
49
+ | Output | L2-normalized 192-dim speaker embedding |
50
+ | **Parameters** | **~4.6M** |
51
+
52
+ The architecture mirrors the original SpeechBrain ECAPA-TDNN but is trained end-to-end on RVQ inputs rather than raw audio spectrograms.
53
+
54
+ ---
55
+
56
+ ## Training
57
+
58
+ | Detail | Value |
59
+ |--------|-------|
60
+ | Dataset | `lonesamurai/emilia_clean_10k` (10,000 clips, 200 speakers) |
61
+ | Train / Val split | 8,000 / 2,000 clips |
62
+ | Epochs | ~200 |
63
+ | Loss | Pure contrastive — `(1−cos)²` alignment + `λ·ReLU(cos−margin)²` repulsion |
64
+ | λ (repel) | 5.0 |
65
+ | Optimizer | AdamW, lr = 1e-4, weight_decay = 1e-5 |
66
+ | Best val separation | **0.8141** |
67
+
68
+ ### Validation performance (contrastive separation metric)
69
+
70
+ - **Best checkpoint:** epoch ~140, separation = **0.8141**
71
+ - **Final checkpoint:** epoch ~197, separation ≈ 0.73 (plateaued)
72
 
73
+ ---
74
+
75
+ ## Comparison with Original ECAPA-TDNN
76
+
77
+ Tested on 5 seen + 5 unseen speakers from EMILIA:
78
+
79
+ | Metric | SpeakerProxy (Ours) | Original ECAPA-TDNN |
80
+ |--------|---------------------|---------------------|
81
+ | Seen-Seen off-diag mean | **0.050** | 0.094 |
82
+ | Unseen-Unseen off-diag mean | **−0.026** | 0.060 |
83
+ | Seen-Unseen off-diag mean | **−0.026** | 0.033 |
84
+ | **All off-diag mean** | **−0.009** | 0.053 |
85
+ | Off-diag std | 0.156 | **0.098** |
86
+ | Worst confusion (max) | 0.420 | **0.327** |
87
+ | Per-speaker separation (seen avg) | **0.992** | 0.940 |
88
+ | Per-speaker separation (unseen avg) | **1.024** | 0.955 |
89
+
90
+ **Takeaway:** Our proxy achieves **stronger average separation** than the original audio-based ECAPA, especially on **unseen speakers** (negative mean similarity vs. positive). The trade-off is slightly higher variance — a few outlier pairs show stronger confusion, but the vast majority of speaker pairs are pushed farther apart.
91
+
92
+ ---
93
 
94
  ## Checkpoints
95
 
96
  | File | Description |
97
  |------|-------------|
98
+ | `speaker_proxy_10k_best.pt` | **Best checkpoint** (val separation = 0.8141, ~epoch 140) |
 
99
 
100
+ The checkpoint contains:
101
+ - `model_state_dict`: full network weights
102
+ - `config`: architecture hyperparameters
103
+ - `epoch`: training epoch at save time
104
+ - `val_separation`: best validation metric
105
+
106
+ ---
107
 
108
  ## Usage
109
 
110
  ```python
 
111
  import torch
112
+ from exiv.components.models.qwen3_tts.sern.speaker_proxy_ecapa import SpeakerProxyECAPA
113
+
114
+ # Load checkpoint
115
+ checkpoint = torch.load("speaker_proxy_10k_best.pt", map_location="cpu")
116
+ config = checkpoint["config"]
117
+
118
+ # Build model
119
+ proxy = SpeakerProxyECAPA(
120
+ input_dim=config["input_dim"], # 2048
121
+ embed_dim=config["embed_dim"], # 192
122
+ channels=config["channels"], # 512
123
+ num_blocks=config["num_blocks"], # 3
124
  )
125
+ proxy.load_state_dict(checkpoint["model_state_dict"])
126
  proxy.eval().cuda()
127
 
128
+ # Forward pass — E_rvq is the sum of 16 RVQ embedding tables
129
+ # E_rvq: [B, T, 2048] from Qwen3-TTS RVQ tokens
130
+ speaker_embedding = proxy(E_rvq) # [B, 192], L2-normalized
131
+ ```
132
+
133
+ ### Computing RVQ sum embeddings from Qwen3-TTS tokens
134
+
135
+ ```python
136
+ # Extract the 16 embedding tables from Qwen3-TTS
137
+ embedding_tables = [
138
+ model.model.embed_tokens[i].weight for i in range(16)
139
+ ]
140
+
141
+ # tokens: [B, T, 16] integer RVQ indices
142
+ E_rvq = torch.stack([
143
+ embedding_tables[i][tokens[..., i]] for i in range(16)
144
+ ], dim=-1).sum(dim=-1) # [B, T, 2048]
145
  ```
146
 
147
+ ---
148
+
149
  ## Requirements
150
 
151
  - PyTorch ≥ 2.0
152
+ - See [Exiv](https://github.com/piyushK52/Exiv) for full integration with Qwen3-TTS SERN adapter
153
+
154
+ ---
155
 
156
  ## License
157