Spaces:
Sleeping
Sleeping
primepake
commited on
Commit
·
3f67e2c
1
Parent(s):
feec49e
add LLM
Browse files- speech/cosyvoice/dataset/processor.py +163 -12
- speech/cosyvoice/llm/llm.py +110 -1
speech/cosyvoice/dataset/processor.py
CHANGED
|
@@ -312,12 +312,7 @@ def compute_fbank(data,
|
|
| 312 |
waveform = sample['speech']
|
| 313 |
feat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
|
| 314 |
if token_mel_ratio != 0:
|
| 315 |
-
|
| 316 |
-
# token_len = int(min(feat.shape[0] / token_mel_ratio, sample["speech_token"].shape[0]))
|
| 317 |
-
# feat = feat[:token_mel_ratio * token_len]
|
| 318 |
-
# sample["speech_token"] = sample["speech_token"][:token_len]
|
| 319 |
-
|
| 320 |
-
# Convert speech_token to tensor if it's a list
|
| 321 |
if isinstance(sample["speech_token"], list):
|
| 322 |
speech_token_tensor = torch.tensor(sample["speech_token"])
|
| 323 |
else:
|
|
@@ -334,6 +329,80 @@ def compute_fbank(data,
|
|
| 334 |
yield sample
|
| 335 |
|
| 336 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 337 |
def compute_f0(data, sample_rate, hop_size, mode='train'):
|
| 338 |
""" Extract f0
|
| 339 |
|
|
@@ -505,19 +574,19 @@ def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, m
|
|
| 505 |
else:
|
| 506 |
logging.fatal('Unsupported batch type {}'.format(batch_type))
|
| 507 |
|
| 508 |
-
|
| 509 |
-
def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
|
| 510 |
""" Padding the data into training data
|
| 511 |
|
| 512 |
Args:
|
| 513 |
data: Iterable[List[{key, feat, label}]]
|
|
|
|
| 514 |
|
| 515 |
Returns:
|
| 516 |
Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
|
| 517 |
"""
|
| 518 |
for sample in data:
|
| 519 |
assert isinstance(sample, list)
|
| 520 |
-
speech_feat_len = torch.tensor([x['speech_feat'].size(
|
| 521 |
dtype=torch.int32)
|
| 522 |
order = torch.argsort(speech_feat_len, descending=True)
|
| 523 |
|
|
@@ -525,11 +594,19 @@ def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
|
|
| 525 |
speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
|
| 526 |
speech_len = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
|
| 527 |
speech = pad_sequence(speech, batch_first=True, padding_value=0)
|
| 528 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 529 |
speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
|
| 530 |
speech_token = pad_sequence(speech_token,
|
| 531 |
batch_first=True,
|
| 532 |
padding_value=0)
|
|
|
|
| 533 |
speech_feat = [sample[i]['speech_feat'] for i in order]
|
| 534 |
speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
|
| 535 |
speech_feat = pad_sequence(speech_feat,
|
|
@@ -541,6 +618,7 @@ def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
|
|
| 541 |
text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
|
| 542 |
utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
|
| 543 |
spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
|
|
|
|
| 544 |
batch = {
|
| 545 |
"utts": utts,
|
| 546 |
"speech": speech,
|
|
@@ -555,6 +633,71 @@ def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
|
|
| 555 |
"utt_embedding": utt_embedding,
|
| 556 |
"spk_embedding": spk_embedding,
|
| 557 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 558 |
if gan is True:
|
| 559 |
# in gan train, we need pitch_feat
|
| 560 |
pitch_feat = [sample[i]['pitch_feat'] for i in order]
|
|
@@ -568,16 +711,24 @@ def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
|
|
| 568 |
# only gan train needs speech, delete it to save memory
|
| 569 |
del batch["speech"]
|
| 570 |
del batch["speech_len"]
|
|
|
|
| 571 |
if dpo is True:
|
| 572 |
-
reject_speech_token = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 573 |
reject_speech_token_len = torch.tensor([i.size(0) for i in reject_speech_token], dtype=torch.int32)
|
| 574 |
reject_speech_token = pad_sequence(reject_speech_token,
|
| 575 |
batch_first=True,
|
| 576 |
padding_value=0)
|
| 577 |
batch['reject_speech_token'] = reject_speech_token
|
| 578 |
batch['reject_speech_token_len'] = reject_speech_token_len
|
|
|
|
| 579 |
if use_spk_embedding is True:
|
| 580 |
batch["embedding"] = batch["spk_embedding"]
|
| 581 |
else:
|
| 582 |
batch["embedding"] = batch["utt_embedding"]
|
| 583 |
-
|
|
|
|
|
|
| 312 |
waveform = sample['speech']
|
| 313 |
feat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
|
| 314 |
if token_mel_ratio != 0:
|
| 315 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 316 |
if isinstance(sample["speech_token"], list):
|
| 317 |
speech_token_tensor = torch.tensor(sample["speech_token"])
|
| 318 |
else:
|
|
|
|
| 329 |
yield sample
|
| 330 |
|
| 331 |
|
| 332 |
+
def extract_reference_mel_from_speech(
|
| 333 |
+
data,
|
| 334 |
+
feat_extractor,
|
| 335 |
+
min_length=2.0,
|
| 336 |
+
max_length=6.0,
|
| 337 |
+
num_crops=2, # Multiple random crops from same utterance
|
| 338 |
+
training=True,
|
| 339 |
+
sample_rate=24000,
|
| 340 |
+
mode='train'
|
| 341 |
+
):
|
| 342 |
+
"""
|
| 343 |
+
Extract mel spectrograms from current speech waveform with random cropping.
|
| 344 |
+
This creates multiple random crops from the same utterance for training diversity.
|
| 345 |
+
"""
|
| 346 |
+
for sample in data:
|
| 347 |
+
# Use the current speech waveform
|
| 348 |
+
waveform = sample['speech'] # [1, T]
|
| 349 |
+
speech_length = waveform.shape[1]
|
| 350 |
+
|
| 351 |
+
# Convert time to samples
|
| 352 |
+
min_samples = int(min_length * sample_rate)
|
| 353 |
+
max_samples = int(max_length * sample_rate)
|
| 354 |
+
|
| 355 |
+
reference_mels = []
|
| 356 |
+
reference_mel_lengths = []
|
| 357 |
+
|
| 358 |
+
# Skip if utterance is too short
|
| 359 |
+
if speech_length < min_samples:
|
| 360 |
+
logging.warning(f"Speech for {sample['utt']} is too short ({speech_length/sample_rate:.2f}s)")
|
| 361 |
+
sample['reference_mels'] = []
|
| 362 |
+
sample['reference_mel_lengths'] = []
|
| 363 |
+
sample['num_references'] = 0
|
| 364 |
+
yield sample
|
| 365 |
+
continue
|
| 366 |
+
|
| 367 |
+
# Generate multiple crops from the same utterance
|
| 368 |
+
crops_to_generate = num_crops if training else 1
|
| 369 |
+
|
| 370 |
+
for i in range(crops_to_generate):
|
| 371 |
+
if training and speech_length > max_samples:
|
| 372 |
+
# Random crop during training
|
| 373 |
+
crop_length = random.randint(min_samples, min(max_samples, speech_length))
|
| 374 |
+
start_idx = random.randint(0, speech_length - crop_length)
|
| 375 |
+
audio_segment = waveform[:, start_idx:start_idx + crop_length]
|
| 376 |
+
elif speech_length > max_samples:
|
| 377 |
+
# Center crop during inference
|
| 378 |
+
start_idx = (speech_length - max_samples) // 2
|
| 379 |
+
audio_segment = waveform[:, start_idx:start_idx + max_samples]
|
| 380 |
+
else:
|
| 381 |
+
# Use full audio if shorter than max_length
|
| 382 |
+
audio_segment = waveform
|
| 383 |
+
# For training, if we need multiple crops but audio is short,
|
| 384 |
+
# we can add slight variations
|
| 385 |
+
if training and i > 0:
|
| 386 |
+
# Add very slight noise for variation
|
| 387 |
+
noise = torch.randn_like(audio_segment) * 0.001
|
| 388 |
+
audio_segment = audio_segment + noise
|
| 389 |
+
|
| 390 |
+
# Normalize audio segment
|
| 391 |
+
max_val = audio_segment.abs().max()
|
| 392 |
+
if max_val > 0:
|
| 393 |
+
audio_segment = audio_segment / max_val
|
| 394 |
+
|
| 395 |
+
# Extract mel spectrogram
|
| 396 |
+
mel = feat_extractor(audio_segment).squeeze(0) # Remove batch dim [C, T]
|
| 397 |
+
reference_mels.append(mel)
|
| 398 |
+
reference_mel_lengths.append(mel.shape[1])
|
| 399 |
+
|
| 400 |
+
sample['reference_mels'] = reference_mels
|
| 401 |
+
sample['reference_mel_lengths'] = reference_mel_lengths
|
| 402 |
+
sample['num_references'] = len(reference_mels)
|
| 403 |
+
|
| 404 |
+
yield sample
|
| 405 |
+
|
| 406 |
def compute_f0(data, sample_rate, hop_size, mode='train'):
|
| 407 |
""" Extract f0
|
| 408 |
|
|
|
|
| 574 |
else:
|
| 575 |
logging.fatal('Unsupported batch type {}'.format(batch_type))
|
| 576 |
|
| 577 |
+
def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False, use_speaker_encoder=False):
|
|
|
|
| 578 |
""" Padding the data into training data
|
| 579 |
|
| 580 |
Args:
|
| 581 |
data: Iterable[List[{key, feat, label}]]
|
| 582 |
+
use_speaker_encoder: Whether to prepare reference mels for speaker encoder
|
| 583 |
|
| 584 |
Returns:
|
| 585 |
Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
|
| 586 |
"""
|
| 587 |
for sample in data:
|
| 588 |
assert isinstance(sample, list)
|
| 589 |
+
speech_feat_len = torch.tensor([x['speech_feat'].size(0) for x in sample], # Changed from size(1) to size(0)
|
| 590 |
dtype=torch.int32)
|
| 591 |
order = torch.argsort(speech_feat_len, descending=True)
|
| 592 |
|
|
|
|
| 594 |
speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
|
| 595 |
speech_len = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
|
| 596 |
speech = pad_sequence(speech, batch_first=True, padding_value=0)
|
| 597 |
+
|
| 598 |
+
# Handle speech_token - check if it's already a tensor
|
| 599 |
+
speech_token = []
|
| 600 |
+
for i in order:
|
| 601 |
+
if isinstance(sample[i]['speech_token'], torch.Tensor):
|
| 602 |
+
speech_token.append(sample[i]['speech_token'])
|
| 603 |
+
else:
|
| 604 |
+
speech_token.append(torch.tensor(sample[i]['speech_token']))
|
| 605 |
speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
|
| 606 |
speech_token = pad_sequence(speech_token,
|
| 607 |
batch_first=True,
|
| 608 |
padding_value=0)
|
| 609 |
+
|
| 610 |
speech_feat = [sample[i]['speech_feat'] for i in order]
|
| 611 |
speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
|
| 612 |
speech_feat = pad_sequence(speech_feat,
|
|
|
|
| 618 |
text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
|
| 619 |
utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
|
| 620 |
spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
|
| 621 |
+
|
| 622 |
batch = {
|
| 623 |
"utts": utts,
|
| 624 |
"speech": speech,
|
|
|
|
| 633 |
"utt_embedding": utt_embedding,
|
| 634 |
"spk_embedding": spk_embedding,
|
| 635 |
}
|
| 636 |
+
|
| 637 |
+
# Handle reference mels for speaker encoder
|
| 638 |
+
if use_speaker_encoder:
|
| 639 |
+
# Collect all reference mels
|
| 640 |
+
all_reference_mels = []
|
| 641 |
+
all_reference_mel_lengths = []
|
| 642 |
+
all_num_references = []
|
| 643 |
+
|
| 644 |
+
for i in order:
|
| 645 |
+
ref_mels = sample[i].get('reference_mels', [])
|
| 646 |
+
ref_lengths = sample[i].get('reference_mel_lengths', [])
|
| 647 |
+
num_refs = sample[i].get('num_references', 0)
|
| 648 |
+
|
| 649 |
+
all_reference_mels.append(ref_mels)
|
| 650 |
+
all_reference_mel_lengths.append(ref_lengths)
|
| 651 |
+
all_num_references.append(num_refs)
|
| 652 |
+
|
| 653 |
+
# Determine max number of references in batch
|
| 654 |
+
max_num_refs = max(all_num_references) if all_num_references else 0
|
| 655 |
+
|
| 656 |
+
if max_num_refs > 0:
|
| 657 |
+
# Find dimensions
|
| 658 |
+
batch_size = len(order)
|
| 659 |
+
max_mel_length = 0
|
| 660 |
+
mel_dim = 80 # default
|
| 661 |
+
|
| 662 |
+
# Find max mel length and mel dimension
|
| 663 |
+
for ref_mels in all_reference_mels:
|
| 664 |
+
for mel in ref_mels:
|
| 665 |
+
if isinstance(mel, torch.Tensor) and mel.numel() > 0:
|
| 666 |
+
max_mel_length = max(max_mel_length, mel.shape[1])
|
| 667 |
+
mel_dim = mel.shape[0]
|
| 668 |
+
|
| 669 |
+
if max_mel_length > 0:
|
| 670 |
+
# Create padded tensor [B, N, C, T]
|
| 671 |
+
padded_reference_mels = torch.zeros(batch_size, max_num_refs, mel_dim, max_mel_length)
|
| 672 |
+
padded_reference_mel_lengths = torch.zeros(batch_size, max_num_refs, dtype=torch.int32)
|
| 673 |
+
reference_mel_masks = torch.zeros(batch_size, max_num_refs, max_mel_length)
|
| 674 |
+
|
| 675 |
+
for b_idx, (ref_mels, ref_lengths) in enumerate(zip(all_reference_mels, all_reference_mel_lengths)):
|
| 676 |
+
for r_idx in range(min(len(ref_mels), max_num_refs)):
|
| 677 |
+
if r_idx < len(ref_mels) and isinstance(ref_mels[r_idx], torch.Tensor):
|
| 678 |
+
mel = ref_mels[r_idx]
|
| 679 |
+
length = ref_lengths[r_idx] if r_idx < len(ref_lengths) else mel.shape[1]
|
| 680 |
+
actual_length = min(length, mel.shape[1], max_mel_length)
|
| 681 |
+
|
| 682 |
+
padded_reference_mels[b_idx, r_idx, :, :actual_length] = mel[:, :actual_length]
|
| 683 |
+
padded_reference_mel_lengths[b_idx, r_idx] = actual_length
|
| 684 |
+
reference_mel_masks[b_idx, r_idx, :actual_length] = 1.0
|
| 685 |
+
|
| 686 |
+
batch['reference_mels'] = padded_reference_mels
|
| 687 |
+
batch['reference_mel_lengths'] = padded_reference_mel_lengths
|
| 688 |
+
batch['reference_mel_masks'] = reference_mel_masks
|
| 689 |
+
else:
|
| 690 |
+
# No valid mels, create dummy
|
| 691 |
+
batch['reference_mels'] = torch.zeros(batch_size, 1, mel_dim, 100)
|
| 692 |
+
batch['reference_mel_lengths'] = torch.zeros(batch_size, 1, dtype=torch.int32)
|
| 693 |
+
batch['reference_mel_masks'] = torch.zeros(batch_size, 1, 100)
|
| 694 |
+
else:
|
| 695 |
+
# No references available, create dummy for batch
|
| 696 |
+
batch_size = len(order)
|
| 697 |
+
batch['reference_mels'] = torch.zeros(batch_size, 1, 80, 100)
|
| 698 |
+
batch['reference_mel_lengths'] = torch.zeros(batch_size, 1, dtype=torch.int32)
|
| 699 |
+
batch['reference_mel_masks'] = torch.zeros(batch_size, 1, 100)
|
| 700 |
+
|
| 701 |
if gan is True:
|
| 702 |
# in gan train, we need pitch_feat
|
| 703 |
pitch_feat = [sample[i]['pitch_feat'] for i in order]
|
|
|
|
| 711 |
# only gan train needs speech, delete it to save memory
|
| 712 |
del batch["speech"]
|
| 713 |
del batch["speech_len"]
|
| 714 |
+
|
| 715 |
if dpo is True:
|
| 716 |
+
reject_speech_token = []
|
| 717 |
+
for i in order:
|
| 718 |
+
if isinstance(sample[i]['reject_speech_token'], torch.Tensor):
|
| 719 |
+
reject_speech_token.append(sample[i]['reject_speech_token'])
|
| 720 |
+
else:
|
| 721 |
+
reject_speech_token.append(torch.tensor(sample[i]['reject_speech_token']))
|
| 722 |
reject_speech_token_len = torch.tensor([i.size(0) for i in reject_speech_token], dtype=torch.int32)
|
| 723 |
reject_speech_token = pad_sequence(reject_speech_token,
|
| 724 |
batch_first=True,
|
| 725 |
padding_value=0)
|
| 726 |
batch['reject_speech_token'] = reject_speech_token
|
| 727 |
batch['reject_speech_token_len'] = reject_speech_token_len
|
| 728 |
+
|
| 729 |
if use_spk_embedding is True:
|
| 730 |
batch["embedding"] = batch["spk_embedding"]
|
| 731 |
else:
|
| 732 |
batch["embedding"] = batch["utt_embedding"]
|
| 733 |
+
|
| 734 |
+
yield batch
|
speech/cosyvoice/llm/llm.py
CHANGED
|
@@ -403,11 +403,19 @@ class Qwen2LM(TransformerLM):
|
|
| 403 |
length_normalized_loss: bool = True,
|
| 404 |
lsm_weight: float = 0.0,
|
| 405 |
mix_ratio: List[int] = [5, 15],
|
|
|
|
|
|
|
|
|
|
| 406 |
):
|
| 407 |
torch.nn.Module.__init__(self)
|
| 408 |
self.llm_input_size = llm_input_size
|
| 409 |
self.llm_output_size = llm_output_size
|
| 410 |
self.speech_token_size = speech_token_size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 411 |
# 2. build speech token language model related modules
|
| 412 |
self.sos_eos = 0
|
| 413 |
self.task_id = 1
|
|
@@ -425,7 +433,16 @@ class Qwen2LM(TransformerLM):
|
|
| 425 |
|
| 426 |
# 3. [Optional] build speech token related modules
|
| 427 |
self.speech_embedding = torch.nn.Embedding(speech_token_size + 3, llm_input_size)
|
| 428 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 429 |
# 4. sampling method
|
| 430 |
self.sampling = sampling
|
| 431 |
self.mix_ratio = mix_ratio
|
|
@@ -434,6 +451,60 @@ class Qwen2LM(TransformerLM):
|
|
| 434 |
self.stop_token_ids = [speech_token_size + i for i in range(3)]
|
| 435 |
self.vllm_output_queue = {}
|
| 436 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 437 |
def prepare_lm_input_target(self, text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len):
|
| 438 |
lm_target, lm_input = [], []
|
| 439 |
text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
|
|
@@ -493,6 +564,7 @@ class Qwen2LM(TransformerLM):
|
|
| 493 |
speech_token = batch['speech_token'].to(device)
|
| 494 |
speech_token_len = batch['speech_token_len'].to(device)
|
| 495 |
|
|
|
|
| 496 |
# 1. encode text_token
|
| 497 |
text_token_emb = self.llm.model.model.embed_tokens(text_token)
|
| 498 |
|
|
@@ -509,6 +581,43 @@ class Qwen2LM(TransformerLM):
|
|
| 509 |
loss = self.criterion_ce(logits, lm_target.to(device))
|
| 510 |
acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), lm_target, ignore_label=IGNORE_ID)
|
| 511 |
return {'loss': loss, 'acc': acc}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 512 |
|
| 513 |
def forward_dpo(
|
| 514 |
self,
|
|
|
|
| 403 |
length_normalized_loss: bool = True,
|
| 404 |
lsm_weight: float = 0.0,
|
| 405 |
mix_ratio: List[int] = [5, 15],
|
| 406 |
+
use_speaker_encoder: bool = False, # Add this
|
| 407 |
+
spk_embed_dim: int = 192, # Add this
|
| 408 |
+
max_conditioning_inputs: int = 2, # Add this
|
| 409 |
):
|
| 410 |
torch.nn.Module.__init__(self)
|
| 411 |
self.llm_input_size = llm_input_size
|
| 412 |
self.llm_output_size = llm_output_size
|
| 413 |
self.speech_token_size = speech_token_size
|
| 414 |
+
# Initialize speaker encoder settings
|
| 415 |
+
self.use_speaker_encoder = use_speaker_encoder
|
| 416 |
+
self.spk_embed_dim = spk_embed_dim
|
| 417 |
+
self.max_conditioning_inputs = max_conditioning_inputs
|
| 418 |
+
|
| 419 |
# 2. build speech token language model related modules
|
| 420 |
self.sos_eos = 0
|
| 421 |
self.task_id = 1
|
|
|
|
| 433 |
|
| 434 |
# 3. [Optional] build speech token related modules
|
| 435 |
self.speech_embedding = torch.nn.Embedding(speech_token_size + 3, llm_input_size)
|
| 436 |
+
# 4. Speaker encoder or embedding projection
|
| 437 |
+
if use_speaker_encoder:
|
| 438 |
+
self.speaker_encoder = LearnableSpeakerEncoder(
|
| 439 |
+
mel_dim=80,
|
| 440 |
+
model_dim=512,
|
| 441 |
+
output_dim=spk_embed_dim,
|
| 442 |
+
num_blocks=6,
|
| 443 |
+
num_heads=8,
|
| 444 |
+
)
|
| 445 |
+
self.spk_embed_affine_layer = nn.Linear(spk_embed_dim, llm_input_size)
|
| 446 |
# 4. sampling method
|
| 447 |
self.sampling = sampling
|
| 448 |
self.mix_ratio = mix_ratio
|
|
|
|
| 451 |
self.stop_token_ids = [speech_token_size + i for i in range(3)]
|
| 452 |
self.vllm_output_queue = {}
|
| 453 |
|
| 454 |
+
def prepare_lm_input_target_with_spk(self, text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len, speaker_embed):
|
| 455 |
+
lm_target, lm_input = [], []
|
| 456 |
+
text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
|
| 457 |
+
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
|
| 458 |
+
text_token_emb = unpad_sequence(text_token_emb, text_token_len.cpu(), batch_first=True)
|
| 459 |
+
speech_token_emb = unpad_sequence(speech_token_emb, speech_token_len.cpu(), batch_first=True)
|
| 460 |
+
for i in range(len(text_token)):
|
| 461 |
+
# bistream sequence
|
| 462 |
+
spk_emb = speaker_embed[i] # [1, llm_input_size]
|
| 463 |
+
if random.random() < 0.5 and speech_token_len[i] / text_token_len[i] > self.mix_ratio[1] / self.mix_ratio[0]:
|
| 464 |
+
this_lm_target, this_lm_input = [], []
|
| 465 |
+
this_lm_target.append(IGNORE_ID)
|
| 466 |
+
this_lm_target.append(IGNORE_ID) # For speaker embedding
|
| 467 |
+
this_lm_input.append(self.llm_embedding.weight[self.sos_eos].reshape(1, -1))
|
| 468 |
+
this_lm_input.append(spk_emb) # Add speaker embedding
|
| 469 |
+
|
| 470 |
+
for j in range(((text_token_len[i] + 1) / self.mix_ratio[0]).ceil().int().item()):
|
| 471 |
+
this_text_token = text_token[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]].tolist()
|
| 472 |
+
this_speech_token = speech_token[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]].tolist()
|
| 473 |
+
if len(this_text_token) == self.mix_ratio[0]:
|
| 474 |
+
assert len(this_speech_token) == self.mix_ratio[1]
|
| 475 |
+
this_lm_target += [IGNORE_ID] * (self.mix_ratio[0] - 1)
|
| 476 |
+
this_lm_target += this_speech_token
|
| 477 |
+
this_lm_target.append(self.speech_token_size + 2)
|
| 478 |
+
this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]])
|
| 479 |
+
this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]])
|
| 480 |
+
else:
|
| 481 |
+
this_lm_target += [-1] * len(this_text_token)
|
| 482 |
+
this_lm_target += speech_token[i][j * self.mix_ratio[1]:].tolist()
|
| 483 |
+
this_lm_target.append(self.speech_token_size)
|
| 484 |
+
this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]:])
|
| 485 |
+
this_lm_input.append(self.llm_embedding.weight[self.task_id].reshape(1, -1))
|
| 486 |
+
this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]:])
|
| 487 |
+
this_lm_target, this_lm_input = torch.tensor(this_lm_target), torch.concat(this_lm_input, dim=0)
|
| 488 |
+
# unistream sequence
|
| 489 |
+
else:
|
| 490 |
+
|
| 491 |
+
# this_lm_target = torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i].tolist() + [self.speech_token_size])
|
| 492 |
+
this_lm_target = torch.tensor([IGNORE_ID, IGNORE_ID] + [IGNORE_ID] * text_token_len[i] + speech_token[i].tolist() + [self.speech_token_size])
|
| 493 |
+
|
| 494 |
+
this_lm_input = torch.concat([
|
| 495 |
+
self.llm_embedding.weight[self.sos_eos].reshape(1, -1),
|
| 496 |
+
spk_emb,
|
| 497 |
+
text_token_emb[i],
|
| 498 |
+
self.llm_embedding.weight[self.task_id].reshape(1, -1),
|
| 499 |
+
speech_token_emb[i]
|
| 500 |
+
], dim=0)
|
| 501 |
+
lm_target.append(this_lm_target)
|
| 502 |
+
lm_input.append(this_lm_input)
|
| 503 |
+
lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
|
| 504 |
+
lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
|
| 505 |
+
lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID)
|
| 506 |
+
return lm_target, lm_input, lm_input_len
|
| 507 |
+
|
| 508 |
def prepare_lm_input_target(self, text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len):
|
| 509 |
lm_target, lm_input = [], []
|
| 510 |
text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
|
|
|
|
| 564 |
speech_token = batch['speech_token'].to(device)
|
| 565 |
speech_token_len = batch['speech_token_len'].to(device)
|
| 566 |
|
| 567 |
+
|
| 568 |
# 1. encode text_token
|
| 569 |
text_token_emb = self.llm.model.model.embed_tokens(text_token)
|
| 570 |
|
|
|
|
| 581 |
loss = self.criterion_ce(logits, lm_target.to(device))
|
| 582 |
acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), lm_target, ignore_label=IGNORE_ID)
|
| 583 |
return {'loss': loss, 'acc': acc}
|
| 584 |
+
|
| 585 |
+
def forward_spk(
|
| 586 |
+
self,
|
| 587 |
+
batch: dict,
|
| 588 |
+
device: torch.device,
|
| 589 |
+
) -> Dict[str, Optional[torch.Tensor]]:
|
| 590 |
+
"""
|
| 591 |
+
Args:
|
| 592 |
+
text: (B, L, D)
|
| 593 |
+
text_lengths: (B,)
|
| 594 |
+
audio: (B, T, N) or (B, T)
|
| 595 |
+
audio_lengths: (B,)
|
| 596 |
+
"""
|
| 597 |
+
text_token = batch['text_token'].to(device)
|
| 598 |
+
text_token_len = batch['text_token_len'].to(device)
|
| 599 |
+
speech_token = batch['speech_token'].to(device)
|
| 600 |
+
speech_token_len = batch['speech_token_len'].to(device)
|
| 601 |
+
|
| 602 |
+
speaker_embed = self.get_speaker_conditioning(batch, device) # [B, 1, llm_input_size]
|
| 603 |
+
|
| 604 |
+
|
| 605 |
+
# 1. encode text_token
|
| 606 |
+
text_token_emb = self.llm.model.model.embed_tokens(text_token)
|
| 607 |
+
|
| 608 |
+
# 2. encode speech_token
|
| 609 |
+
speech_token_emb = self.speech_embedding(speech_token)
|
| 610 |
+
|
| 611 |
+
# 3. prepare llm_input/target
|
| 612 |
+
lm_target, lm_input, lm_input_len = self.prepare_lm_input_target_with_spk(text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len, speaker_embed)
|
| 613 |
+
lm_target = lm_target.to(device)
|
| 614 |
+
|
| 615 |
+
# 4. run lm forward
|
| 616 |
+
lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
|
| 617 |
+
logits = self.llm_decoder(lm_output)
|
| 618 |
+
loss = self.criterion_ce(logits, lm_target.to(device))
|
| 619 |
+
acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), lm_target, ignore_label=IGNORE_ID)
|
| 620 |
+
return {'loss': loss, 'acc': acc}
|
| 621 |
|
| 622 |
def forward_dpo(
|
| 623 |
self,
|