swc2's picture
add model select
ef932f5
#!/usr/bin/env python
import os
import logging
import numpy as np
import torch as th
import soundfile as sf
import hydra
from omegaconf import OmegaConf
class NnetComputer(object):
def __init__(self, cpt_dir, gpuid, nnet_conf):
self.device = th.device(f"cuda:{gpuid}") if gpuid >= 0 else th.device("cpu")
nnet = self._load_nnet(cpt_dir, nnet_conf)
self.nnet = nnet.to(self.device) if gpuid >= 0 else nnet
self.nnet.eval()
def _load_nnet(self, cpt_dir, model):
cpt = th.load(cpt_dir, map_location="cpu")
model.load_state_dict(cpt["model_state_dict"])
return model
def compute(self, samps, aux_samps, aux_samps_len):
with th.no_grad():
raw = th.tensor(samps, dtype=th.float32, device=self.device)
aux = th.tensor(aux_samps, dtype=th.float32, device=self.device)
aux_len = th.tensor(aux_samps_len, dtype=th.float32, device=self.device)
aux = aux.unsqueeze(0)
print("raw",raw.shape)
print("aux",aux.shape)
sps = self.nnet(raw, aux, aux_len)
sp_samps = np.squeeze(sps.detach().cpu().numpy())
return sp_samps
class InferencePipeline:
def __init__(self, config):
model_inst = hydra.utils.instantiate(config.model)
self.computer_ = NnetComputer(config.test.checkpoint,config.test.gpu, model_inst)
def run_inference(self, input_audio_path: str, enroll_audio_path: str) -> str:
mix_samps, sr = sf.read(input_audio_path)
aux_samps, sr2 = sf.read(enroll_audio_path)
samps = self.computer_.compute(mix_samps, aux_samps, len(aux_samps))
norm = np.linalg.norm(mix_samps, np.inf)
samps = samps[:mix_samps.size]
samps = samps * norm / np.max(np.abs(samps))
out_wav = "temp_extracted.wav"
sf.write(out_wav, samps, sr)
return out_wav
if __name__ == "__main__":
cfg = OmegaConf.load("config/config.yaml")
pipeline = InferencePipeline(cfg)
mix_path = "test_output_mixture.wav"
enroll_path = "test_mix.wav"
out_path = "temp_output.wav"
out_wav = pipeline.run_inference(mix_path, enroll_path)
print("Done:", out_wav)