Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| import os | |
| import logging | |
| import numpy as np | |
| import torch as th | |
| import soundfile as sf | |
| import hydra | |
| from omegaconf import OmegaConf | |
| class NnetComputer(object): | |
| def __init__(self, cpt_dir, gpuid, nnet_conf): | |
| self.device = th.device(f"cuda:{gpuid}") if gpuid >= 0 else th.device("cpu") | |
| nnet = self._load_nnet(cpt_dir, nnet_conf) | |
| self.nnet = nnet.to(self.device) if gpuid >= 0 else nnet | |
| self.nnet.eval() | |
| def _load_nnet(self, cpt_dir, model): | |
| cpt = th.load(cpt_dir, map_location="cpu") | |
| model.load_state_dict(cpt["model_state_dict"]) | |
| return model | |
| def compute(self, samps, aux_samps, aux_samps_len): | |
| with th.no_grad(): | |
| raw = th.tensor(samps, dtype=th.float32, device=self.device) | |
| aux = th.tensor(aux_samps, dtype=th.float32, device=self.device) | |
| aux_len = th.tensor(aux_samps_len, dtype=th.float32, device=self.device) | |
| aux = aux.unsqueeze(0) | |
| print("raw",raw.shape) | |
| print("aux",aux.shape) | |
| sps = self.nnet(raw, aux, aux_len) | |
| sp_samps = np.squeeze(sps.detach().cpu().numpy()) | |
| return sp_samps | |
| class InferencePipeline: | |
| def __init__(self, config): | |
| model_inst = hydra.utils.instantiate(config.model) | |
| self.computer_ = NnetComputer(config.test.checkpoint,config.test.gpu, model_inst) | |
| def run_inference(self, input_audio_path: str, enroll_audio_path: str) -> str: | |
| mix_samps, sr = sf.read(input_audio_path) | |
| aux_samps, sr2 = sf.read(enroll_audio_path) | |
| samps = self.computer_.compute(mix_samps, aux_samps, len(aux_samps)) | |
| norm = np.linalg.norm(mix_samps, np.inf) | |
| samps = samps[:mix_samps.size] | |
| samps = samps * norm / np.max(np.abs(samps)) | |
| out_wav = "temp_extracted.wav" | |
| sf.write(out_wav, samps, sr) | |
| return out_wav | |
| if __name__ == "__main__": | |
| cfg = OmegaConf.load("config/config.yaml") | |
| pipeline = InferencePipeline(cfg) | |
| mix_path = "test_output_mixture.wav" | |
| enroll_path = "test_mix.wav" | |
| out_path = "temp_output.wav" | |
| out_wav = pipeline.run_inference(mix_path, enroll_path) | |
| print("Done:", out_wav) | |