primepake commited on
Commit
3f67e2c
·
1 Parent(s): feec49e
speech/cosyvoice/dataset/processor.py CHANGED
@@ -312,12 +312,7 @@ def compute_fbank(data,
312
  waveform = sample['speech']
313
  feat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
314
  if token_mel_ratio != 0:
315
- # trim to align speech_token and speech_feat
316
- # token_len = int(min(feat.shape[0] / token_mel_ratio, sample["speech_token"].shape[0]))
317
- # feat = feat[:token_mel_ratio * token_len]
318
- # sample["speech_token"] = sample["speech_token"][:token_len]
319
-
320
- # Convert speech_token to tensor if it's a list
321
  if isinstance(sample["speech_token"], list):
322
  speech_token_tensor = torch.tensor(sample["speech_token"])
323
  else:
@@ -334,6 +329,80 @@ def compute_fbank(data,
334
  yield sample
335
 
336
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
  def compute_f0(data, sample_rate, hop_size, mode='train'):
338
  """ Extract f0
339
 
@@ -505,19 +574,19 @@ def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, m
505
  else:
506
  logging.fatal('Unsupported batch type {}'.format(batch_type))
507
 
508
-
509
- def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
510
  """ Padding the data into training data
511
 
512
  Args:
513
  data: Iterable[List[{key, feat, label}]]
 
514
 
515
  Returns:
516
  Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
517
  """
518
  for sample in data:
519
  assert isinstance(sample, list)
520
- speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
521
  dtype=torch.int32)
522
  order = torch.argsort(speech_feat_len, descending=True)
523
 
@@ -525,11 +594,19 @@ def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
525
  speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
526
  speech_len = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
527
  speech = pad_sequence(speech, batch_first=True, padding_value=0)
528
- speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
 
 
 
 
 
 
 
529
  speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
530
  speech_token = pad_sequence(speech_token,
531
  batch_first=True,
532
  padding_value=0)
 
533
  speech_feat = [sample[i]['speech_feat'] for i in order]
534
  speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
535
  speech_feat = pad_sequence(speech_feat,
@@ -541,6 +618,7 @@ def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
541
  text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
542
  utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
543
  spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
 
544
  batch = {
545
  "utts": utts,
546
  "speech": speech,
@@ -555,6 +633,71 @@ def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
555
  "utt_embedding": utt_embedding,
556
  "spk_embedding": spk_embedding,
557
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
558
  if gan is True:
559
  # in gan train, we need pitch_feat
560
  pitch_feat = [sample[i]['pitch_feat'] for i in order]
@@ -568,16 +711,24 @@ def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
568
  # only gan train needs speech, delete it to save memory
569
  del batch["speech"]
570
  del batch["speech_len"]
 
571
  if dpo is True:
572
- reject_speech_token = [torch.tensor(sample[i]['reject_speech_token']) for i in order]
 
 
 
 
 
573
  reject_speech_token_len = torch.tensor([i.size(0) for i in reject_speech_token], dtype=torch.int32)
574
  reject_speech_token = pad_sequence(reject_speech_token,
575
  batch_first=True,
576
  padding_value=0)
577
  batch['reject_speech_token'] = reject_speech_token
578
  batch['reject_speech_token_len'] = reject_speech_token_len
 
579
  if use_spk_embedding is True:
580
  batch["embedding"] = batch["spk_embedding"]
581
  else:
582
  batch["embedding"] = batch["utt_embedding"]
583
- yield batch
 
 
312
  waveform = sample['speech']
313
  feat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
314
  if token_mel_ratio != 0:
315
+
 
 
 
 
 
316
  if isinstance(sample["speech_token"], list):
317
  speech_token_tensor = torch.tensor(sample["speech_token"])
318
  else:
 
329
  yield sample
330
 
331
 
332
+ def extract_reference_mel_from_speech(
333
+ data,
334
+ feat_extractor,
335
+ min_length=2.0,
336
+ max_length=6.0,
337
+ num_crops=2, # Multiple random crops from same utterance
338
+ training=True,
339
+ sample_rate=24000,
340
+ mode='train'
341
+ ):
342
+ """
343
+ Extract mel spectrograms from current speech waveform with random cropping.
344
+ This creates multiple random crops from the same utterance for training diversity.
345
+ """
346
+ for sample in data:
347
+ # Use the current speech waveform
348
+ waveform = sample['speech'] # [1, T]
349
+ speech_length = waveform.shape[1]
350
+
351
+ # Convert time to samples
352
+ min_samples = int(min_length * sample_rate)
353
+ max_samples = int(max_length * sample_rate)
354
+
355
+ reference_mels = []
356
+ reference_mel_lengths = []
357
+
358
+ # Skip if utterance is too short
359
+ if speech_length < min_samples:
360
+ logging.warning(f"Speech for {sample['utt']} is too short ({speech_length/sample_rate:.2f}s)")
361
+ sample['reference_mels'] = []
362
+ sample['reference_mel_lengths'] = []
363
+ sample['num_references'] = 0
364
+ yield sample
365
+ continue
366
+
367
+ # Generate multiple crops from the same utterance
368
+ crops_to_generate = num_crops if training else 1
369
+
370
+ for i in range(crops_to_generate):
371
+ if training and speech_length > max_samples:
372
+ # Random crop during training
373
+ crop_length = random.randint(min_samples, min(max_samples, speech_length))
374
+ start_idx = random.randint(0, speech_length - crop_length)
375
+ audio_segment = waveform[:, start_idx:start_idx + crop_length]
376
+ elif speech_length > max_samples:
377
+ # Center crop during inference
378
+ start_idx = (speech_length - max_samples) // 2
379
+ audio_segment = waveform[:, start_idx:start_idx + max_samples]
380
+ else:
381
+ # Use full audio if shorter than max_length
382
+ audio_segment = waveform
383
+ # For training, if we need multiple crops but audio is short,
384
+ # we can add slight variations
385
+ if training and i > 0:
386
+ # Add very slight noise for variation
387
+ noise = torch.randn_like(audio_segment) * 0.001
388
+ audio_segment = audio_segment + noise
389
+
390
+ # Normalize audio segment
391
+ max_val = audio_segment.abs().max()
392
+ if max_val > 0:
393
+ audio_segment = audio_segment / max_val
394
+
395
+ # Extract mel spectrogram
396
+ mel = feat_extractor(audio_segment).squeeze(0) # Remove batch dim [C, T]
397
+ reference_mels.append(mel)
398
+ reference_mel_lengths.append(mel.shape[1])
399
+
400
+ sample['reference_mels'] = reference_mels
401
+ sample['reference_mel_lengths'] = reference_mel_lengths
402
+ sample['num_references'] = len(reference_mels)
403
+
404
+ yield sample
405
+
406
  def compute_f0(data, sample_rate, hop_size, mode='train'):
407
  """ Extract f0
408
 
 
574
  else:
575
  logging.fatal('Unsupported batch type {}'.format(batch_type))
576
 
577
+ def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False, use_speaker_encoder=False):
 
578
  """ Padding the data into training data
579
 
580
  Args:
581
  data: Iterable[List[{key, feat, label}]]
582
+ use_speaker_encoder: Whether to prepare reference mels for speaker encoder
583
 
584
  Returns:
585
  Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
586
  """
587
  for sample in data:
588
  assert isinstance(sample, list)
589
+ speech_feat_len = torch.tensor([x['speech_feat'].size(0) for x in sample], # Changed from size(1) to size(0)
590
  dtype=torch.int32)
591
  order = torch.argsort(speech_feat_len, descending=True)
592
 
 
594
  speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
595
  speech_len = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
596
  speech = pad_sequence(speech, batch_first=True, padding_value=0)
597
+
598
+ # Handle speech_token - check if it's already a tensor
599
+ speech_token = []
600
+ for i in order:
601
+ if isinstance(sample[i]['speech_token'], torch.Tensor):
602
+ speech_token.append(sample[i]['speech_token'])
603
+ else:
604
+ speech_token.append(torch.tensor(sample[i]['speech_token']))
605
  speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
606
  speech_token = pad_sequence(speech_token,
607
  batch_first=True,
608
  padding_value=0)
609
+
610
  speech_feat = [sample[i]['speech_feat'] for i in order]
611
  speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
612
  speech_feat = pad_sequence(speech_feat,
 
618
  text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
619
  utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
620
  spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
621
+
622
  batch = {
623
  "utts": utts,
624
  "speech": speech,
 
633
  "utt_embedding": utt_embedding,
634
  "spk_embedding": spk_embedding,
635
  }
636
+
637
+ # Handle reference mels for speaker encoder
638
+ if use_speaker_encoder:
639
+ # Collect all reference mels
640
+ all_reference_mels = []
641
+ all_reference_mel_lengths = []
642
+ all_num_references = []
643
+
644
+ for i in order:
645
+ ref_mels = sample[i].get('reference_mels', [])
646
+ ref_lengths = sample[i].get('reference_mel_lengths', [])
647
+ num_refs = sample[i].get('num_references', 0)
648
+
649
+ all_reference_mels.append(ref_mels)
650
+ all_reference_mel_lengths.append(ref_lengths)
651
+ all_num_references.append(num_refs)
652
+
653
+ # Determine max number of references in batch
654
+ max_num_refs = max(all_num_references) if all_num_references else 0
655
+
656
+ if max_num_refs > 0:
657
+ # Find dimensions
658
+ batch_size = len(order)
659
+ max_mel_length = 0
660
+ mel_dim = 80 # default
661
+
662
+ # Find max mel length and mel dimension
663
+ for ref_mels in all_reference_mels:
664
+ for mel in ref_mels:
665
+ if isinstance(mel, torch.Tensor) and mel.numel() > 0:
666
+ max_mel_length = max(max_mel_length, mel.shape[1])
667
+ mel_dim = mel.shape[0]
668
+
669
+ if max_mel_length > 0:
670
+ # Create padded tensor [B, N, C, T]
671
+ padded_reference_mels = torch.zeros(batch_size, max_num_refs, mel_dim, max_mel_length)
672
+ padded_reference_mel_lengths = torch.zeros(batch_size, max_num_refs, dtype=torch.int32)
673
+ reference_mel_masks = torch.zeros(batch_size, max_num_refs, max_mel_length)
674
+
675
+ for b_idx, (ref_mels, ref_lengths) in enumerate(zip(all_reference_mels, all_reference_mel_lengths)):
676
+ for r_idx in range(min(len(ref_mels), max_num_refs)):
677
+ if r_idx < len(ref_mels) and isinstance(ref_mels[r_idx], torch.Tensor):
678
+ mel = ref_mels[r_idx]
679
+ length = ref_lengths[r_idx] if r_idx < len(ref_lengths) else mel.shape[1]
680
+ actual_length = min(length, mel.shape[1], max_mel_length)
681
+
682
+ padded_reference_mels[b_idx, r_idx, :, :actual_length] = mel[:, :actual_length]
683
+ padded_reference_mel_lengths[b_idx, r_idx] = actual_length
684
+ reference_mel_masks[b_idx, r_idx, :actual_length] = 1.0
685
+
686
+ batch['reference_mels'] = padded_reference_mels
687
+ batch['reference_mel_lengths'] = padded_reference_mel_lengths
688
+ batch['reference_mel_masks'] = reference_mel_masks
689
+ else:
690
+ # No valid mels, create dummy
691
+ batch['reference_mels'] = torch.zeros(batch_size, 1, mel_dim, 100)
692
+ batch['reference_mel_lengths'] = torch.zeros(batch_size, 1, dtype=torch.int32)
693
+ batch['reference_mel_masks'] = torch.zeros(batch_size, 1, 100)
694
+ else:
695
+ # No references available, create dummy for batch
696
+ batch_size = len(order)
697
+ batch['reference_mels'] = torch.zeros(batch_size, 1, 80, 100)
698
+ batch['reference_mel_lengths'] = torch.zeros(batch_size, 1, dtype=torch.int32)
699
+ batch['reference_mel_masks'] = torch.zeros(batch_size, 1, 100)
700
+
701
  if gan is True:
702
  # in gan train, we need pitch_feat
703
  pitch_feat = [sample[i]['pitch_feat'] for i in order]
 
711
  # only gan train needs speech, delete it to save memory
712
  del batch["speech"]
713
  del batch["speech_len"]
714
+
715
  if dpo is True:
716
+ reject_speech_token = []
717
+ for i in order:
718
+ if isinstance(sample[i]['reject_speech_token'], torch.Tensor):
719
+ reject_speech_token.append(sample[i]['reject_speech_token'])
720
+ else:
721
+ reject_speech_token.append(torch.tensor(sample[i]['reject_speech_token']))
722
  reject_speech_token_len = torch.tensor([i.size(0) for i in reject_speech_token], dtype=torch.int32)
723
  reject_speech_token = pad_sequence(reject_speech_token,
724
  batch_first=True,
725
  padding_value=0)
726
  batch['reject_speech_token'] = reject_speech_token
727
  batch['reject_speech_token_len'] = reject_speech_token_len
728
+
729
  if use_spk_embedding is True:
730
  batch["embedding"] = batch["spk_embedding"]
731
  else:
732
  batch["embedding"] = batch["utt_embedding"]
733
+
734
+ yield batch
speech/cosyvoice/llm/llm.py CHANGED
@@ -403,11 +403,19 @@ class Qwen2LM(TransformerLM):
403
  length_normalized_loss: bool = True,
404
  lsm_weight: float = 0.0,
405
  mix_ratio: List[int] = [5, 15],
 
 
 
406
  ):
407
  torch.nn.Module.__init__(self)
408
  self.llm_input_size = llm_input_size
409
  self.llm_output_size = llm_output_size
410
  self.speech_token_size = speech_token_size
 
 
 
 
 
411
  # 2. build speech token language model related modules
412
  self.sos_eos = 0
413
  self.task_id = 1
@@ -425,7 +433,16 @@ class Qwen2LM(TransformerLM):
425
 
426
  # 3. [Optional] build speech token related modules
427
  self.speech_embedding = torch.nn.Embedding(speech_token_size + 3, llm_input_size)
428
-
 
 
 
 
 
 
 
 
 
429
  # 4. sampling method
430
  self.sampling = sampling
431
  self.mix_ratio = mix_ratio
@@ -434,6 +451,60 @@ class Qwen2LM(TransformerLM):
434
  self.stop_token_ids = [speech_token_size + i for i in range(3)]
435
  self.vllm_output_queue = {}
436
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
437
  def prepare_lm_input_target(self, text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len):
438
  lm_target, lm_input = [], []
439
  text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
@@ -493,6 +564,7 @@ class Qwen2LM(TransformerLM):
493
  speech_token = batch['speech_token'].to(device)
494
  speech_token_len = batch['speech_token_len'].to(device)
495
 
 
496
  # 1. encode text_token
497
  text_token_emb = self.llm.model.model.embed_tokens(text_token)
498
 
@@ -509,6 +581,43 @@ class Qwen2LM(TransformerLM):
509
  loss = self.criterion_ce(logits, lm_target.to(device))
510
  acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), lm_target, ignore_label=IGNORE_ID)
