Spaces:
Runtime error
Runtime error
| import hydra | |
| from omegaconf import DictConfig | |
| import torch | |
| from remfx.models import RemFXChainInference | |
| import torchaudio | |
| def main(cfg: DictConfig): | |
| print("Loading models...") | |
| models = {} | |
| for effect in cfg.ckpts: | |
| model = hydra.utils.instantiate(cfg.ckpts[effect].model, _convert_="partial") | |
| ckpt_path = cfg.ckpts[effect].ckpt_path | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| state_dict = torch.load(ckpt_path, map_location=device)["state_dict"] | |
| model.load_state_dict(state_dict) | |
| model.to(device) | |
| models[effect] = model | |
| classifier = hydra.utils.instantiate(cfg.classifier, _convert_="partial") | |
| ckpt_path = cfg.classifier_ckpt | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| state_dict = torch.load(ckpt_path, map_location=device)["state_dict"] | |
| classifier.load_state_dict(state_dict) | |
| classifier.to(device) | |
| inference_model = RemFXChainInference( | |
| models, | |
| sample_rate=cfg.sample_rate, | |
| num_bins=cfg.num_bins, | |
| effect_order=cfg.inference_effects_ordering, | |
| classifier=classifier, | |
| shuffle_effect_order=cfg.inference_effects_shuffle, | |
| use_all_effect_models=cfg.inference_use_all_effect_models, | |
| ) | |
| audio_file = cfg.audio_input | |
| print("Loading", audio_file) | |
| audio, sr = torchaudio.load(audio_file) | |
| # Resample | |
| audio = torchaudio.transforms.Resample(sr, cfg.sample_rate)(audio) | |
| # Convert to mono | |
| audio = audio.mean(0, keepdim=True) | |
| # Add dimension for batch | |
| audio = audio.unsqueeze(0) | |
| audio = audio.to(device) | |
| batch = [audio, audio, None, None] | |
| _, y = inference_model(batch, 0, verbose=True) | |
| y = y.cpu() | |
| if "output_path" in cfg: | |
| output_path = cfg.output_path | |
| else: | |
| output_path = "./output.wav" | |
| print("Saving output to", output_path) | |
| torchaudio.save(output_path, y[0], sample_rate=cfg.sample_rate) | |
| if __name__ == "__main__": | |
| main() | |