Update demos/musicgen_app.py
Browse files- demos/musicgen_app.py +56 -1
demos/musicgen_app.py
CHANGED
|
@@ -93,6 +93,44 @@ def make_waveform(*args, **kwargs):
|
|
| 93 |
return out
|
| 94 |
|
| 95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
def load_model(version="facebook/musicgen-small"):
|
| 97 |
global MODEL
|
| 98 |
print("Loading Musivesal musicgen-small") # , version
|
|
@@ -101,8 +139,25 @@ def load_model(version="facebook/musicgen-small"):
|
|
| 101 |
del MODEL
|
| 102 |
torch.cuda.empty_cache()
|
| 103 |
MODEL = None # in case loading would crash
|
| 104 |
-
MODEL = MusicGen.get_pretrained("data")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
print("Custom model loaded.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
|
| 107 |
|
| 108 |
def load_diffusion():
|
|
|
|
| 93 |
return out
|
| 94 |
|
| 95 |
|
| 96 |
+
def _delete_param(cfg: DictConfig, full_name: str):
|
| 97 |
+
parts = full_name.split(".")
|
| 98 |
+
for part in parts[:-1]:
|
| 99 |
+
if part in cfg:
|
| 100 |
+
cfg = cfg[part]
|
| 101 |
+
else:
|
| 102 |
+
return
|
| 103 |
+
OmegaConf.set_struct(cfg, False)
|
| 104 |
+
if parts[-1] in cfg:
|
| 105 |
+
del cfg[parts[-1]]
|
| 106 |
+
OmegaConf.set_struct(cfg, True)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def load_lm_model(
|
| 110 |
+
file_or_url_or_id: tp.Union[Path, str],
|
| 111 |
+
device=None,
|
| 112 |
+
):
|
| 113 |
+
pkg = torch.load(file_or_url_or_id, map_location=device)
|
| 114 |
+
cfg = OmegaConf.create(pkg["xp.cfg"])
|
| 115 |
+
cfg.device = str(device)
|
| 116 |
+
if cfg.device == "cpu":
|
| 117 |
+
cfg.dtype = "float32"
|
| 118 |
+
else:
|
| 119 |
+
cfg.dtype = "float16"
|
| 120 |
+
_delete_param(cfg, "conditioners.self_wav.chroma_stem.cache_path")
|
| 121 |
+
_delete_param(cfg, "conditioners.args.merge_text_conditions_p")
|
| 122 |
+
_delete_param(cfg, "conditioners.args.drop_desc_p")
|
| 123 |
+
model = get_lm_model(cfg)
|
| 124 |
+
model.load_state_dict(pkg["best_state"])
|
| 125 |
+
model.eval()
|
| 126 |
+
model.cfg = cfg
|
| 127 |
+
return model
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def load_compression_model(file_or_url_or_id: tp.Union[Path, str], device=None):
|
| 131 |
+
return CompressionModel.get_pretrained(file_or_url_or_id, device=device)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
def load_model(version="facebook/musicgen-small"):
|
| 135 |
global MODEL
|
| 136 |
print("Loading Musivesal musicgen-small") # , version
|
|
|
|
| 139 |
del MODEL
|
| 140 |
torch.cuda.empty_cache()
|
| 141 |
MODEL = None # in case loading would crash
|
| 142 |
+
# MODEL = MusicGen.get_pretrained("/Users/ebenge/repos/audiocraft/data/")
|
| 143 |
+
lm = load_lm_model("../data/state_dict.bin", device="cudu")
|
| 144 |
+
compression_model = load_compression_model(
|
| 145 |
+
"facebook/encodec_32khz", device="cudu"
|
| 146 |
+
)
|
| 147 |
+
MODEL = MusicGen("musiversal/musicgen-small", compression_model, lm)
|
| 148 |
print("Custom model loaded.")
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
# def load_model(version="facebook/musicgen-small"):
|
| 152 |
+
# global MODEL
|
| 153 |
+
# print("Loading Musivesal musicgen-small") # , version
|
| 154 |
+
# if MODEL is None or MODEL.name != version:
|
| 155 |
+
# # Clear PyTorch CUDA cache and delete model
|
| 156 |
+
# del MODEL
|
| 157 |
+
# torch.cuda.empty_cache()
|
| 158 |
+
# MODEL = None # in case loading would crash
|
| 159 |
+
# MODEL = MusicGen.get_pretrained("data")
|
| 160 |
+
# print("Custom model loaded.")
|
| 161 |
|
| 162 |
|
| 163 |
def load_diffusion():
|