511
  return {'loss': loss, 'acc': acc}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
512
 
513
  def forward_dpo(
514
  self,
 
403
  length_normalized_loss: bool = True,
404
  lsm_weight: float = 0.0,
405
  mix_ratio: List[int] = [5, 15],
406
+ use_speaker_encoder: bool = False, # Add this
407
+ spk_embed_dim: int = 192, # Add this
408
+ max_conditioning_inputs: int = 2, # Add this
409
  ):
410
  torch.nn.Module.__init__(self)
411
  self.llm_input_size = llm_input_size
412
  self.llm_output_size = llm_output_size
413
  self.speech_token_size = speech_token_size
414
+ # Initialize speaker encoder settings
415
+ self.use_speaker_encoder = use_speaker_encoder
416
+ self.spk_embed_dim = spk_embed_dim
417
+ self.max_conditioning_inputs = max_conditioning_inputs
418
+
419
  # 2. build speech token language model related modules
420
  self.sos_eos = 0
421
  self.task_id = 1
 
433
 
434
  # 3. [Optional] build speech token related modules
435
  self.speech_embedding = torch.nn.Embedding(speech_token_size + 3, llm_input_size)
436
+ # 4. Speaker encoder or embedding projection
437
+ if use_speaker_encoder:
438
+ self.speaker_encoder = LearnableSpeakerEncoder(
439
+ mel_dim=80,
440
+ model_dim=512,
441
+ output_dim=spk_embed_dim,
442
+ num_blocks=6,
443
+ num_heads=8,
444
+ )
445
+ self.spk_embed_affine_layer = nn.Linear(spk_embed_dim, llm_input_size)
446
  # 4. sampling method
