primepake commited on
Commit
feec49e
·
1 Parent(s): d1b8469

update speaker encoder

Browse files
Files changed (1) hide show
  1. speech/cosyvoice/llm/llm.py +135 -2
speech/cosyvoice/llm/llm.py CHANGED
@@ -27,7 +27,72 @@ from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
27
  from cosyvoice.utils.common import th_accuracy
28
  from cosyvoice.utils.file_utils import logging
29
  from cosyvoice.utils.mask import make_pad_mask
30
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  class TransformerLM(torch.nn.Module):
33
  def __init__(
@@ -43,6 +108,8 @@ class TransformerLM(torch.nn.Module):
43
  length_normalized_loss: bool = True,
44
  lsm_weight: float = 0.0,
45
  spk_embed_dim: int = 192,
 
 
46
  ):
47
  super().__init__()
48
  self.llm_input_size = llm_input_size
@@ -69,12 +136,78 @@ class TransformerLM(torch.nn.Module):
69
  )
70
 
71
  # 3. [Optional] build speech token related modules
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
73
- self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size)
74
 
75
  # 4. sampling method
76
  self.sampling = sampling
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  def encode(
79
  self,
80
  text: torch.Tensor,
 
27
  from cosyvoice.utils.common import th_accuracy
28
  from cosyvoice.utils.file_utils import logging
29
  from cosyvoice.utils.mask import make_pad_mask
30
+ from cosyvoice.transformer.attention import MultiHeadedAttention
31
+ from cosyvoice.transformer.encoder_layer import ConformerEncoderLayer
32
+
33
+ class LearnableSpeakerEncoder(nn.Module):
34
+ """
35
+ Speaker encoder inspired by Tortoise-TTS for CosyVoice2.
36
+ Processes mel-spectrograms through attention blocks to extract speaker characteristics.
37
+ """
38
+ def __init__(
39
+ self,
40
+ mel_dim: int = 80,
41
+ model_dim: int = 512,
42
+ output_dim: int = 192,
43
+ num_blocks: int = 6,
44
+ num_heads: int = 8,
45
+ kernel_size: int = 1,
46
+ dropout: float = 0.0,
47
+ ):
48
+ super().__init__()
49
+
50
+ # Initial projection (like Tortoise's ConditioningEncoder)
51
+ self.init = nn.Conv1d(mel_dim, model_dim, kernel_size=kernel_size)
52
+
53
+ self.blocks = nn.ModuleList([
54
+ ConformerEncoderLayer(
55
+ size=model_dim,
56
+ self_attn=MultiHeadedAttention(
57
+ n_head=num_heads,
58
+ n_feat=model_dim,
59
+ dropout_rate=dropout,
60
+ ),
61
+ feed_forward=None, # Can add if needed
62
+ feed_forward_macaron=None,
63
+ conv_module=None,
64
+ dropout_rate=dropout,
65
+ normalize_before=True,
66
+ ) for _ in range(num_blocks)
67
+ ])
68
+
69
+ # Output projection
70
+ self.output_proj = nn.Linear(model_dim, output_dim)
71
+
72
+ def forward(self, x, mask=None):
73
+ """
74
+ Args:
75
+ x: mel-spectrogram [B, 80, T]
76
+ mask: padding mask [B, 1, T]
77
+ Returns:
78
+ speaker embedding [B, output_dim]
79
+ """
80
+ # Initial conv
81
+ h = self.init(x) # [B, model_dim, T]
82
+ h = h.transpose(1, 2) # [B, T, model_dim]
83
+
84
+ # Apply attention blocks
85
+ for block in self.blocks:
86
+ h, _ = block(h, mask)
87
+
88
+ # Pool over time (take first position like Tortoise)
89
+ # Could also use mean pooling
90
+ output = h[:, 0, :] # [B, model_dim]
91
+
92
+ # Project to output dimension
93
+ output = self.output_proj(output) # [B, output_dim]
94
+
95
+ return F.normalize(output, p=2, dim=1)
96
 
97
  class TransformerLM(torch.nn.Module):
98
  def __init__(
 
108
  length_normalized_loss: bool = True,
109
  lsm_weight: float = 0.0,
110
  spk_embed_dim: int = 192,
111
+ use_speaker_encoder: bool = False,
112
+ max_conditioning_inputs: int = 3,
113
  ):
114
  super().__init__()
115
  self.llm_input_size = llm_input_size
 
136
  )
