|
|
import spaces |
|
|
import os |
|
|
import shutil |
|
|
import time |
|
|
from typing import Generator, Optional, Tuple |
|
|
|
|
|
import gradio as gr |
|
|
import nltk |
|
|
import numpy as np |
|
|
import torch |
|
|
from huggingface_hub import HfApi |
|
|
|
|
|
|
|
|
from espnet2.sds.espnet_model import ESPnetSDSModelInterface |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
access_token = os.environ.get("HF_TOKEN") |
|
|
ASR_name="pyf98/owsm_ctc_v3.1_1B" |
|
|
LLM_name="meta-llama/Llama-3.2-1B-Instruct" |
|
|
TTS_name="espnet/kan-bayashi_ljspeech_vits" |
|
|
ASR_options="pyf98/owsm_ctc_v3.1_1B,espnet/owsm_v3.1_ebf".split(",") |
|
|
LLM_options="meta-llama/Llama-3.2-1B-Instruct".split(",") |
|
|
TTS_options="espnet/kan-bayashi_ljspeech_vits,espnet/kan-bayashi_vctk_multi_spk_vits".split(",") |
|
|
Eval_options="Latency,TTS Intelligibility,TTS Speech Quality,ASR WER,Text Dialog Metrics" |
|
|
upload_to_hub=None |
|
|
dialogue_model = ESPnetSDSModelInterface( |
|
|
ASR_name, LLM_name, TTS_name, "Cascaded", access_token |
|
|
) |
|
|
ASR_curr_name=None |
|
|
LLM_curr_name=None |
|
|
TTS_curr_name=None |
|
|
|
|
|
latency_ASR = 0.0 |
|
|
latency_LM = 0.0 |
|
|
latency_TTS = 0.0 |
|
|
|
|
|
text_str = "" |
|
|
asr_output_str = "" |
|
|
vad_output = None |
|
|
audio_output = None |
|
|
audio_output1 = None |
|
|
LLM_response_arr = [] |
|
|
total_response_arr = [] |
|
|
start_record_time = None |
|
|
enable_btn = gr.Button(interactive=True, visible=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def handle_eval_selection( |
|
|
option: str, |
|
|
TTS_audio_output: str, |
|
|
LLM_Output: str, |
|
|
ASR_audio_output: str, |
|
|
ASR_transcript: str, |
|
|
): |
|
|
""" |
|
|
Handles the evaluation of a selected metric based on |
|
|
user input and provided outputs. |
|
|
|
|
|
This function evaluates different aspects of a |
|
|
casacaded conversational AI pipeline, such as: |
|
|
Latency, TTS intelligibility, TTS speech quality, |
|
|
ASR WER, and text dialog metrics. |
|
|
It is designed to integrate with Gradio via |
|
|
multiple yield statements, |
|
|
allowing updates to be displayed in real time. |
|
|
|
|
|
Parameters: |
|
|
---------- |
|
|
option : str |
|
|
The evaluation metric selected by the user. |
|
|
Supported options include: |
|
|
- "Latency" |
|
|
- "TTS Intelligibility" |
|
|
- "TTS Speech Quality" |
|
|
- "ASR WER" |
|
|
- "Text Dialog Metrics" |
|
|
TTS_audio_output : np.ndarray |
|
|
The audio output generated by the TTS module for evaluation. |
|
|
LLM_Output : str |
|
|
The text output generated by the LLM module for evaluation. |
|
|
ASR_audio_output : np.ndarray |
|
|
The audio input/output used for ASR evaluation. |
|
|
ASR_transcript : str |
|
|
The transcript generated by the ASR module for evaluation. |
|
|
|
|
|
Returns: |
|
|
------- |
|
|
str |
|
|
A string representation of the evaluation results. |
|
|
The specific result depends on the selected evaluation metric: |
|
|
- "Latency": Latencies of ASR, LLM, and TTS modules. |
|
|
- "TTS Intelligibility": A range of scores indicating how intelligible |
|
|
the TTS audio output is based on different reference ASR models. |
|
|
- "TTS Speech Quality": A range of scores representing the |
|
|
speech quality of the TTS audio output. |
|
|
- "ASR WER": The Word Error Rate (WER) of the ASR output |
|
|
based on different judge ASR models. |
|
|
- "Text Dialog Metrics": A combination of perplexity, |
|
|
diversity metrics, and relevance scores for the dialog. |
|
|
|
|
|
Raises: |
|
|
------ |
|
|
ValueError |
|
|
If the `option` parameter does not match any supported evaluation metric. |
|
|
|
|
|
Example: |
|
|
------- |
|
|
>>> result = handle_eval_selection( |
|
|
option="Latency", |
|
|
TTS_audio_output=audio_array, |
|
|
LLM_Output="Generated response", |
|
|
ASR_audio_output=audio_input, |
|
|
ASR_transcript="Expected transcript" |
|
|
) |
|
|
>>> print(result) |
|
|
"ASR Latency: 0.14 |
|
|
LLM Latency: 0.42 |
|
|
TTS Latency: 0.21" |
|
|
""" |
|
|
global LLM_response_arr |
|
|
global total_response_arr |
|
|
return None |
|
|
|
|
|
|
|
|
def handle_eval_selection_E2E( |
|
|
option: str, |
|
|
TTS_audio_output: str, |
|
|
LLM_Output: str, |
|
|
): |
|
|
""" |
|
|
Handles the evaluation of a selected metric based on user input |
|
|
and provided outputs. |
|
|
|
|
|
This function evaluates different aspects of an E2E |
|
|
conversational AI model, such as: |
|
|
Latency, TTS intelligibility, TTS speech quality, and |
|
|
text dialog metrics. |
|
|
It is designed to integrate with Gradio via |
|
|
multiple yield statements, |
|
|
allowing updates to be displayed in real time. |
|
|
|
|
|
Parameters: |
|
|
---------- |
|
|
option : str |
|
|
The evaluation metric selected by the user. |
|
|
Supported options include: |
|
|
- "Latency" |
|
|
- "TTS Intelligibility" |
|
|
- "TTS Speech Quality" |
|
|
- "Text Dialog Metrics" |
|
|
TTS_audio_output : np.ndarray |
|
|
The audio output generated by the TTS module for evaluation. |
|
|
LLM_Output : str |
|
|
The text output generated by the LLM module for evaluation. |
|
|
|
|
|
Returns: |
|
|
------- |
|
|
str |
|
|
A string representation of the evaluation results. |
|
|
The specific result depends on the selected evaluation metric: |
|
|
- "Latency": Latency of the entire system. |
|
|
- "TTS Intelligibility": A range of scores indicating how intelligible the |
|
|
TTS audio output is based on different reference ASR models. |
|
|
- "TTS Speech Quality": A range of scores representing the |
|
|
speech quality of the TTS audio output. |
|
|
- "Text Dialog Metrics": A combination of perplexity and |
|
|
diversity metrics for the dialog. |
|
|
|
|
|
Raises: |
|
|
------ |
|
|
ValueError |
|
|
If the `option` parameter does not match any supported evaluation metric. |
|
|
|
|
|
Example: |
|
|
------- |
|
|
>>> result = handle_eval_selection( |
|
|
option="Latency", |
|
|
TTS_audio_output=audio_array, |
|
|
LLM_Output="Generated response", |
|
|
) |
|
|
>>> print(result) |
|
|
"Total Latency: 2.34" |
|
|
""" |
|
|
global LLM_response_arr |
|
|
global total_response_arr |
|
|
return |
|
|
|
|
|
|
|
|
def start_warmup(): |
|
|
""" |
|
|
Initializes and warms up the dialogue and evaluation model. |
|
|
|
|
|
This function is designed to ensure that all |
|
|
components of the dialogue model are pre-loaded |
|
|
and ready for execution, avoiding delays during runtime. |
|
|
""" |
|
|
global dialogue_model |
|
|
global ASR_options |
|
|
global LLM_options |
|
|
global TTS_options |
|
|
global ASR_name |
|
|
global LLM_name |
|
|
global TTS_name |
|
|
remove=0 |
|
|
for opt_count in range(len(ASR_options)): |
|
|
opt_count-=remove |
|
|
if opt_count>=len(ASR_options): |
|
|
break |
|
|
print(opt_count) |
|
|
print(ASR_options) |
|
|
opt = ASR_options[opt_count] |
|
|
try: |
|
|
for _ in dialogue_model.handle_ASR_selection(opt): |
|
|
continue |
|
|
except Exception as e: |
|
|
print(e) |
|
|
print("Removing " + opt + " from ASR options since it cannot be loaded.") |
|
|
ASR_options = ASR_options[:opt_count] + ASR_options[(opt_count + 1) :] |
|
|
remove+=1 |
|
|
if opt == ASR_name: |
|
|
ASR_name = ASR_options[0] |
|
|
for opt_count in range(len(LLM_options)): |
|
|
opt_count-=remove |
|
|
if opt_count>=len(LLM_options): |
|
|
break |
|
|
opt = LLM_options[opt_count] |
|
|
try: |
|
|
for _ in dialogue_model.handle_LLM_selection(opt): |
|
|
continue |
|
|
except Exception as e: |
|
|
print(e) |
|
|
print("Removing " + opt + " from LLM options since it cannot be loaded.") |
|
|
LLM_options = LLM_options[:opt_count] + LLM_options[(opt_count + 1) :] |
|
|
remove+=1 |
|
|
if opt == LLM_name: |
|
|
LLM_name = LLM_options[0] |
|
|
for opt_count in range(len(TTS_options)): |
|
|
opt_count-=remove |
|
|
if opt_count>=len(TTS_options): |
|
|
break |
|
|
opt = TTS_options[opt_count] |
|
|
try: |
|
|
for _ in dialogue_model.handle_TTS_selection(opt): |
|
|
continue |
|
|
except Exception as e: |
|
|
print(e) |
|
|
print("Removing " + opt + " from TTS options since it cannot be loaded.") |
|
|
TTS_options = TTS_options[:opt_count] + TTS_options[(opt_count + 1) :] |
|
|
remove+=1 |
|
|
if opt == TTS_name: |
|
|
TTS_name = TTS_options[0] |
|
|
dialogue_model.handle_E2E_selection() |
|
|
dialogue_model.client = None |
|
|
for _ in dialogue_model.handle_TTS_selection(TTS_name): |
|
|
continue |
|
|
for _ in dialogue_model.handle_ASR_selection(ASR_name): |
|
|
continue |
|
|
for _ in dialogue_model.handle_LLM_selection(LLM_name): |
|
|
continue |
|
|
dummy_input = ( |
|
|
torch.randn( |
|
|
(3000), |
|
|
dtype=getattr(torch, "float16"), |
|
|
device="cpu", |
|
|
) |
|
|
.cpu() |
|
|
.numpy() |
|
|
) |
|
|
dummy_text = "This is dummy text" |
|
|
for opt in Eval_options: |
|
|
handle_eval_selection(opt, dummy_input, dummy_text, dummy_input, dummy_text) |
|
|
|
|
|
|
|
|
def flash_buttons(): |
|
|
""" |
|
|
Enables human feedback buttons after displaying system output. |
|
|
""" |
|
|
btn_updates = (enable_btn,) * 8 |
|
|
yield ( |
|
|
"", |
|
|
"", |
|
|
) + btn_updates |
|
|
|
|
|
|
|
|
def transcribe( |
|
|
stream: np.ndarray, |
|
|
new_chunk: Tuple[int, np.ndarray], |
|
|
TTS_option: str, |
|
|
ASR_option: str, |
|
|
LLM_option: str, |
|
|
type_option: str, |
|
|
input_text: str, |
|
|
): |
|
|
""" |
|
|
Processes and transcribes an audio stream in real-time. |
|
|
|
|
|
This function handles the transcription of audio input |
|
|
and its transformation through a cascaded |
|
|
or E2E conversational AI system. |
|
|
It dynamically updates the transcription, text generation, |
|
|
and synthesized speech output, while managing global states and latencies. |
|
|
|
|
|
Args: |
|
|
stream: The current audio stream buffer. |
|
|
`None` if the stream is being reset (e.g., after user refresh). |
|
|
new_chunk: A tuple containing: |
|
|
- `sr`: Sample rate of the new audio chunk. |
|
|
- `y`: New audio data chunk. |
|
|
TTS_option: Selected TTS model option. |
|
|
ASR_option: Selected ASR model option. |
|
|
LLM_option: Selected LLM model option. |
|
|
type_option: Type of system ("Cascaded" or "E2E"). |
|
|
|
|
|
Yields: |
|
|
Tuple[Optional[np.ndarray], Optional[str], Optional[str], |
|
|
Optional[Tuple[int, np.ndarray]], Optional[Tuple[int, np.ndarray]]]: |
|
|
A tuple containing: |
|
|
- Updated stream buffer. |
|
|
- ASR output text. |
|
|
- Generated LLM output text. |
|
|
- Audio output as a tuple of sample rate and audio waveform. |
|
|
- User input audio as a tuple of sample rate and audio waveform. |
|
|
|
|
|
Notes: |
|
|
- Resets the session if the transcription exceeds 5 minutes. |
|
|
- Updates the Gradio interface elements dynamically. |
|
|
- Manages latencies. |
|
|
""" |
|
|
sr, y = new_chunk |
|
|
global text_str |
|
|
global chat |
|
|
global user_role |
|
|
global audio_output |
|
|
global audio_output1 |
|
|
global vad_output |
|
|
global asr_output_str |
|
|
global start_record_time |
|
|
global sids |
|
|
global spembs |
|
|
global latency_ASR |
|
|
global latency_LM |
|
|
global latency_TTS |
|
|
global LLM_response_arr |
|
|
global total_response_arr |
|
|
if stream is None: |
|
|
|
|
|
for ( |
|
|
_, |
|
|
_, |
|
|
_, |
|
|
_, |
|
|
asr_output_box, |
|
|
text_box, |
|
|
audio_box, |
|
|
_, |
|
|
_, |
|
|
) in dialogue_model.handle_type_selection( |
|
|
type_option, TTS_option, ASR_option, LLM_option |
|
|
): |
|
|
gr.Info("The models are being reloaded due to a browser refresh.") |
|
|
yield (stream, asr_output_box, text_box, audio_box, gr.Audio(visible=False)) |
|
|
stream = y |
|
|
text_str = "" |
|
|
audio_output = None |
|
|
audio_output1 = None |
|
|
else: |
|
|
stream = np.concatenate((stream, y)) |
|
|
|
|
|
dialogue_model.chat.init_chat( |
|
|
{ |
|
|
"role": "system", |
|
|
"content": ( |
|
|
input_text |
|
|
), |
|
|
} |
|
|
) |
|
|
( |
|
|
asr_output_str, |
|
|
text_str, |
|
|
audio_output, |
|
|
audio_output1, |
|
|
latency_ASR, |
|
|
latency_LM, |
|
|
latency_TTS, |
|
|
stream, |
|
|
change, |
|
|
) = dialogue_model( |
|
|
y, |
|
|
sr, |
|
|
stream, |
|
|
asr_output_str, |
|
|
text_str, |
|
|
audio_output, |
|
|
audio_output1, |
|
|
latency_ASR, |
|
|
latency_LM, |
|
|
latency_TTS, |
|
|
) |
|
|
text_str1 = text_str |
|
|
if change: |
|
|
print("Output changed") |
|
|
if asr_output_str != "": |
|
|
total_response_arr.append(asr_output_str.replace("\n", " ")) |
|
|
LLM_response_arr.append(text_str.replace("\n", " ")) |
|
|
total_response_arr.append(text_str.replace("\n", " ")) |
|
|
if (text_str != "") and (start_record_time is None): |
|
|
start_record_time = time.time() |
|
|
elif start_record_time is not None: |
|
|
current_record_time = time.time() |
|
|
if current_record_time - start_record_time > 300: |
|
|
gr.Info( |
|
|
"Conversations are limited to 5 minutes. " |
|
|
"The session will restart in approximately 60 seconds. " |
|
|
"Please wait for the demo to reset. " |
|
|
"Close this message once you have read it.", |
|
|
duration=None, |
|
|
) |
|
|
yield stream, gr.Textbox(visible=False), gr.Textbox( |
|
|
visible=False |
|
|
), gr.Audio(visible=False), gr.Audio(visible=False) |
|
|
dialogue_model.chat.buffer = [] |
|
|
text_str = "" |
|
|
audio_output = None |
|
|
audio_output1 = None |
|
|
asr_output_str = "" |
|
|
start_record_time = None |
|
|
LLM_response_arr = [] |
|
|
total_response_arr = [] |
|
|
shutil.rmtree("flagged_data_points") |
|
|
os.mkdir("flagged_data_points") |
|
|
yield (stream, asr_output_str, text_str1, audio_output, audio_output1) |
|
|
yield stream, gr.Textbox(visible=True), gr.Textbox(visible=True), gr.Audio( |
|
|
visible=True |
|
|
), gr.Audio(visible=False) |
|
|
|
|
|
yield (stream, asr_output_str, text_str1, audio_output, audio_output1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
api = HfApi() |
|
|
nltk.download("averaged_perceptron_tagger_eng") |
|
|
start_warmup() |
|
|
default_instruct=( |
|
|
"You are a helpful and friendly AI " |
|
|
"assistant. " |
|
|
"You are polite, respectful, and aim to " |
|
|
"provide concise and complete responses of " |
|
|
"less than 15 words." |
|
|
) |
|
|
import pandas as pd |
|
|
examples = pd.DataFrame([ |
|
|
["General Purpose Conversation", default_instruct], |
|
|
["Translation", "You are a translator. Translate user text into English."], |
|
|
["General Purpose Conversation with Disfluencies", "Please reply to user with lot of filler words like ummm, so"], |
|
|
["Summarization", "You are summarizer. Summarize user's utterance."] |
|
|
], columns=["Task", "LLM Prompt"]) |
|
|
with gr.Blocks( |
|
|
title="E2E Spoken Dialog System", |
|
|
) as demo: |
|
|
with gr.Row(): |
|
|
gr.Markdown( |
|
|
""" |
|
|
## ESPnet-SDS |
|
|
Welcome to our unified web interface for various cascaded and |
|
|
E2E spoken dialogue systems built using ESPnet-SDS toolkit, |
|
|
supporting real-time automated evaluation metrics, and |
|
|
human-in-the-loop feedback collection. |
|
|
|
|
|
For more details on how to use the app, refer to the [README] |
|
|
(https://github.com/siddhu001/espnet/tree/sds_demo_recipe/egs2/TEMPLATE/sds1#how-to-use). |
|
|
""" |
|
|
) |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
user_audio = gr.Audio( |
|
|
sources=["microphone"], |
|
|
streaming=True, |
|
|
waveform_options=gr.WaveformOptions(sample_rate=16000), |
|
|
) |
|
|
input_text=gr.Textbox( |
|
|
label="LLM prompt", |
|
|
visible=True, |
|
|
interactive=True, |
|
|
value=default_instruct |
|
|
) |
|
|
with gr.Row(): |
|
|
type_radio = gr.Radio( |
|
|
choices=["Cascaded"], |
|
|
label="Choose type of Spoken Dialog:", |
|
|
value="Cascaded", |
|
|
) |
|
|
with gr.Row(): |
|
|
ASR_radio = gr.Radio( |
|
|
choices=ASR_options, |
|
|
label="Choose ASR:", |
|
|
value=ASR_name, |
|
|
) |
|
|
with gr.Row(): |
|
|
LLM_radio = gr.Radio( |
|
|
choices=LLM_options, |
|
|
label="Choose LLM:", |
|
|
value=LLM_name, |
|
|
) |
|
|
with gr.Row(): |
|
|
radio = gr.Radio( |
|
|
choices=TTS_options, |
|
|
label="Choose TTS:", |
|
|
value=TTS_name, |
|
|
) |
|
|
with gr.Row(): |
|
|
E2Eradio = gr.Radio( |
|
|
choices=["mini-omni"], |
|
|
label="Choose E2E model:", |
|
|
value="mini-omni", |
|
|
visible=False, |
|
|
) |
|
|
with gr.Column(scale=1): |
|
|
output_audio = gr.Audio(label="Output", autoplay=True, visible=True, interactive=False) |
|
|
output_audio1 = gr.Audio(label="Output1", autoplay=False, visible=False, interactive=False) |
|
|
output_asr_text = gr.Textbox(label="ASR output", interactive=False) |
|
|
output_text = gr.Textbox(label="LLM output", interactive=False) |
|
|
eval_radio = gr.Radio( |
|
|
choices=[ |
|
|
"Latency", |
|
|
"TTS Intelligibility", |
|
|
"TTS Speech Quality", |
|
|
"ASR WER", |
|
|
"Text Dialog Metrics", |
|
|
], |
|
|
label="Choose Evaluation metrics:", |
|
|
visible=False, |
|
|
) |
|
|
eval_radio_E2E = gr.Radio( |
|
|
choices=[ |
|
|
"Latency", |
|
|
"TTS Intelligibility", |
|
|
"TTS Speech Quality", |
|
|
"Text Dialog Metrics", |
|
|
], |
|
|
label="Choose Evaluation metrics:", |
|
|
visible=False, |
|
|
) |
|
|
output_eval_text = gr.Textbox(label="Evaluation Results", visible=False) |
|
|
state = gr.State(value=None) |
|
|
|
|
|
|
|
|
natural_response = gr.Textbox( |
|
|
label="natural_response", visible=False, interactive=False |
|
|
) |
|
|
diversity_response = gr.Textbox( |
|
|
label="diversity_response", visible=False, interactive=False |
|
|
) |
|
|
ip_address = gr.Textbox(label="ip_address", visible=False, interactive=False) |
|
|
user_audio.stream( |
|
|
transcribe, |
|
|
inputs=[state, user_audio, radio, ASR_radio, LLM_radio, type_radio, input_text], |
|
|
outputs=[state, output_asr_text, output_text, output_audio, output_audio1], |
|
|
) |
|
|
radio.change( |
|
|
fn=dialogue_model.handle_TTS_selection, |
|
|
inputs=[radio], |
|
|
outputs=[output_asr_text, output_text, output_audio], |
|
|
) |
|
|
LLM_radio.change( |
|
|
fn=dialogue_model.handle_LLM_selection, |
|
|
inputs=[LLM_radio], |
|
|
outputs=[output_asr_text, output_text, output_audio], |
|
|
) |
|
|
ASR_radio.change( |
|
|
fn=dialogue_model.handle_ASR_selection, |
|
|
inputs=[ASR_radio], |
|
|
outputs=[output_asr_text, output_text, output_audio], |
|
|
) |
|
|
type_radio.change( |
|
|
fn=dialogue_model.handle_type_selection, |
|
|
inputs=[type_radio, radio, ASR_radio, LLM_radio], |
|
|
outputs=[ |
|
|
radio, |
|
|
ASR_radio, |
|
|
LLM_radio, |
|
|
E2Eradio, |
|
|
output_asr_text, |
|
|
output_text, |
|
|
output_audio, |
|
|
eval_radio, |
|
|
eval_radio_E2E, |
|
|
], |
|
|
) |
|
|
output_audio.play( |
|
|
flash_buttons, [], [natural_response, diversity_response] |
|
|
) |
|
|
|
|
|
demo.queue(max_size=10, default_concurrency_limit=1) |
|
|
demo.launch(debug=True) |
|
|
|