File size: 2,221 Bytes
7eddfc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef932f5
7eddfc5
 
 
 
ab3af29
7eddfc5
ab3af29
7eddfc5
 
 
 
515c096
ab3af29
7eddfc5
 
 
 
 
 
 
 
515c096
7eddfc5
 
 
 
ef932f5
7eddfc5
 
 
 
956c248
3931696
7eddfc5
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#!/usr/bin/env python

import os
import logging
import numpy as np
import torch as th
import soundfile as sf
import hydra
from omegaconf import OmegaConf



class NnetComputer(object):
    def __init__(self, cpt_dir, gpuid, nnet_conf):
        self.device = th.device(f"cuda:{gpuid}") if gpuid >= 0 else th.device("cpu")
        nnet = self._load_nnet(cpt_dir, nnet_conf)
        self.nnet = nnet.to(self.device) if gpuid >= 0 else nnet
        self.nnet.eval()

    def _load_nnet(self, cpt_dir, model):
        cpt = th.load(cpt_dir, map_location="cpu")
        model.load_state_dict(cpt["model_state_dict"])

        return model

    def compute(self, samps, aux_samps, aux_samps_len):
        with th.no_grad():
            raw = th.tensor(samps, dtype=th.float32, device=self.device)
            aux = th.tensor(aux_samps, dtype=th.float32, device=self.device)
            aux_len = th.tensor(aux_samps_len, dtype=th.float32, device=self.device)
            aux = aux.unsqueeze(0)
            print("raw",raw.shape)
            print("aux",aux.shape)
            sps = self.nnet(raw, aux, aux_len)
            sp_samps = np.squeeze(sps.detach().cpu().numpy())
            return sp_samps

class InferencePipeline:

    def __init__(self, config):

        model_inst = hydra.utils.instantiate(config.model)

        self.computer_ = NnetComputer(config.test.checkpoint,config.test.gpu, model_inst)

    def run_inference(self, input_audio_path: str, enroll_audio_path: str) -> str:

        mix_samps, sr = sf.read(input_audio_path)
        aux_samps, sr2 = sf.read(enroll_audio_path)

        samps = self.computer_.compute(mix_samps, aux_samps, len(aux_samps))
        norm = np.linalg.norm(mix_samps, np.inf)
        samps = samps[:mix_samps.size]
        samps = samps * norm / np.max(np.abs(samps))

        out_wav = "temp_extracted.wav"
        sf.write(out_wav, samps, sr)
        return out_wav

if __name__ == "__main__":
    cfg = OmegaConf.load("config/config.yaml")
    pipeline = InferencePipeline(cfg)

    mix_path = "test_output_mixture.wav"
    enroll_path = "test_mix.wav"
    out_path = "temp_output.wav"
    out_wav = pipeline.run_inference(mix_path, enroll_path)
    print("Done:", out_wav)