Spaces:
Running
Running
add model select
Browse files- app.py +22 -5
- decode.py +2 -2
- model/spex_plus.py +5 -1
- model/spex_plus_plus.py +1 -1
app.py
CHANGED
|
@@ -8,19 +8,30 @@ from decode import InferencePipeline
|
|
| 8 |
from datahandler import AudioMixer, fix_audio_format
|
| 9 |
from omegaconf import OmegaConf
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
inter = InferencePipeline(cfg)
|
| 15 |
datamix = AudioMixer()
|
| 16 |
|
| 17 |
|
| 18 |
def gradio_TSE(input_audio_path, enroll_audio_path1, enroll_audio_path2, audio_type):
|
| 19 |
-
|
| 20 |
print(f"User uploaded audio path: {input_audio_path}")
|
| 21 |
print(f"User enroll audio path: {enroll_audio_path1}")
|
| 22 |
print(f"User enroll audio path: {enroll_audio_path2}")
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
audio_info = sf.info(input_audio_path)
|
| 26 |
print(f"采样率: {audio_info.samplerate} Hz")
|
|
@@ -85,6 +96,12 @@ with gr.Blocks() as demo:
|
|
| 85 |
value="clean",
|
| 86 |
label="Input audio type?"
|
| 87 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
with gr.Row():
|
| 89 |
|
| 90 |
enroll_audio1 = gr.Audio(label="Upload your first enroll audio", type="filepath")
|
|
@@ -101,7 +118,7 @@ with gr.Blocks() as demo:
|
|
| 101 |
convert_button = gr.Button("Extract")
|
| 102 |
convert_button.click(
|
| 103 |
fn=gradio_TSE,
|
| 104 |
-
inputs=[input_audio, enroll_audio1, enroll_audio2, audio_type],
|
| 105 |
outputs=[noisy_audio_output, extracted_audio_output1, extracted_audio_output2]
|
| 106 |
)
|
| 107 |
|
|
|
|
| 8 |
from datahandler import AudioMixer, fix_audio_format
|
| 9 |
from omegaconf import OmegaConf
|
| 10 |
|
| 11 |
+
MODEL_CACHE = {
|
| 12 |
+
"base_model": InferencePipeline(OmegaConf.load("config/config.yaml")),
|
| 13 |
+
"iter_model": InferencePipeline(OmegaConf.load("config/config_ira.yaml"))
|
| 14 |
+
}
|
| 15 |
|
| 16 |
+
# cfg = OmegaConf.load("config/config_ira.yaml")
|
| 17 |
+
# inter = InferencePipeline(cfg)
|
|
|
|
| 18 |
datamix = AudioMixer()
|
| 19 |
|
| 20 |
|
| 21 |
def gradio_TSE(input_audio_path, enroll_audio_path1, enroll_audio_path2, audio_type):
|
| 22 |
+
print(f"模型选择: {model_select}")
|
| 23 |
print(f"User uploaded audio path: {input_audio_path}")
|
| 24 |
print(f"User enroll audio path: {enroll_audio_path1}")
|
| 25 |
print(f"User enroll audio path: {enroll_audio_path2}")
|
| 26 |
|
| 27 |
+
# if model_select == "base_model":
|
| 28 |
+
# cfg_path = "config/config_base.yaml"
|
| 29 |
+
# elif model_select == "iter_model":
|
| 30 |
+
# cfg_path = "config/config_iter.yaml"
|
| 31 |
+
# else:
|
| 32 |
+
# raise ValueError("未知模型类型")
|
| 33 |
+
|
| 34 |
+
inter = MODEL_CACHE[model_select]
|
| 35 |
|
| 36 |
audio_info = sf.info(input_audio_path)
|
| 37 |
print(f"采样率: {audio_info.samplerate} Hz")
|
|
|
|
| 96 |
value="clean",
|
| 97 |
label="Input audio type?"
|
| 98 |
)
|
| 99 |
+
|
| 100 |
+
model_select = gr.Radio(
|
| 101 |
+
choices=["base_model", "iter_model"],
|
| 102 |
+
value="iter_model",
|
| 103 |
+
label="Select Model Type"
|
| 104 |
+
)
|
| 105 |
with gr.Row():
|
| 106 |
|
| 107 |
enroll_audio1 = gr.Audio(label="Upload your first enroll audio", type="filepath")
|
|
|
|
| 118 |
convert_button = gr.Button("Extract")
|
| 119 |
convert_button.click(
|
| 120 |
fn=gradio_TSE,
|
| 121 |
+
inputs=[input_audio, enroll_audio1, enroll_audio2, audio_type, model_select],
|
| 122 |
outputs=[noisy_audio_output, extracted_audio_output1, extracted_audio_output2]
|
| 123 |
)
|
| 124 |
|
decode.py
CHANGED
|
@@ -31,7 +31,7 @@ class NnetComputer(object):
|
|
| 31 |
aux = aux.unsqueeze(0)
|
| 32 |
print("raw",raw.shape)
|
| 33 |
print("aux",aux.shape)
|
| 34 |
-
sps
|
| 35 |
sp_samps = np.squeeze(sps.detach().cpu().numpy())
|
| 36 |
return sp_samps
|
| 37 |
|
|
@@ -58,7 +58,7 @@ class InferencePipeline:
|
|
| 58 |
return out_wav
|
| 59 |
|
| 60 |
if __name__ == "__main__":
|
| 61 |
-
cfg = OmegaConf.load("config/
|
| 62 |
pipeline = InferencePipeline(cfg)
|
| 63 |
|
| 64 |
mix_path = "test_output_mixture.wav"
|
|
|
|
| 31 |
aux = aux.unsqueeze(0)
|
| 32 |
print("raw",raw.shape)
|
| 33 |
print("aux",aux.shape)
|
| 34 |
+
sps = self.nnet(raw, aux, aux_len)
|
| 35 |
sp_samps = np.squeeze(sps.detach().cpu().numpy())
|
| 36 |
return sp_samps
|
| 37 |
|
|
|
|
| 58 |
return out_wav
|
| 59 |
|
| 60 |
if __name__ == "__main__":
|
| 61 |
+
cfg = OmegaConf.load("config/config.yaml")
|
| 62 |
pipeline = InferencePipeline(cfg)
|
| 63 |
|
| 64 |
mix_path = "test_output_mixture.wav"
|
model/spex_plus.py
CHANGED
|
@@ -122,8 +122,12 @@ class SpEx_Plus(nn.Module):
|
|
| 122 |
S1 = w1 * m1
|
| 123 |
S2 = w2 * m2
|
| 124 |
S3 = w3 * m3
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
|
| 126 |
-
return self.decoder_1d_short(S1)
|
| 127 |
|
| 128 |
class Extractor(nn.Module):
|
| 129 |
def __init__(self,
|
|
|
|
| 122 |
S1 = w1 * m1
|
| 123 |
S2 = w2 * m2
|
| 124 |
S3 = w3 * m3
|
| 125 |
+
|
| 126 |
+
out1 = self.decoder_1d_short(S1)
|
| 127 |
+
# out2 = self.decoder_1d_middle(S2)[:, :xlen1]
|
| 128 |
+
# out3 = self.decoder_1d_long(S3)[:, :xlen1]
|
| 129 |
|
| 130 |
+
return self.decoder_1d_short(S1)
|
| 131 |
|
| 132 |
class Extractor(nn.Module):
|
| 133 |
def __init__(self,
|
model/spex_plus_plus.py
CHANGED
|
@@ -206,7 +206,7 @@ class SpEx_Plus_Double(nn.Module):
|
|
| 206 |
|
| 207 |
est3 = self.ira(est2, aux, aux_len,xlen1, xlen2, xlen3, w1, w2, w3)
|
| 208 |
|
| 209 |
-
return est3
|
| 210 |
|
| 211 |
class Extractor(nn.Module):
|
| 212 |
def __init__(self,
|
|
|
|
| 206 |
|
| 207 |
est3 = self.ira(est2, aux, aux_len,xlen1, xlen2, xlen3, w1, w2, w3)
|
| 208 |
|
| 209 |
+
return est3
|
| 210 |
|
| 211 |
class Extractor(nn.Module):
|
| 212 |
def __init__(self,
|