primepake commited on
Commit
75e50c2
·
1 Parent(s): 3f67e2c

update LLM with speaker encoder

Browse files
Files changed (2) hide show
  1. speech/config.yaml +23 -0
  2. speech/cosyvoice/llm/llm.py +60 -0
speech/config.yaml CHANGED
@@ -17,6 +17,16 @@ token_mel_ratio: 2
17
  chunk_size: 25 # streaming inference chunk size, in token
18
  num_decoding_left_chunks: -1 # streaming inference flow decoder left chunk size, <0 means use all left chunks
19
 
 
 
 
 
 
 
 
 
 
 
20
  # model params
21
  # for all class/function included in this repo, we use !<name> or !<new> for intialization, so that user may find all corresponding class/function according to one single yaml.
22
  # for system/third_party class/function, we do not require this.
@@ -27,6 +37,9 @@ llm: !new:cosyvoice.llm.llm.Qwen2LM
27
  length_normalized_loss: True
28
  lsm_weight: 0
29
  mix_ratio: [5, 15]
 
 
 
30
  llm: !new:cosyvoice.llm.llm.Qwen2Encoder
31
  pretrain_path: !ref <qwen_pretrain_path>
32
  sampling: !name:cosyvoice.utils.common.ras_sampling
@@ -34,6 +47,14 @@ llm: !new:cosyvoice.llm.llm.Qwen2LM
34
  top_k: 25
35
  win_size: 10
36
  tau_r: 0.1
 
 
 
 
 
 
 
 
37
 
38
  flow: !new:cosyvoice.flow.flow.CausalMaskedDiffWithXvec
39
  input_size: 512
@@ -180,6 +201,7 @@ batch: !name:cosyvoice.dataset.processor.batch
180
  max_frames_in_batch: 25000
181
  padding: !name:cosyvoice.dataset.processor.padding
182
  use_spk_embedding: False # change to True during sft
 
183
 
184
 
185
  # dataset processor pipeline