447
  self.sampling = sampling
448
  self.mix_ratio = mix_ratio
 
451
  self.stop_token_ids = [speech_token_size + i for i in range(3)]
452
  self.vllm_output_queue = {}
453
 
454
+ def prepare_lm_input_target_with_spk(self, text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len, speaker_embed):
455
+ lm_target, lm_input = [], []
456
+ text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
457
+ speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
458
+ text_token_emb = unpad_sequence(text_token_emb, text_token_len.cpu(), batch_first=True)
459
+ speech_token_emb = unpad_sequence(speech_token_emb, speech_token_len.cpu(), batch_first=True)
460
+ for i in range(len(text_token)):
461
+ # bistream sequence
462
+ spk_emb = speaker_embed[i] # [1, llm_input_size]
463
+ if random.random() < 0.5 and speech_token_len[i] / text_token_len[i] > self.mix_ratio[1] / self.mix_ratio[0]:
464
+ this_lm_target, this_lm_input = [], []
465
+ this_lm_target.append(IGNORE_ID)
466
+ this_lm_target.append(IGNORE_ID) # For speaker embedding
467
+ this_lm_input.append(self.llm_embedding.weight[self.sos_eos].reshape(1, -1))
468
+ this_lm_input.append(spk_emb) # Add speaker embedding
469
+
470
+ for j in range(((text_token_len[i] + 1) / self.mix_ratio[0]).ceil().int().item()):
471
+ this_text_token = text_token[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]].tolist()
472
+ this_speech_token = speech_token[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]].tolist()
473
+ if len(this_text_token) == self.mix_ratio[0]:
474
+ assert len(this_speech_token) == self.mix_ratio[1]
475
+ this_lm_target += [IGNORE_ID] * (self.mix_ratio[0] - 1)
476
+ this_lm_target += this_speech_token
477
+ this_lm_target.append(self.speech_token_size + 2)
478
+ this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]])
479
+ this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]])
480
+ else:
481
+ this_lm_target += [-1] * len(this_text_token)
482
+ this_lm_target += speech_token[i][j * self.mix_ratio[1]:].tolist()
483
+ this_lm_target.append(self.speech_token_size)
484
+ this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]:])
485
+ this_lm_input.append(self.llm_embedding.weight[self.task_id].reshape(1, -1))
486
+ this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]:])
487
+ this_lm_target, this_lm_input = torch.tensor(this_lm_target), torch.concat(this_lm_input, dim=0)
488
+ # unistream sequence
489
+ else:
490
+
491
+ # this_lm_target = torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i].tolist() + [self.speech_token_size])
492
+ this_lm_target = torch.tensor([IGNORE_ID, IGNORE_ID] + [IGNORE_ID] * text_token_len[i] + speech_token[i].tolist() + [self.speech_token_size])
493
+
494
+ this_lm_input = torch.concat([
495
+ self.llm_embedding.weight[self.sos_eos].reshape(1, -1),
496
+ spk_emb,
497
+ text_token_emb[i],
498
+ self.llm_embedding.weight[self.task_id].reshape(1, -1),
499
+ speech_token_emb[i]
500
+ ], dim=0)
501
+ lm_target.append(this_lm_target)
502
+ lm_input.append(this_lm_input)
503
+ lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
504
+ lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
505
+ lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID)
506
+ return lm_target, lm_input, lm_input_len
507
+
508
  def prepare_lm_input_target(self, text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len):