137
 
138
  # 3. [Optional] build speech token related modules
139
+
140
+ self.use_speaker_encoder = use_speaker_encoder
141
+ self.max_conditioning_inputs = max_conditioning_inputs
142
+
143
+ if use_speaker_encoder:
144
+ self.speaker_encoder = LearnableSpeakerEncoder(
145
+ mel_dim=80,
146
+ model_dim=512,
147
+ output_dim=spk_embed_dim,
148
+ num_blocks=6,
149
+ num_heads=8,
150
+ )
151
+ self.spk_embed_affine_layer = nn.Linear(spk_embed_dim, llm_input_size)
152
+ else:
153
+ # Fallback to embedding-based approach
154
+ self.spk_embed_affine_layer = nn.Linear(spk_embed_dim, llm_input_size)
155
+
156
  self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
 
157
 
158
  # 4. sampling method
159
  self.sampling = sampling
160
 
161
+ def get_speaker_conditioning(self, batch, device):
162
+ """Extract speaker conditioning from reference audio or embeddings."""
163
+
164
+ if self.use_speaker_encoder and 'reference_mels' in batch:
165
+ reference_mels = batch['reference_mels'].to(device)
166
+
167
+ # Handle multiple references
168
+ if reference_mels.dim() == 4: # [B, N, T, 80]
169
+ B, N, T, D = reference_mels.shape
170
+ conds = []
171
+
172
+ for i in range(N):
173
+ ref_mel = reference_mels[:, i, :, :].transpose(1, 2) # [B, 80, T]
174
+ if 'reference_mel_masks' in batch:
175
+ mask = batch['reference_mel_masks'][:, i, :].unsqueeze(1).to(device)
176
+ else:
177
+ mask = None
178
+
179
+ cond = self.speaker_encoder(ref_mel, mask) # [B, spk_embed_dim]
180
+ conds.append(cond)
181
+
182
+ # Average multiple references (like Tortoise)
183
+ speaker_embed = torch.stack(conds, dim=1).mean(dim=1) # [B, spk_embed_dim]
184
+
185
+ else: # Single reference [B, T, 80]
186
+ ref_mel = reference_mels.transpose(1, 2) # [B, 80, T]
187
+ if 'reference_mel_mask' in batch:
188
+ mask = batch['reference_mel_mask'].unsqueeze(1).to(device)
189
+ else:
190
+ mask = None
191
+ speaker_embed = self.speaker_encoder(ref_mel, mask)
192
+
193
+ # Project to LLM dimension
194
+ speaker_embed = self.spk_embed_affine_layer(speaker_embed)
195
+ speaker_embed = speaker_embed.unsqueeze(1) # [B, 1, llm_input_size]
196
+
197
+ elif 'embedding' in batch:
198
+ # Use provided embeddings (backward compatibility)
199
+ embedding = batch['embedding'].to(device)
200
+ embedding = F.normalize(embedding, dim=1)
201
+ speaker_embed = self.spk_embed_affine_layer(embedding)
202
+ speaker_embed = speaker_embed.unsqueeze(1)
203
+
204
+ else:
205
+ # No speaker conditioning
206
+ B = batch['text_token'].shape[0]
207
+ speaker_embed = torch.zeros(B, 1, self.llm_input_size).to(device)
208
+
209
+ return speaker_embed
210
+
211
  def encode(
212
  self,
213
  text: torch.Tensor,