updated requirements
Browse files- models/vallex.py +8 -3
models/vallex.py
CHANGED
|
@@ -22,7 +22,6 @@ import torch.nn.functional as F
|
|
| 22 |
# from icefall.utils import make_pad_mask
|
| 23 |
# from torchmetrics.classification import MulticlassAccuracy
|
| 24 |
|
| 25 |
-
|
| 26 |
from modules.embedding import SinePositionalEmbedding, TokenEmbedding
|
| 27 |
from modules.transformer import (
|
| 28 |
AdaptiveLayerNorm,
|
|
@@ -493,7 +492,10 @@ class VALLE(VALLF):
|
|
| 493 |
x = self.ar_text_embedding(text)
|
| 494 |
# Add language embedding
|
| 495 |
prompt_language_id = torch.LongTensor(np.array([self.language_ID[prompt_language]])).to(x.device)
|
| 496 |
-
|
|
|
|
|
|
|
|
|
|
| 497 |
x[:, :enroll_x_lens, :] += self.ar_language_embedding(prompt_language_id)
|
| 498 |
x[:, enroll_x_lens:, :] += self.ar_language_embedding(text_language_id)
|
| 499 |
x = self.ar_text_prenet(x)
|
|
@@ -599,7 +601,10 @@ class VALLE(VALLF):
|
|
| 599 |
x = self.nar_text_embedding(text)
|
| 600 |
# Add language embedding
|
| 601 |
prompt_language_id = torch.LongTensor(np.array([self.language_ID[prompt_language]])).to(x.device)
|
| 602 |
-
|
|
|
|
|
|
|
|
|
|
| 603 |
x[:, :enroll_x_lens, :] += self.nar_language_embedding(prompt_language_id)
|
| 604 |
x[:, enroll_x_lens:, :] += self.nar_language_embedding(text_language_id)
|
| 605 |
x = self.nar_text_prenet(x)
|
|
|
|
| 22 |
# from icefall.utils import make_pad_mask
|
| 23 |
# from torchmetrics.classification import MulticlassAccuracy
|
| 24 |
|
|
|
|
| 25 |
from modules.embedding import SinePositionalEmbedding, TokenEmbedding
|
| 26 |
from modules.transformer import (
|
| 27 |
AdaptiveLayerNorm,
|
|
|
|
| 492 |
x = self.ar_text_embedding(text)
|
| 493 |
# Add language embedding
|
| 494 |
prompt_language_id = torch.LongTensor(np.array([self.language_ID[prompt_language]])).to(x.device)
|
| 495 |
+
if isinstance(text_language, str):
|
| 496 |
+
text_language_id = torch.LongTensor(np.array([self.language_ID[text_language]])).to(x.device)
|
| 497 |
+
elif isinstance(text_language, List):
|
| 498 |
+
text_language_id = torch.LongTensor(np.array([self.language_ID[tl] for tl in text_language])).to(x.device)
|
| 499 |
x[:, :enroll_x_lens, :] += self.ar_language_embedding(prompt_language_id)
|
| 500 |
x[:, enroll_x_lens:, :] += self.ar_language_embedding(text_language_id)
|
| 501 |
x = self.ar_text_prenet(x)
|
|
|
|
| 601 |
x = self.nar_text_embedding(text)
|
| 602 |
# Add language embedding
|
| 603 |
prompt_language_id = torch.LongTensor(np.array([self.language_ID[prompt_language]])).to(x.device)
|
| 604 |
+
if isinstance(text_language, str):
|
| 605 |
+
text_language_id = torch.LongTensor(np.array([self.language_ID[text_language]])).to(x.device)
|
| 606 |
+
elif isinstance(text_language, List):
|
| 607 |
+
text_language_id = torch.LongTensor(np.array([self.language_ID[tl] for tl in text_language])).to(x.device)
|
| 608 |
x[:, :enroll_x_lens, :] += self.nar_language_embedding(prompt_language_id)
|
| 609 |
x[:, enroll_x_lens:, :] += self.nar_language_embedding(text_language_id)
|
| 610 |
x = self.nar_text_prenet(x)
|