swc2 commited on
Commit
ef932f5
·
1 Parent(s): f8d6437

add model select

Browse files
Files changed (4) hide show
  1. app.py +22 -5
  2. decode.py +2 -2
  3. model/spex_plus.py +5 -1
  4. model/spex_plus_plus.py +1 -1
app.py CHANGED
@@ -8,19 +8,30 @@ from decode import InferencePipeline
8
  from datahandler import AudioMixer, fix_audio_format
9
  from omegaconf import OmegaConf
10
 
 
 
 
 
11
 
12
-
13
- cfg = OmegaConf.load("config/config_ira.yaml")
14
- inter = InferencePipeline(cfg)
15
  datamix = AudioMixer()
16
 
17
 
18
  def gradio_TSE(input_audio_path, enroll_audio_path1, enroll_audio_path2, audio_type):
19
-
20
  print(f"User uploaded audio path: {input_audio_path}")
21
  print(f"User enroll audio path: {enroll_audio_path1}")
22
  print(f"User enroll audio path: {enroll_audio_path2}")
23
 
 
 
 
 
 
 
 
 
24
 
25
  audio_info = sf.info(input_audio_path)
26
  print(f"采样率: {audio_info.samplerate} Hz")
@@ -85,6 +96,12 @@ with gr.Blocks() as demo:
85
  value="clean",
86
  label="Input audio type?"
87
  )
 
 
 
 
 
 
88
  with gr.Row():
89
 
90
  enroll_audio1 = gr.Audio(label="Upload your first enroll audio", type="filepath")
@@ -101,7 +118,7 @@ with gr.Blocks() as demo:
101
  convert_button = gr.Button("Extract")
102
  convert_button.click(
103
  fn=gradio_TSE,
104
- inputs=[input_audio, enroll_audio1, enroll_audio2, audio_type],
105
  outputs=[noisy_audio_output, extracted_audio_output1, extracted_audio_output2]
106
  )
107
 
 
8
  from datahandler import AudioMixer, fix_audio_format
9
  from omegaconf import OmegaConf
10
 
11
+ MODEL_CACHE = {
12
+ "base_model": InferencePipeline(OmegaConf.load("config/config.yaml")),
13
+ "iter_model": InferencePipeline(OmegaConf.load("config/config_ira.yaml"))
14
+ }
15
 
16
+ # cfg = OmegaConf.load("config/config_ira.yaml")
17
+ # inter = InferencePipeline(cfg)
 
18
  datamix = AudioMixer()
19
 
20
 
21
  def gradio_TSE(input_audio_path, enroll_audio_path1, enroll_audio_path2, audio_type):
22
+ print(f"模型选择: {model_select}")
23
  print(f"User uploaded audio path: {input_audio_path}")
24
  print(f"User enroll audio path: {enroll_audio_path1}")
25
  print(f"User enroll audio path: {enroll_audio_path2}")
26
 
27
+ # if model_select == "base_model":
28
+ # cfg_path = "config/config_base.yaml"
29
+ # elif model_select == "iter_model":
30
+ # cfg_path = "config/config_iter.yaml"
31
+ # else:
32
+ # raise ValueError("未知模型类型")
33
+
34
+ inter = MODEL_CACHE[model_select]
35
 
36
  audio_info = sf.info(input_audio_path)
37
  print(f"采样率: {audio_info.samplerate} Hz")
 
96
  value="clean",
97
  label="Input audio type?"
98
  )
99
+
100
+ model_select = gr.Radio(
101
+ choices=["base_model", "iter_model"],
102
+ value="iter_model",
103
+ label="Select Model Type"
104
+ )
105
  with gr.Row():
106
 
107
  enroll_audio1 = gr.Audio(label="Upload your first enroll audio", type="filepath")
 
118
  convert_button = gr.Button("Extract")
119
  convert_button.click(
120
  fn=gradio_TSE,
121
+ inputs=[input_audio, enroll_audio1, enroll_audio2, audio_type, model_select],
122
  outputs=[noisy_audio_output, extracted_audio_output1, extracted_audio_output2]
123
  )
124
 
decode.py CHANGED
@@ -31,7 +31,7 @@ class NnetComputer(object):
31
  aux = aux.unsqueeze(0)
32
  print("raw",raw.shape)
33
  print("aux",aux.shape)
34
- sps,spk_pred,emb = self.nnet(raw, aux, aux_len)
35
  sp_samps = np.squeeze(sps.detach().cpu().numpy())
36
  return sp_samps
37
 
@@ -58,7 +58,7 @@ class InferencePipeline:
58
  return out_wav
59
 
60
  if __name__ == "__main__":
61
- cfg = OmegaConf.load("config/config_ira.yaml")
62
  pipeline = InferencePipeline(cfg)
63
 
64
  mix_path = "test_output_mixture.wav"
 
31
  aux = aux.unsqueeze(0)
32
  print("raw",raw.shape)
33
  print("aux",aux.shape)
34
+ sps = self.nnet(raw, aux, aux_len)
35
  sp_samps = np.squeeze(sps.detach().cpu().numpy())
36
  return sp_samps
37
 
 
58
  return out_wav
59
 
60
  if __name__ == "__main__":
61
+ cfg = OmegaConf.load("config/config.yaml")
62
  pipeline = InferencePipeline(cfg)
63
 
64
  mix_path = "test_output_mixture.wav"
model/spex_plus.py CHANGED
@@ -122,8 +122,12 @@ class SpEx_Plus(nn.Module):
122
  S1 = w1 * m1
123
  S2 = w2 * m2
124
  S3 = w3 * m3
 
 
 
 
125
 
126
- return self.decoder_1d_short(S1), self.decoder_1d_middle(S2)[:, :xlen1], self.decoder_1d_long(S3)[:, :xlen1], self.pred_linear(aux)
127
 
128
  class Extractor(nn.Module):
129
  def __init__(self,
 
122
  S1 = w1 * m1
123
  S2 = w2 * m2
124
  S3 = w3 * m3
125
+
126
+ out1 = self.decoder_1d_short(S1)
127
+ # out2 = self.decoder_1d_middle(S2)[:, :xlen1]
128
+ # out3 = self.decoder_1d_long(S3)[:, :xlen1]
129
 
130
+ return self.decoder_1d_short(S1)
131
 
132
  class Extractor(nn.Module):
133
  def __init__(self,
model/spex_plus_plus.py CHANGED
@@ -206,7 +206,7 @@ class SpEx_Plus_Double(nn.Module):
206
 
207
  est3 = self.ira(est2, aux, aux_len,xlen1, xlen2, xlen3, w1, w2, w3)
208
 
209
- return est3,self.pred_linear(aux), aux
210
 
211
  class Extractor(nn.Module):
212
  def __init__(self,
 
206
 
207
  est3 = self.ira(est2, aux, aux_len,xlen1, xlen2, xlen3, w1, w2, w3)
208
 
209
+ return est3
210
 
211
  class Extractor(nn.Module):
212
  def __init__(self,