Update musiclm_pytorch.py
Browse files- musiclm_pytorch.py +10 -4
musiclm_pytorch.py
CHANGED
|
@@ -5,6 +5,7 @@ from torch import nn, einsum
|
|
| 5 |
from torchaudio.transforms import Spectrogram, TimeStretch, FrequencyMasking, TimeMasking
|
| 6 |
|
| 7 |
from audiolm_pytorch import AudioLM
|
|
|
|
| 8 |
|
| 9 |
from x_clip.tokenizer import tokenizer
|
| 10 |
from vector_quantize_pytorch import ResidualVQ
|
|
@@ -448,7 +449,7 @@ class MuLaN(nn.Module):
|
|
| 448 |
# music lm
|
| 449 |
|
| 450 |
@beartype
|
| 451 |
-
class MuLaNEmbedQuantizer(
|
| 452 |
def __init__(
|
| 453 |
self,
|
| 454 |
mulan: MuLaN,
|
|
@@ -494,6 +495,9 @@ class MuLaNEmbedQuantizer(nn.Module):
|
|
| 494 |
|
| 495 |
self.set_default_namespace(namespaces[0])
|
| 496 |
|
|
|
|
|
|
|
|
|
|
| 497 |
def set_default_namespace(self, namespace):
|
| 498 |
self._default_namespace = namespace
|
| 499 |
|
|
@@ -537,6 +541,8 @@ class MusicLM(nn.Module):
|
|
| 537 |
mulan_embed_quantizer: MuLaNEmbedQuantizer
|
| 538 |
):
|
| 539 |
super().__init__()
|
|
|
|
|
|
|
| 540 |
self.mulan_embed_quantizer = mulan_embed_quantizer
|
| 541 |
self.audio_lm = audio_lm
|
| 542 |
|
|
@@ -549,7 +555,7 @@ class MusicLM(nn.Module):
|
|
| 549 |
self.eval()
|
| 550 |
|
| 551 |
texts = tokenizer.tokenize(raw_texts)
|
| 552 |
-
cond_tokens = self.mulan_embed_quantizer(texts = texts)
|
| 553 |
|
| 554 |
-
|
| 555 |
-
|
|
|
|
|
|
| 5 |
from torchaudio.transforms import Spectrogram, TimeStretch, FrequencyMasking, TimeMasking
|
| 6 |
|
| 7 |
from audiolm_pytorch import AudioLM
|
| 8 |
+
from audiolm_pytorch.utils import AudioConditionerBase
|
| 9 |
|
| 10 |
from x_clip.tokenizer import tokenizer
|
| 11 |
from vector_quantize_pytorch import ResidualVQ
|
|
|
|
| 449 |
# music lm
|
| 450 |
|
| 451 |
@beartype
|
| 452 |
+
class MuLaNEmbedQuantizer(AudioConditionerBase):
|
| 453 |
def __init__(
|
| 454 |
self,
|
| 455 |
mulan: MuLaN,
|
|
|
|
| 495 |
|
| 496 |
self.set_default_namespace(namespaces[0])
|
| 497 |
|
| 498 |
+
def parameters(self):
|
| 499 |
+
return self.cond_embeddings.parameters()
|
| 500 |
+
|
| 501 |
def set_default_namespace(self, namespace):
|
| 502 |
self._default_namespace = namespace
|
| 503 |
|
|
|
|
| 541 |
mulan_embed_quantizer: MuLaNEmbedQuantizer
|
| 542 |
):
|
| 543 |
super().__init__()
|
| 544 |
+
assert not exists(audio_lm.audio_conditioner), 'mulan must not have been passed into AudioLM. it will be managed externally now, embedding the text into the joint embedding space for text-to-audio synthesis'
|
| 545 |
+
|
| 546 |
self.mulan_embed_quantizer = mulan_embed_quantizer
|
| 547 |
self.audio_lm = audio_lm
|
| 548 |
|
|
|
|
| 555 |
self.eval()
|
| 556 |
|
| 557 |
texts = tokenizer.tokenize(raw_texts)
|
|
|
|
| 558 |
|
| 559 |
+
text_embeds = self.mulan_embed_quantizer(texts = texts)
|
| 560 |
+
|
| 561 |
+
return self.audio_lm(text_embeds = text_embeds, **audio_lm_kwargs)
|