eusip commited on
Commit
73c9101
·
verified ·
1 Parent(s): 8d95e16

Update demos/musicgen_app.py

Browse files
Files changed (1) hide show
  1. demos/musicgen_app.py +56 -1
demos/musicgen_app.py CHANGED
@@ -93,6 +93,44 @@ def make_waveform(*args, **kwargs):
93
  return out
94
 
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  def load_model(version="facebook/musicgen-small"):
97
  global MODEL
98
  print("Loading Musivesal musicgen-small") # , version
@@ -101,8 +139,25 @@ def load_model(version="facebook/musicgen-small"):
101
  del MODEL
102
  torch.cuda.empty_cache()
103
  MODEL = None # in case loading would crash
104
- MODEL = MusicGen.get_pretrained("data")
 
 
 
 
 
105
  print("Custom model loaded.")
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
 
108
  def load_diffusion():
 
93
  return out
94
 
95
 
96
+ def _delete_param(cfg: DictConfig, full_name: str):
97
+ parts = full_name.split(".")
98
+ for part in parts[:-1]:
99
+ if part in cfg:
100
+ cfg = cfg[part]
101
+ else:
102
+ return
103
+ OmegaConf.set_struct(cfg, False)
104
+ if parts[-1] in cfg:
105
+ del cfg[parts[-1]]
106
+ OmegaConf.set_struct(cfg, True)
107
+
108
+
109
+ def load_lm_model(
110
+ file_or_url_or_id: tp.Union[Path, str],
111
+ device=None,
112
+ ):
113
+ pkg = torch.load(file_or_url_or_id, map_location=device)
114
+ cfg = OmegaConf.create(pkg["xp.cfg"])
115
+ cfg.device = str(device)
116
+ if cfg.device == "cpu":
117
+ cfg.dtype = "float32"
118
+ else:
119
+ cfg.dtype = "float16"
120
+ _delete_param(cfg, "conditioners.self_wav.chroma_stem.cache_path")
121
+ _delete_param(cfg, "conditioners.args.merge_text_conditions_p")
122
+ _delete_param(cfg, "conditioners.args.drop_desc_p")
123
+ model = get_lm_model(cfg)
124
+ model.load_state_dict(pkg["best_state"])
125
+ model.eval()
126
+ model.cfg = cfg
127
+ return model
128
+
129
+
130
+ def load_compression_model(file_or_url_or_id: tp.Union[Path, str], device=None):
131
+ return CompressionModel.get_pretrained(file_or_url_or_id, device=device)
132
+
133
+
134
  def load_model(version="facebook/musicgen-small"):
135
  global MODEL
136
  print("Loading Musivesal musicgen-small") # , version
 
139
  del MODEL
140
  torch.cuda.empty_cache()
141
  MODEL = None # in case loading would crash
142
+ # MODEL = MusicGen.get_pretrained("/Users/ebenge/repos/audiocraft/data/")
143
+ lm = load_lm_model("../data/state_dict.bin", device="cudu")
144
+ compression_model = load_compression_model(
145
+ "facebook/encodec_32khz", device="cudu"
146
+ )
147
+ MODEL = MusicGen("musiversal/musicgen-small", compression_model, lm)
148
  print("Custom model loaded.")
149
+
150
+
151
+ # def load_model(version="facebook/musicgen-small"):
152
+ # global MODEL
153
+ # print("Loading Musivesal musicgen-small") # , version
154
+ # if MODEL is None or MODEL.name != version:
155
+ # # Clear PyTorch CUDA cache and delete model
156
+ # del MODEL
157
+ # torch.cuda.empty_cache()
158
+ # MODEL = None # in case loading would crash
159
+ # MODEL = MusicGen.get_pretrained("data")
160
+ # print("Custom model loaded.")
161
 
162
 
163
  def load_diffusion():