Spaces:
Sleeping
Sleeping
remove streaming
Browse files
app.py
CHANGED
|
@@ -37,14 +37,8 @@ latency_ASR = 0.0
|
|
| 37 |
latency_LM = 0.0
|
| 38 |
latency_TTS = 0.0
|
| 39 |
|
| 40 |
-
text_str = ""
|
| 41 |
-
asr_output_str = ""
|
| 42 |
-
vad_output = None
|
| 43 |
-
audio_output = None
|
| 44 |
-
audio_output1 = None
|
| 45 |
LLM_response_arr = []
|
| 46 |
total_response_arr = []
|
| 47 |
-
start_record_time = None
|
| 48 |
enable_btn = gr.Button(interactive=True, visible=True)
|
| 49 |
|
| 50 |
# ------------------------
|
|
@@ -289,9 +283,8 @@ def flash_buttons():
|
|
| 289 |
|
| 290 |
|
| 291 |
@spaces.GPU
|
| 292 |
-
def
|
| 293 |
-
|
| 294 |
-
new_chunk: Tuple[int, np.ndarray],
|
| 295 |
TTS_option: str,
|
| 296 |
ASR_option: str,
|
| 297 |
LLM_option: str,
|
|
@@ -299,88 +292,62 @@ def transcribe(
|
|
| 299 |
input_text: str,
|
| 300 |
):
|
| 301 |
"""
|
| 302 |
-
Processes
|
| 303 |
|
| 304 |
-
This function handles the transcription of audio
|
| 305 |
-
and its transformation through a cascaded
|
| 306 |
-
|
| 307 |
-
It dynamically updates the transcription, text generation,
|
| 308 |
-
and synthesized speech output, while managing global states and latencies.
|
| 309 |
|
| 310 |
Args:
|
| 311 |
-
|
| 312 |
-
`
|
| 313 |
-
|
| 314 |
-
- `sr`: Sample rate of the new audio chunk.
|
| 315 |
-
- `y`: New audio data chunk.
|
| 316 |
TTS_option: Selected TTS model option.
|
| 317 |
ASR_option: Selected ASR model option.
|
| 318 |
LLM_option: Selected LLM model option.
|
| 319 |
type_option: Type of system ("Cascaded" or "E2E").
|
|
|
|
| 320 |
|
| 321 |
-
|
| 322 |
-
Tuple[
|
| 323 |
-
Optional[Tuple[int, np.ndarray]], Optional[Tuple[int, np.ndarray]]]:
|
| 324 |
A tuple containing:
|
| 325 |
-
-
|
| 326 |
-
-
|
| 327 |
-
-
|
| 328 |
-
- Audio output as a tuple of sample rate and audio waveform.
|
| 329 |
-
- User input audio as a tuple of sample rate and audio waveform.
|
| 330 |
|
| 331 |
Notes:
|
| 332 |
-
-
|
| 333 |
-
- Updates
|
| 334 |
-
- Manages latencies.
|
| 335 |
"""
|
| 336 |
-
sr, y = new_chunk
|
| 337 |
-
global text_str
|
| 338 |
-
global chat
|
| 339 |
-
global user_role
|
| 340 |
-
global audio_output
|
| 341 |
-
global audio_output1
|
| 342 |
-
global vad_output
|
| 343 |
-
global asr_output_str
|
| 344 |
-
global start_record_time
|
| 345 |
-
global sids
|
| 346 |
-
global spembs
|
| 347 |
global latency_ASR
|
| 348 |
global latency_LM
|
| 349 |
global latency_TTS
|
| 350 |
global LLM_response_arr
|
| 351 |
global total_response_arr
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
audio_box,
|
| 362 |
-
_,
|
| 363 |
-
_,
|
| 364 |
-
) in dialogue_model.handle_type_selection(
|
| 365 |
-
type_option, TTS_option, ASR_option, LLM_option
|
| 366 |
-
):
|
| 367 |
-
gr.Info("The models are being reloaded due to a browser refresh.")
|
| 368 |
-
yield (stream, asr_output_box, text_box, audio_box, gr.Audio(visible=False))
|
| 369 |
-
stream = y
|
| 370 |
-
text_str = ""
|
| 371 |
-
audio_output = None
|
| 372 |
-
audio_output1 = None
|
| 373 |
-
else:
|
| 374 |
-
stream = np.concatenate((stream, y))
|
| 375 |
-
# import pdb;pdb.set_trace()
|
| 376 |
dialogue_model.chat.init_chat(
|
| 377 |
{
|
| 378 |
"role": "system",
|
| 379 |
-
"content":
|
| 380 |
-
input_text
|
| 381 |
-
),
|
| 382 |
}
|
| 383 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
(
|
| 385 |
asr_output_str,
|
| 386 |
text_str,
|
|
@@ -403,44 +370,16 @@ def transcribe(
|
|
| 403 |
latency_LM,
|
| 404 |
latency_TTS,
|
| 405 |
)
|
| 406 |
-
|
|
|
|
| 407 |
if change:
|
| 408 |
-
print("
|
| 409 |
if asr_output_str != "":
|
| 410 |
total_response_arr.append(asr_output_str.replace("\n", " "))
|
| 411 |
LLM_response_arr.append(text_str.replace("\n", " "))
|
| 412 |
total_response_arr.append(text_str.replace("\n", " "))
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
elif start_record_time is not None:
|
| 416 |
-
current_record_time = time.time()
|
| 417 |
-
if current_record_time - start_record_time > 300:
|
| 418 |
-
gr.Info(
|
| 419 |
-
"Conversations are limited to 5 minutes. "
|
| 420 |
-
"The session will restart in approximately 60 seconds. "
|
| 421 |
-
"Please wait for the demo to reset. "
|
| 422 |
-
"Close this message once you have read it.",
|
| 423 |
-
duration=None,
|
| 424 |
-
)
|
| 425 |
-
yield stream, gr.Textbox(visible=False), gr.Textbox(
|
| 426 |
-
visible=False
|
| 427 |
-
), gr.Audio(visible=False), gr.Audio(visible=False)
|
| 428 |
-
dialogue_model.chat.buffer = []
|
| 429 |
-
text_str = ""
|
| 430 |
-
audio_output = None
|
| 431 |
-
audio_output1 = None
|
| 432 |
-
asr_output_str = ""
|
| 433 |
-
start_record_time = None
|
| 434 |
-
LLM_response_arr = []
|
| 435 |
-
total_response_arr = []
|
| 436 |
-
shutil.rmtree("flagged_data_points")
|
| 437 |
-
os.mkdir("flagged_data_points")
|
| 438 |
-
yield (stream, asr_output_str, text_str1, audio_output, audio_output1)
|
| 439 |
-
yield stream, gr.Textbox(visible=True), gr.Textbox(visible=True), gr.Audio(
|
| 440 |
-
visible=True
|
| 441 |
-
), gr.Audio(visible=False)
|
| 442 |
-
|
| 443 |
-
yield (stream, asr_output_str, text_str1, audio_output, audio_output1)
|
| 444 |
|
| 445 |
|
| 446 |
# ------------------------
|
|
@@ -464,28 +403,37 @@ examples = pd.DataFrame([
|
|
| 464 |
["Summarization", "You are summarizer. Summarize user's utterance."]
|
| 465 |
], columns=["Task", "LLM Prompt"])
|
| 466 |
with gr.Blocks(
|
| 467 |
-
title="
|
| 468 |
) as demo:
|
| 469 |
with gr.Row():
|
| 470 |
gr.Markdown(
|
| 471 |
"""
|
| 472 |
-
## ESPnet-SDS
|
| 473 |
-
Welcome to our
|
| 474 |
-
E2E spoken dialogue systems built using ESPnet-SDS
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 479 |
(https://github.com/siddhu001/espnet/tree/sds_demo_recipe/egs2/TEMPLATE/sds1#how-to-use).
|
| 480 |
"""
|
| 481 |
)
|
| 482 |
with gr.Row():
|
| 483 |
with gr.Column(scale=1):
|
| 484 |
user_audio = gr.Audio(
|
| 485 |
-
sources=["microphone"],
|
| 486 |
-
|
| 487 |
-
|
| 488 |
)
|
|
|
|
| 489 |
input_text=gr.Textbox(
|
| 490 |
label="LLM prompt",
|
| 491 |
visible=True,
|
|
@@ -524,10 +472,9 @@ with gr.Blocks(
|
|
| 524 |
visible=False,
|
| 525 |
)
|
| 526 |
with gr.Column(scale=1):
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
output_text = gr.Textbox(label="LLM output", interactive=False)
|
| 531 |
eval_radio = gr.Radio(
|
| 532 |
choices=[
|
| 533 |
"Latency",
|
|
@@ -550,7 +497,6 @@ with gr.Blocks(
|
|
| 550 |
visible=False,
|
| 551 |
)
|
| 552 |
output_eval_text = gr.Textbox(label="Evaluation Results", visible=False)
|
| 553 |
-
state = gr.State(value=None)
|
| 554 |
|
| 555 |
|
| 556 |
natural_response = gr.Textbox(
|
|
@@ -560,10 +506,12 @@ with gr.Blocks(
|
|
| 560 |
label="diversity_response", visible=False, interactive=False
|
| 561 |
)
|
| 562 |
ip_address = gr.Textbox(label="ip_address", visible=False, interactive=False)
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
|
|
|
|
|
|
| 567 |
)
|
| 568 |
radio.change(
|
| 569 |
fn=dialogue_model.handle_TTS_selection,
|
|
|
|
| 37 |
latency_LM = 0.0
|
| 38 |
latency_TTS = 0.0
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
LLM_response_arr = []
|
| 41 |
total_response_arr = []
|
|
|
|
| 42 |
enable_btn = gr.Button(interactive=True, visible=True)
|
| 43 |
|
| 44 |
# ------------------------
|
|
|
|
| 283 |
|
| 284 |
|
| 285 |
@spaces.GPU
|
| 286 |
+
def process_audio_file(
|
| 287 |
+
audio_file: Optional[Tuple[int, np.ndarray]],
|
|
|
|
| 288 |
TTS_option: str,
|
| 289 |
ASR_option: str,
|
| 290 |
LLM_option: str,
|
|
|
|
| 292 |
input_text: str,
|
| 293 |
):
|
| 294 |
"""
|
| 295 |
+
Processes a recorded audio file through the dialogue system.
|
| 296 |
|
| 297 |
+
This function handles the transcription of an uploaded audio file
|
| 298 |
+
and its transformation through a cascaded conversational AI system.
|
| 299 |
+
It processes the entire audio file at once (offline mode).
|
|
|
|
|
|
|
| 300 |
|
| 301 |
Args:
|
| 302 |
+
audio_file: A tuple containing:
|
| 303 |
+
- `sr`: Sample rate of the audio file.
|
| 304 |
+
- `y`: Audio data array.
|
|
|
|
|
|
|
| 305 |
TTS_option: Selected TTS model option.
|
| 306 |
ASR_option: Selected ASR model option.
|
| 307 |
LLM_option: Selected LLM model option.
|
| 308 |
type_option: Type of system ("Cascaded" or "E2E").
|
| 309 |
+
input_text: System prompt for the LLM.
|
| 310 |
|
| 311 |
+
Returns:
|
| 312 |
+
Tuple[str, str, Optional[Tuple[int, np.ndarray]]]:
|
|
|
|
| 313 |
A tuple containing:
|
| 314 |
+
- ASR output text (transcription).
|
| 315 |
+
- Generated LLM output text (response).
|
| 316 |
+
- Audio output as a tuple of sample rate and audio waveform (TTS).
|
|
|
|
|
|
|
| 317 |
|
| 318 |
Notes:
|
| 319 |
+
- Processes the complete audio file in one go.
|
| 320 |
+
- Updates latency metrics.
|
|
|
|
| 321 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 322 |
global latency_ASR
|
| 323 |
global latency_LM
|
| 324 |
global latency_TTS
|
| 325 |
global LLM_response_arr
|
| 326 |
global total_response_arr
|
| 327 |
+
|
| 328 |
+
if audio_file is None:
|
| 329 |
+
gr.Info("Please upload an audio file.")
|
| 330 |
+
return "", "", None
|
| 331 |
+
|
| 332 |
+
# Extract sample rate and audio data
|
| 333 |
+
sr, y = audio_file
|
| 334 |
+
|
| 335 |
+
# Initialize chat with system prompt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
dialogue_model.chat.init_chat(
|
| 337 |
{
|
| 338 |
"role": "system",
|
| 339 |
+
"content": input_text,
|
|
|
|
|
|
|
| 340 |
}
|
| 341 |
)
|
| 342 |
+
|
| 343 |
+
# Initialize variables
|
| 344 |
+
asr_output_str = ""
|
| 345 |
+
text_str = ""
|
| 346 |
+
audio_output = None
|
| 347 |
+
audio_output1 = None
|
| 348 |
+
stream = y # Use entire audio file as stream
|
| 349 |
+
|
| 350 |
+
# Process the audio file
|
| 351 |
(
|
| 352 |
asr_output_str,
|
| 353 |
text_str,
|
|
|
|
| 370 |
latency_LM,
|
| 371 |
latency_TTS,
|
| 372 |
)
|
| 373 |
+
|
| 374 |
+
# Store results
|
| 375 |
if change:
|
| 376 |
+
print("Processing completed")
|
| 377 |
if asr_output_str != "":
|
| 378 |
total_response_arr.append(asr_output_str.replace("\n", " "))
|
| 379 |
LLM_response_arr.append(text_str.replace("\n", " "))
|
| 380 |
total_response_arr.append(text_str.replace("\n", " "))
|
| 381 |
+
|
| 382 |
+
return asr_output_str, text_str, audio_output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 383 |
|
| 384 |
|
| 385 |
# ------------------------
|
|
|
|
| 403 |
["Summarization", "You are summarizer. Summarize user's utterance."]
|
| 404 |
], columns=["Task", "LLM Prompt"])
|
| 405 |
with gr.Blocks(
|
| 406 |
+
title="ESPnet-SDS Offline Audio Processing",
|
| 407 |
) as demo:
|
| 408 |
with gr.Row():
|
| 409 |
gr.Markdown(
|
| 410 |
"""
|
| 411 |
+
## ESPnet-SDS (Offline Mode)
|
| 412 |
+
Welcome to our offline audio processing interface for various cascaded and
|
| 413 |
+
E2E spoken dialogue systems built using ESPnet-SDS toolkit.
|
| 414 |
+
|
| 415 |
+
**How to use:**
|
| 416 |
+
1. Upload or record an audio file
|
| 417 |
+
2. Configure the LLM prompt and select models
|
| 418 |
+
3. Click "Process Audio" to transcribe and generate a response
|
| 419 |
+
|
| 420 |
+
The system will:
|
| 421 |
+
- **Transcribe** your audio using ASR (Automatic Speech Recognition)
|
| 422 |
+
- **Generate** a response using the selected LLM
|
| 423 |
+
- **Synthesize** speech output using TTS (Text-to-Speech)
|
| 424 |
+
|
| 425 |
+
For more details, refer to the [README]
|
| 426 |
(https://github.com/siddhu001/espnet/tree/sds_demo_recipe/egs2/TEMPLATE/sds1#how-to-use).
|
| 427 |
"""
|
| 428 |
)
|
| 429 |
with gr.Row():
|
| 430 |
with gr.Column(scale=1):
|
| 431 |
user_audio = gr.Audio(
|
| 432 |
+
sources=["upload", "microphone"],
|
| 433 |
+
type="numpy",
|
| 434 |
+
label="Upload or Record Audio File",
|
| 435 |
)
|
| 436 |
+
process_btn = gr.Button("Process Audio", variant="primary")
|
| 437 |
input_text=gr.Textbox(
|
| 438 |
label="LLM prompt",
|
| 439 |
visible=True,
|
|
|
|
| 472 |
visible=False,
|
| 473 |
)
|
| 474 |
with gr.Column(scale=1):
|
| 475 |
+
output_asr_text = gr.Textbox(label="ASR Transcription", interactive=False)
|
| 476 |
+
output_text = gr.Textbox(label="LLM Response", interactive=False)
|
| 477 |
+
output_audio = gr.Audio(label="TTS Output", autoplay=True, visible=True, interactive=False)
|
|
|
|
| 478 |
eval_radio = gr.Radio(
|
| 479 |
choices=[
|
| 480 |
"Latency",
|
|
|
|
| 497 |
visible=False,
|
| 498 |
)
|
| 499 |
output_eval_text = gr.Textbox(label="Evaluation Results", visible=False)
|
|
|
|
| 500 |
|
| 501 |
|
| 502 |
natural_response = gr.Textbox(
|
|
|
|
| 506 |
label="diversity_response", visible=False, interactive=False
|
| 507 |
)
|
| 508 |
ip_address = gr.Textbox(label="ip_address", visible=False, interactive=False)
|
| 509 |
+
|
| 510 |
+
# Process button click event
|
| 511 |
+
process_btn.click(
|
| 512 |
+
process_audio_file,
|
| 513 |
+
inputs=[user_audio, radio, ASR_radio, LLM_radio, type_radio, input_text],
|
| 514 |
+
outputs=[output_asr_text, output_text, output_audio],
|
| 515 |
)
|
| 516 |
radio.change(
|
| 517 |
fn=dialogue_model.handle_TTS_selection,
|