Spaces:
Sleeping
Sleeping
Ubuntu
commited on
Commit
·
e422499
1
Parent(s):
b6ec358
update llm
Browse files- speech/cosyvoice/llm/llm.py +28 -82
speech/cosyvoice/llm/llm.py
CHANGED
|
@@ -162,51 +162,28 @@ class TransformerLM(torch.nn.Module):
|
|
| 162 |
|
| 163 |
def get_speaker_conditioning(self, batch, device):
|
| 164 |
"""Extract speaker conditioning from reference audio or embeddings."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
conds.append(cond)
|
| 183 |
-
|
| 184 |
-
# Average multiple references (like Tortoise)
|
| 185 |
-
speaker_embed = torch.stack(conds, dim=1).mean(dim=1) # [B, spk_embed_dim]
|
| 186 |
-
|
| 187 |
-
else: # Single reference [B, T, 80]
|
| 188 |
-
ref_mel = reference_mels.transpose(1, 2) # [B, 80, T]
|
| 189 |
-
if 'reference_mel_mask' in batch:
|
| 190 |
-
mask = batch['reference_mel_mask'].unsqueeze(1).to(device)
|
| 191 |
-
else:
|
| 192 |
-
mask = None
|
| 193 |
-
speaker_embed = self.speaker_encoder(ref_mel, mask)
|
| 194 |
-
|
| 195 |
-
# Project to LLM dimension
|
| 196 |
-
speaker_embed = self.spk_embed_affine_layer(speaker_embed)
|
| 197 |
-
speaker_embed = speaker_embed.unsqueeze(1) # [B, 1, llm_input_size]
|
| 198 |
-
|
| 199 |
-
elif 'embedding' in batch:
|
| 200 |
-
# Use provided embeddings (backward compatibility)
|
| 201 |
-
embedding = batch['embedding'].to(device)
|
| 202 |
-
embedding = F.normalize(embedding, dim=1)
|
| 203 |
-
speaker_embed = self.spk_embed_affine_layer(embedding)
|
| 204 |
-
speaker_embed = speaker_embed.unsqueeze(1)
|
| 205 |
-
|
| 206 |
-
else:
|
| 207 |
-
# No speaker conditioning
|
| 208 |
-
B = batch['text_token'].shape[0]
|
| 209 |
-
speaker_embed = torch.zeros(B, 1, self.llm_input_size).to(device)
|
| 210 |
|
| 211 |
return speaker_embed
|
| 212 |
|
|
@@ -548,41 +525,6 @@ class Qwen2LM(TransformerLM):
|
|
| 548 |
lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
|
| 549 |
lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID)
|
| 550 |
return lm_target, lm_input, lm_input_len
|
| 551 |
-
|
| 552 |
-
# def forward(
|
| 553 |
-
# self,
|
| 554 |
-
# batch: dict,
|
| 555 |
-
# device: torch.device,
|
| 556 |
-
# ) -> Dict[str, Optional[torch.Tensor]]:
|
| 557 |
-
# """
|
| 558 |
-
# Args:
|
| 559 |
-
# text: (B, L, D)
|
| 560 |
-
# text_lengths: (B,)
|
| 561 |
-
# audio: (B, T, N) or (B, T)
|
| 562 |
-
# audio_lengths: (B,)
|
| 563 |
-
# """
|
| 564 |
-
# text_token = batch['text_token'].to(device)
|
| 565 |
-
# text_token_len = batch['text_token_len'].to(device)
|
| 566 |
-
# speech_token = batch['speech_token'].to(device)
|
| 567 |
-
# speech_token_len = batch['speech_token_len'].to(device)
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
# # 1. encode text_token
|
| 571 |
-
# text_token_emb = self.llm.model.model.embed_tokens(text_token)
|
| 572 |
-
|
| 573 |
-
# # 2. encode speech_token
|
| 574 |
-
# speech_token_emb = self.speech_embedding(speech_token)
|
| 575 |
-
|
| 576 |
-
# # 3. prepare llm_input/target
|
| 577 |
-
# lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len)
|
| 578 |
-
# lm_target = lm_target.to(device)
|
| 579 |
-
|
| 580 |
-
# # 4. run lm forward
|
| 581 |
-
# lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
|
| 582 |
-
# logits = self.llm_decoder(lm_output)
|
| 583 |
-
# loss = self.criterion_ce(logits, lm_target.to(device))
|
| 584 |
-
# acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), lm_target, ignore_label=IGNORE_ID)
|
| 585 |
-
# return {'loss': loss, 'acc': acc}
|
| 586 |
|
| 587 |
def forward(
|
| 588 |
self,
|
|
@@ -601,8 +543,8 @@ class Qwen2LM(TransformerLM):
|
|
| 601 |
speech_token = batch['speech_token'].to(device)
|
| 602 |
speech_token_len = batch['speech_token_len'].to(device)
|
| 603 |
|
| 604 |
-
|
| 605 |
-
|
| 606 |
|
| 607 |
# 1. encode text_token
|
| 608 |
text_token_emb = self.llm.model.model.embed_tokens(text_token)
|
|
@@ -611,7 +553,11 @@ class Qwen2LM(TransformerLM):
|
|
| 611 |
speech_token_emb = self.speech_embedding(speech_token)
|
| 612 |
|
| 613 |
# 3. prepare llm_input/target
|
| 614 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 615 |
lm_target = lm_target.to(device)
|
| 616 |
|
| 617 |
# 4. run lm forward
|
|
|
|
| 162 |
|
| 163 |
def get_speaker_conditioning(self, batch, device):
|
| 164 |
"""Extract speaker conditioning from reference audio or embeddings."""
|
| 165 |
+
|
| 166 |
+
reference_mels = batch['reference_mels'].to(device)
|
| 167 |
+
# Handle multiple references
|
| 168 |
+
B, N, C, T = reference_mels.shape
|
| 169 |
+
conds = []
|
| 170 |
|
| 171 |
+
for i in range(N):
|
| 172 |
+
ref_mel = reference_mels[:, i, :, :]
|
| 173 |
+
if 'reference_mel_masks' in batch:
|
| 174 |
+
mask = batch['reference_mel_masks'][:, i, :].unsqueeze(1).to(device)
|
| 175 |
+
else:
|
| 176 |
+
mask = None
|
| 177 |
+
print('ref_mel shape: ', ref_mel.shape, 'mask: ', mask.shape)
|
| 178 |
+
cond = self.speaker_encoder(ref_mel, mask) # [B, spk_embed_dim]
|
| 179 |
+
conds.append(cond)
|
| 180 |
+
|
| 181 |
+
# Average multiple references (like Tortoise)
|
| 182 |
+
speaker_embed = torch.stack(conds, dim=1).mean(dim=1) # [B, spk_embed_dim]
|
| 183 |
+
|
| 184 |
+
speaker_embed = F.normalize(speaker_embed, dim=1)
|
| 185 |
+
speaker_embed = self.spk_embed_affine_layer(speaker_embed)
|
| 186 |
+
speaker_embed = speaker_embed.unsqueeze(1) # [B, 1, llm_input_size]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
|
| 188 |
return speaker_embed
|
| 189 |
|
|
|
|
| 525 |
lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
|
| 526 |
lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID)
|
| 527 |
return lm_target, lm_input, lm_input_len
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 528 |
|
| 529 |
def forward(
|
| 530 |
self,
|
|
|
|
| 543 |
speech_token = batch['speech_token'].to(device)
|
| 544 |
speech_token_len = batch['speech_token_len'].to(device)
|
| 545 |
|
| 546 |
+
if self.use_speaker_encoder:
|
| 547 |
+
embedding = self.get_speaker_conditioning(batch, device) # [B, 1, llm_input_size]
|
| 548 |
|
| 549 |
# 1. encode text_token
|
| 550 |
text_token_emb = self.llm.model.model.embed_tokens(text_token)
|
|
|
|
| 553 |
speech_token_emb = self.speech_embedding(speech_token)
|
| 554 |
|
| 555 |
# 3. prepare llm_input/target
|
| 556 |
+
if self.use_speaker_encoder:
|
| 557 |
+
lm_target, lm_input, lm_input_len = self.prepare_lm_input_target_with_spk(text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len, embedding)
|
| 558 |
+
else:
|
| 559 |
+
lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len)
|
| 560 |
+
|
| 561 |
lm_target = lm_target.to(device)
|
| 562 |
|
| 563 |
# 4. run lm forward
|