509
  lm_target, lm_input = [], []
510
  text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
 
564
  speech_token = batch['speech_token'].to(device)
565
  speech_token_len = batch['speech_token_len'].to(device)
566
 
567
+
568
  # 1. encode text_token
569
  text_token_emb = self.llm.model.model.embed_tokens(text_token)
570
 
 
581
  loss = self.criterion_ce(logits, lm_target.to(device))
582
  acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), lm_target, ignore_label=IGNORE_ID)
583
  return {'loss': loss, 'acc': acc}
584
+
585
+ def forward_spk(
586
+ self,
587
+ batch: dict,
588
+ device: torch.device,
589
+ ) -> Dict[str, Optional[torch.Tensor]]:
590
+ """
591
+ Args:
592
+ text: (B, L, D)
593
+ text_lengths: (B,)
594
+ audio: (B, T, N) or (B, T)
595
+ audio_lengths: (B,)
596
+ """
597
+ text_token = batch['text_token'].to(device)
598
+ text_token_len = batch['text_token_len'].to(device)
599
+ speech_token = batch['speech_token'].to(device)
600
+ speech_token_len = batch['speech_token_len'].to(device)
601
+
602
+ speaker_embed = self.get_speaker_conditioning(batch, device) # [B, 1, llm_input_size]
603
+
604
+
605
+ # 1. encode text_token
606
+ text_token_emb = self.llm.model.model.embed_tokens(text_token)
607
+
608
+ # 2. encode speech_token
609
+ speech_token_emb = self.speech_embedding(speech_token)
610
+
611
+ # 3. prepare llm_input/target
612
+ lm_target, lm_input, lm_input_len = self.prepare_lm_input_target_with_spk(text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len, speaker_embed)
613
+ lm_target = lm_target.to(device)
614
+
615
+ # 4. run lm forward
616
+ lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
617
+ logits = self.llm_decoder(lm_output)
618
+ loss = self.criterion_ce(logits, lm_target.to(device))
619
+ acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), lm_target, ignore_label=IGNORE_ID)
620
+ return {'loss': loss, 'acc': acc}
621
 
622
  def forward_dpo(
623
  self,