Spaces:
Runtime error
Runtime error
| import torch | |
| import yaml | |
| from audiosr import download_checkpoint, default_audioldm_config, LatentDiffusion | |
| def load_audiosr(ckpt_path=None, config=None, device=None, model_name="basic"): | |
| if device is None or device == "auto": | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda:0") | |
| elif torch.backends.mps.is_available(): | |
| device = torch.device("mps") | |
| else: | |
| device = torch.device("cpu") | |
| print("Loading AudioSR: %s" % model_name) | |
| print("Loading model on %s" % device) | |
| ckpt_path = download_checkpoint(model_name) | |
| if config is not None: | |
| assert type(config) is str | |
| config = yaml.load(open(config, "r"), Loader=yaml.FullLoader) | |
| else: | |
| config = default_audioldm_config(model_name) | |
| # # Use text as condition instead of using waveform during training | |
| config["model"]["params"]["device"] = device | |
| # config["model"]["params"]["cond_stage_key"] = "text" | |
| # No normalization here | |
| latent_diffusion = LatentDiffusion(**config["model"]["params"]) | |
| resume_from_checkpoint = ckpt_path | |
| checkpoint = torch.load(resume_from_checkpoint, map_location="cpu") | |
| latent_diffusion.load_state_dict(checkpoint["state_dict"], strict=True) | |
| latent_diffusion.eval() | |
| latent_diffusion = latent_diffusion.to(device) | |
| return latent_diffusion | |