owaski commited on
Commit
1acf95b
·
1 Parent(s): 8818ae1

remove streaming

Browse files
Files changed (1) hide show
  1. app.py +70 -122
app.py CHANGED
@@ -37,14 +37,8 @@ latency_ASR = 0.0
37
  latency_LM = 0.0
38
  latency_TTS = 0.0
39
 
40
- text_str = ""
41
- asr_output_str = ""
42
- vad_output = None
43
- audio_output = None
44
- audio_output1 = None
45
  LLM_response_arr = []
46
  total_response_arr = []
47
- start_record_time = None
48
  enable_btn = gr.Button(interactive=True, visible=True)
49
 
50
  # ------------------------
@@ -289,9 +283,8 @@ def flash_buttons():
289
 
290
 
291
  @spaces.GPU
292
- def transcribe(
293
- stream: np.ndarray,
294
- new_chunk: Tuple[int, np.ndarray],
295
  TTS_option: str,
296
  ASR_option: str,
297
  LLM_option: str,
@@ -299,88 +292,62 @@ def transcribe(
299
  input_text: str,
300
  ):
301
  """
302
- Processes and transcribes an audio stream in real-time.
303
 
304
- This function handles the transcription of audio input
305
- and its transformation through a cascaded
306
- or E2E conversational AI system.
307
- It dynamically updates the transcription, text generation,
308
- and synthesized speech output, while managing global states and latencies.
309
 
310
  Args:
311
- stream: The current audio stream buffer.
312
- `None` if the stream is being reset (e.g., after user refresh).
313
- new_chunk: A tuple containing:
314
- - `sr`: Sample rate of the new audio chunk.
315
- - `y`: New audio data chunk.
316
  TTS_option: Selected TTS model option.
317
  ASR_option: Selected ASR model option.
318
  LLM_option: Selected LLM model option.
319
  type_option: Type of system ("Cascaded" or "E2E").
 
320
 
321
- Yields:
322
- Tuple[Optional[np.ndarray], Optional[str], Optional[str],
323
- Optional[Tuple[int, np.ndarray]], Optional[Tuple[int, np.ndarray]]]:
324
  A tuple containing:
325
- - Updated stream buffer.
326
- - ASR output text.
327
- - Generated LLM output text.
328
- - Audio output as a tuple of sample rate and audio waveform.
329
- - User input audio as a tuple of sample rate and audio waveform.
330
 
331
  Notes:
332
- - Resets the session if the transcription exceeds 5 minutes.
333
- - Updates the Gradio interface elements dynamically.
334
- - Manages latencies.
335
  """
336
- sr, y = new_chunk
337
- global text_str
338
- global chat
339
- global user_role
340
- global audio_output
341
- global audio_output1
342
- global vad_output
343
- global asr_output_str
344
- global start_record_time
345
- global sids
346
- global spembs
347
  global latency_ASR
348
  global latency_LM
349
  global latency_TTS
350
  global LLM_response_arr
351
  global total_response_arr
352
- if stream is None:
353
- # Handle user refresh
354
- for (
355
- _,
356
- _,
357
- _,
358
- _,
359
- asr_output_box,
360
- text_box,
361
- audio_box,
362
- _,
363
- _,
364
- ) in dialogue_model.handle_type_selection(
365
- type_option, TTS_option, ASR_option, LLM_option
366
- ):
367
- gr.Info("The models are being reloaded due to a browser refresh.")
368
- yield (stream, asr_output_box, text_box, audio_box, gr.Audio(visible=False))
369
- stream = y
370
- text_str = ""
371
- audio_output = None
372
- audio_output1 = None
373
- else:
374
- stream = np.concatenate((stream, y))
375
- # import pdb;pdb.set_trace()
376
  dialogue_model.chat.init_chat(
377
  {
378
  "role": "system",
379
- "content": (
380
- input_text
381
- ),
382
  }
383
  )
 
 
 
 
 
 
 
 
 
384
  (
385
  asr_output_str,
386
  text_str,
@@ -403,44 +370,16 @@ def transcribe(
403
  latency_LM,
404
  latency_TTS,
405
  )
406
- text_str1 = text_str
 
407
  if change:
408
- print("Output changed")
409
  if asr_output_str != "":
410
  total_response_arr.append(asr_output_str.replace("\n", " "))
411
  LLM_response_arr.append(text_str.replace("\n", " "))
412
  total_response_arr.append(text_str.replace("\n", " "))
413
- if (text_str != "") and (start_record_time is None):
414
- start_record_time = time.time()
415
- elif start_record_time is not None:
416
- current_record_time = time.time()
417
- if current_record_time - start_record_time > 300:
418
- gr.Info(
419
- "Conversations are limited to 5 minutes. "
420
- "The session will restart in approximately 60 seconds. "
421
- "Please wait for the demo to reset. "
422
- "Close this message once you have read it.",
423
- duration=None,
424
- )
425
- yield stream, gr.Textbox(visible=False), gr.Textbox(
426
- visible=False
427
- ), gr.Audio(visible=False), gr.Audio(visible=False)
428
- dialogue_model.chat.buffer = []
429
- text_str = ""
430
- audio_output = None
431
- audio_output1 = None
432
- asr_output_str = ""
433
- start_record_time = None
434
- LLM_response_arr = []
435
- total_response_arr = []
436
- shutil.rmtree("flagged_data_points")
437
- os.mkdir("flagged_data_points")
438
- yield (stream, asr_output_str, text_str1, audio_output, audio_output1)
439
- yield stream, gr.Textbox(visible=True), gr.Textbox(visible=True), gr.Audio(
440
- visible=True
441
- ), gr.Audio(visible=False)
442
-
443
- yield (stream, asr_output_str, text_str1, audio_output, audio_output1)
444
 
445
 
446
  # ------------------------
@@ -464,28 +403,37 @@ examples = pd.DataFrame([
464
  ["Summarization", "You are summarizer. Summarize user's utterance."]
465
  ], columns=["Task", "LLM Prompt"])
466
  with gr.Blocks(
467
- title="E2E Spoken Dialog System",
468
  ) as demo:
469
  with gr.Row():
470
  gr.Markdown(
471
  """
472
- ## ESPnet-SDS
473
- Welcome to our unified web interface for various cascaded and
474
- E2E spoken dialogue systems built using ESPnet-SDS toolkit,
475
- supporting real-time automated evaluation metrics, and
476
- human-in-the-loop feedback collection.
477
-
478
- For more details on how to use the app, refer to the [README]
 
 
 
 
 
 
 
 
479
  (https://github.com/siddhu001/espnet/tree/sds_demo_recipe/egs2/TEMPLATE/sds1#how-to-use).
480
  """
481
  )
482
  with gr.Row():
483
  with gr.Column(scale=1):
484
  user_audio = gr.Audio(
485
- sources=["microphone"],
486
- streaming=True,
487
- waveform_options=gr.WaveformOptions(sample_rate=16000),
488
  )
 
489
  input_text=gr.Textbox(
490
  label="LLM prompt",
491
  visible=True,
@@ -524,10 +472,9 @@ with gr.Blocks(
524
  visible=False,
525
  )
526
  with gr.Column(scale=1):
527
- output_audio = gr.Audio(label="Output", autoplay=True, visible=True, interactive=False)
528
- output_audio1 = gr.Audio(label="Output1", autoplay=False, visible=False, interactive=False)
529
- output_asr_text = gr.Textbox(label="ASR output", interactive=False)
530
- output_text = gr.Textbox(label="LLM output", interactive=False)
531
  eval_radio = gr.Radio(
532
  choices=[
533
  "Latency",
@@ -550,7 +497,6 @@ with gr.Blocks(
550
  visible=False,
551
  )
552
  output_eval_text = gr.Textbox(label="Evaluation Results", visible=False)
553
- state = gr.State(value=None)
554
 
555
 
556
  natural_response = gr.Textbox(
@@ -560,10 +506,12 @@ with gr.Blocks(
560
  label="diversity_response", visible=False, interactive=False
561
  )
562
  ip_address = gr.Textbox(label="ip_address", visible=False, interactive=False)
563
- user_audio.stream(
564
- transcribe,
565
- inputs=[state, user_audio, radio, ASR_radio, LLM_radio, type_radio, input_text],
566
- outputs=[state, output_asr_text, output_text, output_audio, output_audio1],
 
 
567
  )
568
  radio.change(
569
  fn=dialogue_model.handle_TTS_selection,
 
37
  latency_LM = 0.0
38
  latency_TTS = 0.0
39
 
 
 
 
 
 
40
  LLM_response_arr = []
41
  total_response_arr = []
 
42
  enable_btn = gr.Button(interactive=True, visible=True)
43
 
44
  # ------------------------
 
283
 
284
 
285
  @spaces.GPU
286
+ def process_audio_file(
287
+ audio_file: Optional[Tuple[int, np.ndarray]],
 
288
  TTS_option: str,
289
  ASR_option: str,
290
  LLM_option: str,
 
292
  input_text: str,
293
  ):
294
  """
295
+ Processes a recorded audio file through the dialogue system.
296
 
297
+ This function handles the transcription of an uploaded audio file
298
+ and its transformation through a cascaded conversational AI system.
299
+ It processes the entire audio file at once (offline mode).
 
 
300
 
301
  Args:
302
+ audio_file: A tuple containing:
303
+ - `sr`: Sample rate of the audio file.
304
+ - `y`: Audio data array.
 
 
305
  TTS_option: Selected TTS model option.
306
  ASR_option: Selected ASR model option.
307
  LLM_option: Selected LLM model option.
308
  type_option: Type of system ("Cascaded" or "E2E").
309
+ input_text: System prompt for the LLM.
310
 
311
+ Returns:
312
+ Tuple[str, str, Optional[Tuple[int, np.ndarray]]]:
 
313
  A tuple containing:
314
+ - ASR output text (transcription).
315
+ - Generated LLM output text (response).
316
+ - Audio output as a tuple of sample rate and audio waveform (TTS).
 
 
317
 
318
  Notes:
319
+ - Processes the complete audio file in one go.
320
+ - Updates latency metrics.
 
321
  """
 
 
 
 
 
 
 
 
 
 
 
322
  global latency_ASR
323
  global latency_LM
324
  global latency_TTS
325
  global LLM_response_arr
326
  global total_response_arr
327
+
328
+ if audio_file is None:
329
+ gr.Info("Please upload an audio file.")
330
+ return "", "", None
331
+
332
+ # Extract sample rate and audio data
333
+ sr, y = audio_file
334
+
335
+ # Initialize chat with system prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
  dialogue_model.chat.init_chat(
337
  {
338
  "role": "system",
339
+ "content": input_text,
 
 
340
  }
341
  )
342
+
343
+ # Initialize variables
344
+ asr_output_str = ""
345
+ text_str = ""
346
+ audio_output = None
347
+ audio_output1 = None
348
+ stream = y # Use entire audio file as stream
349
+
350
+ # Process the audio file
351
  (
352
  asr_output_str,
353
  text_str,
 
370
  latency_LM,
371
  latency_TTS,
372
  )
373
+
374
+ # Store results
375
  if change:
376
+ print("Processing completed")
377
  if asr_output_str != "":
378
  total_response_arr.append(asr_output_str.replace("\n", " "))
379
  LLM_response_arr.append(text_str.replace("\n", " "))
380
  total_response_arr.append(text_str.replace("\n", " "))
381
+
382
+ return asr_output_str, text_str, audio_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
383
 
384
 
385
  # ------------------------
 
403
  ["Summarization", "You are summarizer. Summarize user's utterance."]
404
  ], columns=["Task", "LLM Prompt"])
405
  with gr.Blocks(
406
+ title="ESPnet-SDS Offline Audio Processing",
407
  ) as demo:
408
  with gr.Row():
409
  gr.Markdown(
410
  """
411
+ ## ESPnet-SDS (Offline Mode)
412
+ Welcome to our offline audio processing interface for various cascaded and
413
+ E2E spoken dialogue systems built using ESPnet-SDS toolkit.
414
+
415
+ **How to use:**
416
+ 1. Upload or record an audio file
417
+ 2. Configure the LLM prompt and select models
418
+ 3. Click "Process Audio" to transcribe and generate a response
419
+
420
+ The system will:
421
+ - **Transcribe** your audio using ASR (Automatic Speech Recognition)
422
+ - **Generate** a response using the selected LLM
423
+ - **Synthesize** speech output using TTS (Text-to-Speech)
424
+
425
+ For more details, refer to the [README]
426
  (https://github.com/siddhu001/espnet/tree/sds_demo_recipe/egs2/TEMPLATE/sds1#how-to-use).
427
  """
428
  )
429
  with gr.Row():
430
  with gr.Column(scale=1):
431
  user_audio = gr.Audio(
432
+ sources=["upload", "microphone"],
433
+ type="numpy",
434
+ label="Upload or Record Audio File",
435
  )
436
+ process_btn = gr.Button("Process Audio", variant="primary")
437
  input_text=gr.Textbox(
438
  label="LLM prompt",
439
  visible=True,
 
472
  visible=False,
473
  )
474
  with gr.Column(scale=1):
475
+ output_asr_text = gr.Textbox(label="ASR Transcription", interactive=False)
476
+ output_text = gr.Textbox(label="LLM Response", interactive=False)
477
+ output_audio = gr.Audio(label="TTS Output", autoplay=True, visible=True, interactive=False)
 
478
  eval_radio = gr.Radio(
479
  choices=[
480
  "Latency",
 
497
  visible=False,
498
  )
499
  output_eval_text = gr.Textbox(label="Evaluation Results", visible=False)
 
500
 
501
 
502
  natural_response = gr.Textbox(
 
506
  label="diversity_response", visible=False, interactive=False
507
  )
508
  ip_address = gr.Textbox(label="ip_address", visible=False, interactive=False)
509
+
510
+ # Process button click event
511
+ process_btn.click(
512
+ process_audio_file,
513
+ inputs=[user_audio, radio, ASR_radio, LLM_radio, type_radio, input_text],
514
+ outputs=[output_asr_text, output_text, output_audio],
515
  )
516
  radio.change(
517
  fn=dialogue_model.handle_TTS_selection,