Update app.py
Browse files
app.py
CHANGED
|
@@ -7,6 +7,7 @@ import torch.nn as nn
|
|
| 7 |
import lightning_module
|
| 8 |
import pdb
|
| 9 |
import jiwer
|
|
|
|
| 10 |
# ASR part
|
| 11 |
from transformers import pipeline
|
| 12 |
p = pipeline("automatic-speech-recognition")
|
|
@@ -19,6 +20,11 @@ transformation = jiwer.Compose([
|
|
| 19 |
jiwer.ReduceToListOfListOfWords(word_delimiter=" ")
|
| 20 |
])
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
class ChangeSampleRate(nn.Module):
|
| 23 |
def __init__(self, input_rate: int, output_rate: int):
|
| 24 |
super().__init__()
|
|
@@ -35,7 +41,8 @@ class ChangeSampleRate(nn.Module):
|
|
| 35 |
output = round_down * (1. - indices.fmod(1.)).unsqueeze(0) + round_up * indices.fmod(1.).unsqueeze(0)
|
| 36 |
return output
|
| 37 |
|
| 38 |
-
model = lightning_module.BaselineLightningModule.load_from_checkpoint("
|
|
|
|
| 39 |
def calc_mos(audio_path, ref):
|
| 40 |
wav, sr = torchaudio.load(audio_path)
|
| 41 |
osr = 16_000
|
|
@@ -46,7 +53,7 @@ def calc_mos(audio_path, ref):
|
|
| 46 |
trans = p(audio_path)["text"]
|
| 47 |
# WER
|
| 48 |
wer = jiwer.wer(ref, trans, truth_transform=transformation, hypothesis_transform=transformation)
|
| 49 |
-
|
| 50 |
batch = {
|
| 51 |
'wav': out_wavs,
|
| 52 |
'domains': torch.tensor([0]),
|
|
@@ -54,10 +61,17 @@ def calc_mos(audio_path, ref):
|
|
| 54 |
}
|
| 55 |
with torch.no_grad():
|
| 56 |
output = model(batch)
|
| 57 |
-
|
| 58 |
predic_mos = output.mean(dim=1).squeeze().detach().numpy()*2 + 3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
-
return predic_mos, trans, wer
|
| 61 |
|
| 62 |
description ="""
|
| 63 |
MOS prediction demo using UTMOS-strong w/o phoneme encoder model, which is trained on the main track dataset.
|
|
@@ -71,9 +85,14 @@ Add WER interface.
|
|
| 71 |
|
| 72 |
iface = gr.Interface(
|
| 73 |
fn=calc_mos,
|
| 74 |
-
inputs=[gr.Audio(type='filepath'
|
| 75 |
-
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
description=description,
|
| 78 |
allow_flagging="auto",
|
| 79 |
)
|
|
|
|
| 7 |
import lightning_module
|
| 8 |
import pdb
|
| 9 |
import jiwer
|
| 10 |
+
|
| 11 |
# ASR part
|
| 12 |
from transformers import pipeline
|
| 13 |
p = pipeline("automatic-speech-recognition")
|
|
|
|
| 20 |
jiwer.ReduceToListOfListOfWords(word_delimiter=" ")
|
| 21 |
])
|
| 22 |
|
| 23 |
+
# WPM part
|
| 24 |
+
from transformers import Wav2Vec2PhonemeCTCTokenizer, Wav2Vec2Processor, Wav2Vec2ForCTC
|
| 25 |
+
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-xlsr-53-espeak-cv-ft")
|
| 26 |
+
phoneme_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-xlsr-53-espeak-cv-ft")
|
| 27 |
+
|
| 28 |
class ChangeSampleRate(nn.Module):
|
| 29 |
def __init__(self, input_rate: int, output_rate: int):
|
| 30 |
super().__init__()
|
|
|
|
| 41 |
output = round_down * (1. - indices.fmod(1.)).unsqueeze(0) + round_up * indices.fmod(1.).unsqueeze(0)
|
| 42 |
return output
|
| 43 |
|
| 44 |
+
model = lightning_module.BaselineLightningModule.load_from_checkpoint("epoch=3-step=7459.ckpt").eval()
|
| 45 |
+
|
| 46 |
def calc_mos(audio_path, ref):
|
| 47 |
wav, sr = torchaudio.load(audio_path)
|
| 48 |
osr = 16_000
|
|
|
|
| 53 |
trans = p(audio_path)["text"]
|
| 54 |
# WER
|
| 55 |
wer = jiwer.wer(ref, trans, truth_transform=transformation, hypothesis_transform=transformation)
|
| 56 |
+
# MOS
|
| 57 |
batch = {
|
| 58 |
'wav': out_wavs,
|
| 59 |
'domains': torch.tensor([0]),
|
|
|
|
| 61 |
}
|
| 62 |
with torch.no_grad():
|
| 63 |
output = model(batch)
|
|
|
|
| 64 |
predic_mos = output.mean(dim=1).squeeze().detach().numpy()*2 + 3
|
| 65 |
+
# Phonemes per minute (PPM)
|
| 66 |
+
with torch.no_grad():
|
| 67 |
+
logits = phoneme_model(out_wavs).logits
|
| 68 |
+
phone_predicted_ids = torch.argmax(logits, dim=-1)
|
| 69 |
+
phone_transcription = processor.batch_decode(phone_predicted_ids)
|
| 70 |
+
lst_phonemes = phone_transcription[0].split(" ")
|
| 71 |
+
wav_vad = torchaudio.functional.vad(wav, sample_rate=sr)
|
| 72 |
+
ppm = len(lst_phonemes) / (wav_vad.shape[-1] / sr) * 60
|
| 73 |
|
| 74 |
+
return predic_mos, trans, wer, phone_transcription, ppm
|
| 75 |
|
| 76 |
description ="""
|
| 77 |
MOS prediction demo using UTMOS-strong w/o phoneme encoder model, which is trained on the main track dataset.
|
|
|
|
| 85 |
|
| 86 |
iface = gr.Interface(
|
| 87 |
fn=calc_mos,
|
| 88 |
+
inputs=[gr.Audio(type='filepath', label="Audio to evaluate"),
|
| 89 |
+
gr.Textbox(placeholder="Input referance here", label="Referance")],
|
| 90 |
+
outputs=[gr.Textbox(placeholder="Predicted MOS", label="Predicted MOS"),
|
| 91 |
+
gr.Textbox(placeholder="Hypothesis", label="Hypothesis"),
|
| 92 |
+
gr.Textbox(placeholder="Word Error Rate", label = "WER"),
|
| 93 |
+
gr.Textbox(placeholder="Predicted Phonemes", label="Predicted Phonemes"),
|
| 94 |
+
gr.Textbox(placeholder="Phonemes per minutes", label="PPM")],
|
| 95 |
+
title="Laronix's Voice Quality Checking System Demo",
|
| 96 |
description=description,
|
| 97 |
allow_flagging="auto",
|
| 98 |
)
|