Spaces:
Runtime error
Runtime error
Update pico_model.py
Browse files- pico_model.py +1 -31
pico_model.py
CHANGED
|
@@ -12,36 +12,6 @@ from audioldm.audio.stft import TacotronSTFT
|
|
| 12 |
from audioldm.variational_autoencoder.autoencoder import AutoencoderKL
|
| 13 |
from audioldm.utils import default_audioldm_config, get_metadata
|
| 14 |
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
def build_pretrained_models(name):
|
| 18 |
-
checkpoint = torch.load(get_metadata()[name]["path"], map_location="cpu")
|
| 19 |
-
scale_factor = checkpoint["state_dict"]["scale_factor"].item()
|
| 20 |
-
|
| 21 |
-
vae_state_dict = {k[18:]: v for k, v in checkpoint["state_dict"].items() if "first_stage_model." in k}
|
| 22 |
-
|
| 23 |
-
config = default_audioldm_config(name)
|
| 24 |
-
vae_config = config["model"]["params"]["first_stage_config"]["params"]
|
| 25 |
-
vae_config["scale_factor"] = scale_factor
|
| 26 |
-
|
| 27 |
-
vae = AutoencoderKL(**vae_config)
|
| 28 |
-
vae.load_state_dict(vae_state_dict)
|
| 29 |
-
|
| 30 |
-
fn_STFT = TacotronSTFT(
|
| 31 |
-
config["preprocessing"]["stft"]["filter_length"],
|
| 32 |
-
config["preprocessing"]["stft"]["hop_length"],
|
| 33 |
-
config["preprocessing"]["stft"]["win_length"],
|
| 34 |
-
config["preprocessing"]["mel"]["n_mel_channels"],
|
| 35 |
-
config["preprocessing"]["audio"]["sampling_rate"],
|
| 36 |
-
config["preprocessing"]["mel"]["mel_fmin"],
|
| 37 |
-
config["preprocessing"]["mel"]["mel_fmax"],
|
| 38 |
-
)
|
| 39 |
-
|
| 40 |
-
vae.eval()
|
| 41 |
-
fn_STFT.eval()
|
| 42 |
-
|
| 43 |
-
return vae, fn_STFT
|
| 44 |
-
|
| 45 |
def _init_layer(layer):
|
| 46 |
"""Initialize a Linear or Convolutional layer. """
|
| 47 |
nn.init.xavier_uniform_(layer.weight)
|
|
@@ -260,7 +230,7 @@ class PicoDiffusion(ClapText_Onset_2_Audio_Diffusion):
|
|
| 260 |
ckpt = clap_load_state_dict(freeze_text_encoder_ckpt, skip_params=True)
|
| 261 |
del_parameter_key = ["text_branch.embeddings.position_ids"]
|
| 262 |
ckpt = {f"freeze_text_encoder.model.{k}":v for k, v in ckpt.items() if k not in del_parameter_key}
|
| 263 |
-
diffusion_ckpt = torch.load(diffusion_pt)
|
| 264 |
del diffusion_ckpt["class_emb.weight"]
|
| 265 |
ckpt.update(diffusion_ckpt)
|
| 266 |
self.load_state_dict(ckpt)
|
|
|
|
| 12 |
from audioldm.variational_autoencoder.autoencoder import AutoencoderKL
|
| 13 |
from audioldm.utils import default_audioldm_config, get_metadata
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
def _init_layer(layer):
|
| 16 |
"""Initialize a Linear or Convolutional layer. """
|
| 17 |
nn.init.xavier_uniform_(layer.weight)
|
|
|
|
| 230 |
ckpt = clap_load_state_dict(freeze_text_encoder_ckpt, skip_params=True)
|
| 231 |
del_parameter_key = ["text_branch.embeddings.position_ids"]
|
| 232 |
ckpt = {f"freeze_text_encoder.model.{k}":v for k, v in ckpt.items() if k not in del_parameter_key}
|
| 233 |
+
diffusion_ckpt = torch.load(diffusion_pt, map_location=torch.device(self.device))
|
| 234 |
del diffusion_ckpt["class_emb.weight"]
|
| 235 |
ckpt.update(diffusion_ckpt)
|
| 236 |
self.load_state_dict(ckpt)
|