@@ -189,6 +211,7 @@ data_pipeline: [
189
  !ref <filter>,
190
  !ref <resample>,
191
  !ref <compute_fbank>,
 
192
  !ref <parse_embedding>,
193
  !ref <shuffle>,
194
  !ref <sort>,
 
17
  chunk_size: 25 # streaming inference chunk size, in token
18
  num_decoding_left_chunks: -1 # streaming inference flow decoder left chunk size, <0 means use all left chunks
19
 
20
+ use_speaker_encoder: True # New flag
21
+ speaker_encoder_config:
22
+ mel_dim: 80
23
+ model_dim: 512
24
+ output_dim: !ref <spk_embed_dim> # 192
25
+ num_blocks: 6
26
+ num_heads: 8
27
+ kernel_size: 1
28
+ dropout: 0.1
29
+ max_conditioning_inputs: 3 # Support multiple references
30
  # model params
31
  # for all class/function included in this repo, we use !<name> or !<new> for intialization, so that user may find all corresponding class/function according to one single yaml.
32
  # for system/third_party class/function, we do not require this.
 
37
  length_normalized_loss: True
38
  lsm_weight: 0
39
  mix_ratio: [5, 15]
40
+ use_speaker_encoder: !ref <use_speaker_encoder> # Add this
41
+ spk_embed_dim: !ref <spk_embed_dim>
42
+ max_conditioning_inputs: 3
43
  llm: !new:cosyvoice.llm.llm.Qwen2Encoder
44
  pretrain_path: !ref <qwen_pretrain_path>
45
  sampling: !name:cosyvoice.utils.common.ras_sampling
 
47
  top_k: 25
48
  win_size: 10
49
  tau_r: 0.1
50
+ #
51
+ extract_reference_mel: !name:cosyvoice.dataset.processor.extract_reference_mel_from_speech
52
+ feat_extractor: !ref <feat_extractor>
53
+ min_length: 0.5
54
+ max_length: 6.0
55
+ num_crops: 2 # Multiple crops from same utterance
56
+ training: True
57
+ sample_rate: !ref <sample_rate>
58
 
59
  flow: !new:cosyvoice.flow.flow.CausalMaskedDiffWithXvec
60
  input_size: 512
 
201
  max_frames_in_batch: 25000
202
  padding: !name:cosyvoice.dataset.processor.padding
203
  use_spk_embedding: False # change to True during sft
204
+ use_speaker_encoder: !ref <use_speaker_encoder>
205
 
206
 
207
  # dataset processor pipeline
 
211
  !ref <filter>,
212
  !ref <resample>,
213
  !ref <compute_fbank>,
214
+ !ref <extract_reference_mel>, # Add this for speaker encoder
215
  !ref <parse_embedding>,
216
  !ref <shuffle>,
217
  !ref <sort>,
speech/cosyvoice/llm/llm.py CHANGED
@@ -665,6 +665,66 @@ class Qwen2LM(TransformerLM):
665
  rejected_logps = (rejected_logps * chosen_lm_mask).mean(dim=-1)
666
  return {'loss': loss, 'acc': acc, 'chosen_logps': chosen_logps, 'rejected_logps': rejected_logps}
667
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
668
  @torch.inference_mode()
669
  def inference(
670
  self,
 
665
  rejected_logps = (rejected_logps * chosen_lm_mask).mean(dim=-1)
666
  return {'loss': loss, 'acc': acc, 'chosen_logps': chosen_logps, 'rejected_logps': rejected_logps}
667
 
668
+ @torch.inference_mode()
669
+ def inference_spk(
670
+ self,
671
+ text: torch.Tensor,
672
+ text_len: torch.Tensor,
673
+ prompt_text: torch.Tensor,
674
+ prompt_text_len: torch.Tensor,
675
+ prompt_speech_token: torch.Tensor,
676
+ prompt_speech_token_len: torch.Tensor,
677
+ embedding: torch.Tensor = None,
678
+ reference_mels: torch.Tensor = None,
679
+ reference_mel_lengths: torch.Tensor = None,
680
+ reference_mel_masks: torch.Tensor = None,
681
+ sampling: int = 25,
682
+ max_token_text_ratio: float = 20,
683
+ min_token_text_ratio: float = 2,
684
+ uuid: str = '',
685
+ ) -> Generator[torch.Tensor, None, None]:
686
+ device = text.device
687
+ text = torch.concat([prompt_text, text], dim=1)
688
+ text_len += prompt_text_len
689
+ text = self.llm.model.model.embed_tokens(text)
690
+
691
+ # Get speaker conditioning
692
+ if self.use_speaker_encoder and reference_mels is not None:
693
+ # Use speaker encoder
694
+ batch = {
695
+ 'reference_mels': reference_mels,
696
+ 'reference_mel_lengths': reference_mel_lengths,
697
+ 'reference_mel_masks': reference_mel_masks
698
+ }
699
+ speaker_embed = self.get_speaker_conditioning(batch, device) # [1, 1, llm_input_size]
700
+ elif embedding is not None and embedding.shape[0] != 0:
701
+ # Use provided embeddings
702
+ embedding = F.normalize(embedding, dim=1)
703
+ speaker_embed = self.spk_embed_affine_layer(embedding)
704
+ speaker_embed = speaker_embed.unsqueeze(1)
705
+ else:
706
+ # No speaker conditioning
707
+ speaker_embed = torch.zeros(1, 1, self.llm_input_size).to(device)
708
+
709
+ # 3. concat llm_input with speaker embedding
710
+ sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
711
+ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
712
+ if prompt_speech_token_len != 0:
713
+ prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
714
+ else:
715
+ prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
716
+
717
+ # Include speaker embedding in the sequence
718
+ lm_input = torch.concat([sos_eos_emb, speaker_embed, text, task_id_emb, prompt_speech_token_emb], dim=1)
719
+
720
+ # 4. cal min/max_length
721
+ min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
722
+ max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
723
+
724
+ # 5. step by step decode
725
+ for token in self.inference_wrapper(lm_input, sampling, min_len, max_len, uuid):
726
+ yield token
727
+
728
  @torch.inference_mode()
729
  def inference(
730
  self,