swc2 commited on
Commit
956c248
·
1 Parent(s): eada018

change to v3.0

Browse files
Files changed (2) hide show
  1. datahandler.py +2 -2
  2. decode.py +4 -5
datahandler.py CHANGED
@@ -18,8 +18,8 @@ class AudioMixer(object):
18
  def __init__(
19
  self,
20
  sample_rate=16000,
21
- mean_snr=-4,
22
- var_snr=10,
23
  mean_loudness=-24,
24
  var_loudness=10
25
  ):
 
18
  def __init__(
19
  self,
20
  sample_rate=16000,
21
+ mean_snr=-3,
22
+ var_snr=8,
23
  mean_loudness=-24,
24
  var_loudness=10
25
  ):
decode.py CHANGED
@@ -43,19 +43,17 @@ class InferencePipeline:
43
 
44
  self.computer_ = NnetComputer(config.test.checkpoint,config.test.gpu, model_inst)
45
 
46
- def run_inference(self, input_audio_path: str, enroll_audio_path: str) -> str:
47
 
48
  mix_samps, sr = sf.read(input_audio_path)
49
  aux_samps, sr2 = sf.read(enroll_audio_path)
50
- aux_samps[10:]
51
 
52
  samps = self.computer_.compute(mix_samps, aux_samps, len(aux_samps))
53
  norm = np.linalg.norm(mix_samps, np.inf)
54
  samps = samps[:mix_samps.size]
55
  samps = samps * norm / np.max(np.abs(samps))
56
 
57
-
58
- out_wav = "temp_extracted.wav"
59
  sf.write(out_wav, samps, sr)
60
  return out_wav
61
 
@@ -65,7 +63,8 @@ if __name__ == "__main__":
65
 
66
  mix_path = "test_output_mixture.wav"
67
  enroll_path = "test_mix.wav"
68
- out_wav = pipeline.run_inference(mix_path, enroll_path)
 
69
  print("Done:", out_wav)
70
 
71
 
 
43
 
44
  self.computer_ = NnetComputer(config.test.checkpoint,config.test.gpu, model_inst)
45
 
46
+ def run_inference(self, input_audio_path: str, enroll_audio_path: str, out_path: str) -> str:
47
 
48
  mix_samps, sr = sf.read(input_audio_path)
49
  aux_samps, sr2 = sf.read(enroll_audio_path)
 
50
 
51
  samps = self.computer_.compute(mix_samps, aux_samps, len(aux_samps))
52
  norm = np.linalg.norm(mix_samps, np.inf)
53
  samps = samps[:mix_samps.size]
54
  samps = samps * norm / np.max(np.abs(samps))
55
 
56
+ out_wav = out_path
 
57
  sf.write(out_wav, samps, sr)
58
  return out_wav
59
 
 
63
 
64
  mix_path = "test_output_mixture.wav"
65
  enroll_path = "test_mix.wav"
66
+ out_path = "temp_output.wav"
67
+ out_wav = pipeline.run_inference(mix_path, enroll_path, out_path)
68
  print("Done:", out_wav)
69
 
70