#!/usr/bin/env python import os import logging import numpy as np import torch as th import soundfile as sf import hydra from omegaconf import OmegaConf class NnetComputer(object): def __init__(self, cpt_dir, gpuid, nnet_conf): self.device = th.device(f"cuda:{gpuid}") if gpuid >= 0 else th.device("cpu") nnet = self._load_nnet(cpt_dir, nnet_conf) self.nnet = nnet.to(self.device) if gpuid >= 0 else nnet self.nnet.eval() def _load_nnet(self, cpt_dir, model): cpt = th.load(cpt_dir, map_location="cpu") model.load_state_dict(cpt["model_state_dict"]) return model def compute(self, samps, aux_samps, aux_samps_len): with th.no_grad(): raw = th.tensor(samps, dtype=th.float32, device=self.device) aux = th.tensor(aux_samps, dtype=th.float32, device=self.device) aux_len = th.tensor(aux_samps_len, dtype=th.float32, device=self.device) aux = aux.unsqueeze(0) print("raw",raw.shape) print("aux",aux.shape) sps = self.nnet(raw, aux, aux_len) sp_samps = np.squeeze(sps.detach().cpu().numpy()) return sp_samps class InferencePipeline: def __init__(self, config): model_inst = hydra.utils.instantiate(config.model) self.computer_ = NnetComputer(config.test.checkpoint,config.test.gpu, model_inst) def run_inference(self, input_audio_path: str, enroll_audio_path: str) -> str: mix_samps, sr = sf.read(input_audio_path) aux_samps, sr2 = sf.read(enroll_audio_path) samps = self.computer_.compute(mix_samps, aux_samps, len(aux_samps)) norm = np.linalg.norm(mix_samps, np.inf) samps = samps[:mix_samps.size] samps = samps * norm / np.max(np.abs(samps)) out_wav = "temp_extracted.wav" sf.write(out_wav, samps, sr) return out_wav if __name__ == "__main__": cfg = OmegaConf.load("config/config.yaml") pipeline = InferencePipeline(cfg) mix_path = "test_output_mixture.wav" enroll_path = "test_mix.wav" out_path = "temp_output.wav" out_wav = pipeline.run_inference(mix_path, enroll_path) print("Done:", out_wav)