Update MuCodec/generate.py
Browse files- MuCodec/generate.py +2 -1
MuCodec/generate.py
CHANGED
|
@@ -12,6 +12,7 @@ import numpy as np
|
|
| 12 |
from tools.get_melvaehifigan48k import build_pretrained_models
|
| 13 |
import tools.torch_tools as torch_tools
|
| 14 |
from safetensors.torch import load_file
|
|
|
|
| 15 |
|
| 16 |
class MuCodec:
|
| 17 |
def __init__(self, \
|
|
@@ -26,7 +27,7 @@ class MuCodec:
|
|
| 26 |
|
| 27 |
self.MAX_DURATION = 360
|
| 28 |
if load_main_model:
|
| 29 |
-
audio_ldm_path =
|
| 30 |
self.vae, self.stft = build_pretrained_models(audio_ldm_path)
|
| 31 |
self.vae, self.stft = self.vae.eval().to(device), self.stft.eval().to(device)
|
| 32 |
main_config = {
|
|
|
|
| 12 |
from tools.get_melvaehifigan48k import build_pretrained_models
|
| 13 |
import tools.torch_tools as torch_tools
|
| 14 |
from safetensors.torch import load_file
|
| 15 |
+
from cached_path import cached_path
|
| 16 |
|
| 17 |
class MuCodec:
|
| 18 |
def __init__(self, \
|
|
|
|
| 27 |
|
| 28 |
self.MAX_DURATION = 360
|
| 29 |
if load_main_model:
|
| 30 |
+
audio_ldm_path = str(cached_path("hf://haoheliu/audioldm_48k/audioldm_48k.pth"))
|
| 31 |
self.vae, self.stft = build_pretrained_models(audio_ldm_path)
|
| 32 |
self.vae, self.stft = self.vae.eval().to(device), self.stft.eval().to(device)
|
| 33 |
main_config = {
|