Spaces:
Runtime error
Runtime error
kAIto47802
commited on
Commit
·
8537948
1
Parent(s):
a18d920
Fix and add quick option
Browse files
app.py
CHANGED
|
@@ -24,24 +24,25 @@ cfg.config = "fusion_stage3"
|
|
| 24 |
cfg.print_config = False
|
| 25 |
cfg.data_config = None
|
| 26 |
cfg.phase = "inference"
|
| 27 |
-
cfg.weight = None
|
| 28 |
cfg.num_workers = 1
|
| 29 |
|
| 30 |
@spaces.GPU
|
| 31 |
@torch.inference_mode()
|
| 32 |
-
def predict_mos(audio_path: str, domain: str) -> float:
|
| 33 |
data = pd.DataFrame({"file_path": [audio_path]})
|
| 34 |
data["dataset"] = domain
|
| 35 |
-
data[
|
| 36 |
-
|
| 37 |
preds = 0.0
|
| 38 |
for fold in range(5):
|
| 39 |
cfg.now_fold = fold
|
|
|
|
| 40 |
model = get_model(cfg, device).eval()
|
| 41 |
for _ in range(5):
|
| 42 |
test_dataset = get_dataset(cfg, data, "test")
|
| 43 |
p = model(*[torch.tensor(t).unsqueeze(0).to(device) for t in test_dataset[0][:-1]])
|
| 44 |
-
preds += p.cpu().numpy()[0]
|
|
|
|
|
|
|
| 45 |
preds /= 25.0
|
| 46 |
return preds
|
| 47 |
|
|
@@ -65,12 +66,22 @@ with gr.Blocks() as demo:
|
|
| 65 |
"blizzard2011",
|
| 66 |
],
|
| 67 |
label="Data-domain ID for the MOS prediction",
|
| 68 |
-
value="sarulab"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
)
|
| 70 |
submit = gr.Button(value="Submit")
|
| 71 |
|
| 72 |
with gr.Column():
|
| 73 |
output = gr.Textbox(label="Predicted MOS", type="text")
|
| 74 |
-
submit.click(fn=predict_mos, inputs=[audio, domain], outputs=[output])
|
| 75 |
|
| 76 |
demo.queue().launch()
|
|
|
|
| 24 |
cfg.print_config = False
|
| 25 |
cfg.data_config = None
|
| 26 |
cfg.phase = "inference"
|
|
|
|
| 27 |
cfg.num_workers = 1
|
| 28 |
|
| 29 |
@spaces.GPU
|
| 30 |
@torch.inference_mode()
|
| 31 |
+
def predict_mos(audio_path: str, domain: str, quick: bool) -> float:
|
| 32 |
data = pd.DataFrame({"file_path": [audio_path]})
|
| 33 |
data["dataset"] = domain
|
| 34 |
+
data["mos"] = 0
|
|
|
|
| 35 |
preds = 0.0
|
| 36 |
for fold in range(5):
|
| 37 |
cfg.now_fold = fold
|
| 38 |
+
cfg.weight = f"models/fusion_stage3/fold{fold}_s42_best_model.pth"
|
| 39 |
model = get_model(cfg, device).eval()
|
| 40 |
for _ in range(5):
|
| 41 |
test_dataset = get_dataset(cfg, data, "test")
|
| 42 |
p = model(*[torch.tensor(t).unsqueeze(0).to(device) for t in test_dataset[0][:-1]])
|
| 43 |
+
preds += p.cpu().numpy()[0][0]
|
| 44 |
+
if quick:
|
| 45 |
+
return preds
|
| 46 |
preds /= 25.0
|
| 47 |
return preds
|
| 48 |
|
|
|
|
| 66 |
"blizzard2011",
|
| 67 |
],
|
| 68 |
label="Data-domain ID for the MOS prediction",
|
| 69 |
+
value="sarulab",
|
| 70 |
+
)
|
| 71 |
+
quick = gr.Checkbox(
|
| 72 |
+
label="Quick prediction",
|
| 73 |
+
value=True,
|
| 74 |
+
info=(
|
| 75 |
+
"UTMOSv2 makes predictions repeatedly for five randomly selected frames "
|
| 76 |
+
"of the input speech waveform for all five folds. "
|
| 77 |
+
"To make quick predictions by reducing this to a single repetition, "
|
| 78 |
+
"check this checkbox:",
|
| 79 |
+
),
|
| 80 |
)
|
| 81 |
submit = gr.Button(value="Submit")
|
| 82 |
|
| 83 |
with gr.Column():
|
| 84 |
output = gr.Textbox(label="Predicted MOS", type="text")
|
| 85 |
+
submit.click(fn=predict_mos, inputs=[audio, domain, quick], outputs=[output])
|
| 86 |
|
| 87 |
demo.queue().launch()
|