Ubuntu commited on
Commit
e422499
·
1 Parent(s): b6ec358

update llm

Browse files
Files changed (1) hide show
  1. speech/cosyvoice/llm/llm.py +28 -82
speech/cosyvoice/llm/llm.py CHANGED
@@ -162,51 +162,28 @@ class TransformerLM(torch.nn.Module):
162
 
163
  def get_speaker_conditioning(self, batch, device):
164
  """Extract speaker conditioning from reference audio or embeddings."""
 
 
 
 
 
165
 
166
- if self.use_speaker_encoder and 'reference_mels' in batch:
167
- reference_mels = batch['reference_mels'].to(device)
168
- print('reference_mels shape: ', reference_mels.shape)
169
- # Handle multiple references
170
- if reference_mels.dim() == 4: # [B, N, T, 80]
171
- B, N, C, T = reference_mels.shape
172
- conds = []
173
-
174
- for i in range(N):
175
- ref_mel = reference_mels[:, i, :, :]
176
- if 'reference_mel_masks' in batch:
177
- mask = batch['reference_mel_masks'][:, i, :].unsqueeze(1).to(device)
178
- else:
179
- mask = None
180
- print('ref_mel shape: ', ref_mel.shape, 'mask: ', mask.shape)
181
- cond = self.speaker_encoder(ref_mel, mask) # [B, spk_embed_dim]
182
- conds.append(cond)
183
-
184
- # Average multiple references (like Tortoise)
185
- speaker_embed = torch.stack(conds, dim=1).mean(dim=1) # [B, spk_embed_dim]
186
-
187
- else: # Single reference [B, T, 80]
188
- ref_mel = reference_mels.transpose(1, 2) # [B, 80, T]
189
- if 'reference_mel_mask' in batch:
190
- mask = batch['reference_mel_mask'].unsqueeze(1).to(device)
191
- else:
192
- mask = None
193
- speaker_embed = self.speaker_encoder(ref_mel, mask)
194
-
195
- # Project to LLM dimension
196
- speaker_embed = self.spk_embed_affine_layer(speaker_embed)
197
- speaker_embed = speaker_embed.unsqueeze(1) # [B, 1, llm_input_size]
198
-
199
- elif 'embedding' in batch:
200
- # Use provided embeddings (backward compatibility)
201
- embedding = batch['embedding'].to(device)
202
- embedding = F.normalize(embedding, dim=1)
203
- speaker_embed = self.spk_embed_affine_layer(embedding)
204
- speaker_embed = speaker_embed.unsqueeze(1)
205
-
206
- else:
207
- # No speaker conditioning
208
- B = batch['text_token'].shape[0]
209
- speaker_embed = torch.zeros(B, 1, self.llm_input_size).to(device)
210
 
211
  return speaker_embed
212
 
@@ -548,41 +525,6 @@ class Qwen2LM(TransformerLM):
548
  lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
549
  lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID)
550
  return lm_target, lm_input, lm_input_len
551
-
552
- # def forward(
553
- # self,
554
- # batch: dict,
555
- # device: torch.device,
556
- # ) -> Dict[str, Optional[torch.Tensor]]:
557
- # """
558
- # Args:
559
- # text: (B, L, D)
560
- # text_lengths: (B,)
561
- # audio: (B, T, N) or (B, T)
562
- # audio_lengths: (B,)
563
- # """
564
- # text_token = batch['text_token'].to(device)
565
- # text_token_len = batch['text_token_len'].to(device)
566
- # speech_token = batch['speech_token'].to(device)
567
- # speech_token_len = batch['speech_token_len'].to(device)
568
-
569
-
570
- # # 1. encode text_token
571
- # text_token_emb = self.llm.model.model.embed_tokens(text_token)
572
-
573
- # # 2. encode speech_token
574
- # speech_token_emb = self.speech_embedding(speech_token)
575
-
576
- # # 3. prepare llm_input/target
577
- # lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len)
578
- # lm_target = lm_target.to(device)
579
-
580
- # # 4. run lm forward
581
- # lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
582
- # logits = self.llm_decoder(lm_output)
583
- # loss = self.criterion_ce(logits, lm_target.to(device))
584
- # acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), lm_target, ignore_label=IGNORE_ID)
585
- # return {'loss': loss, 'acc': acc}
586
 
587
  def forward(
588
  self,
@@ -601,8 +543,8 @@ class Qwen2LM(TransformerLM):
601
  speech_token = batch['speech_token'].to(device)
602
  speech_token_len = batch['speech_token_len'].to(device)
603
 
604
- speaker_embed = self.get_speaker_conditioning(batch, device) # [B, 1, llm_input_size]
605
-
606
 
607
  # 1. encode text_token
608
  text_token_emb = self.llm.model.model.embed_tokens(text_token)
@@ -611,7 +553,11 @@ class Qwen2LM(TransformerLM):
611
  speech_token_emb = self.speech_embedding(speech_token)
612
 
613
  # 3. prepare llm_input/target
614
- lm_target, lm_input, lm_input_len = self.prepare_lm_input_target_with_spk(text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len, speaker_embed)
 
 
 
 
615
  lm_target = lm_target.to(device)
616
 
617
  # 4. run lm forward
 
162
 
163
  def get_speaker_conditioning(self, batch, device):
164
  """Extract speaker conditioning from reference audio or embeddings."""
165
+
166
+ reference_mels = batch['reference_mels'].to(device)
167
+ # Handle multiple references
168
+ B, N, C, T = reference_mels.shape
169
+ conds = []
170
 
171
+ for i in range(N):
172
+ ref_mel = reference_mels[:, i, :, :]
173
+ if 'reference_mel_masks' in batch:
174
+ mask = batch['reference_mel_masks'][:, i, :].unsqueeze(1).to(device)
175
+ else:
176
+ mask = None
177
+ print('ref_mel shape: ', ref_mel.shape, 'mask: ', mask.shape)
178
+ cond = self.speaker_encoder(ref_mel, mask) # [B, spk_embed_dim]
179
+ conds.append(cond)
180
+
181
+ # Average multiple references (like Tortoise)
182
+ speaker_embed = torch.stack(conds, dim=1).mean(dim=1) # [B, spk_embed_dim]
183
+
184
+ speaker_embed = F.normalize(speaker_embed, dim=1)
185
+ speaker_embed = self.spk_embed_affine_layer(speaker_embed)
186
+ speaker_embed = speaker_embed.unsqueeze(1) # [B, 1, llm_input_size]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
  return speaker_embed
189
 
 
525
  lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
526
  lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID)
527
  return lm_target, lm_input, lm_input_len
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
528
 
529
  def forward(
530
  self,
 
543
  speech_token = batch['speech_token'].to(device)
544
  speech_token_len = batch['speech_token_len'].to(device)
545
 
546
+ if self.use_speaker_encoder:
547
+ embedding = self.get_speaker_conditioning(batch, device) # [B, 1, llm_input_size]
548
 
549
  # 1. encode text_token
550
  text_token_emb = self.llm.model.model.embed_tokens(text_token)
 
553
  speech_token_emb = self.speech_embedding(speech_token)
554
 
555
  # 3. prepare llm_input/target
556
+ if self.use_speaker_encoder:
557
+ lm_target, lm_input, lm_input_len = self.prepare_lm_input_target_with_spk(text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len, embedding)
558
+ else:
559
+ lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len)
560
+
561
  lm_target = lm_target.to(device)
562
 
563
  # 4. run lm forward