herimor commited on
Commit
997a03d
·
1 Parent(s): 2f74ca1

Fix cold start. Add SharedGenerationState

Browse files
Files changed (2) hide show
  1. app.py +97 -24
  2. requirements.txt +1 -1
app.py CHANGED
@@ -24,6 +24,7 @@ from voxtream.utils.app import (
24
  CUSTOM_CSS,
25
  AppConfig,
26
  GenerationControl,
 
27
  SpeakingRateState,
28
  VisualizationState,
29
  build_low_latency_audio_head,
@@ -87,6 +88,7 @@ def demo_app(
87
  synthesize_fn,
88
  speaking_rate_state: SpeakingRateState,
89
  generation_control: GenerationControl,
 
90
  ):
91
  with gr.Blocks(
92
  css=CUSTOM_CSS,
@@ -175,13 +177,14 @@ def demo_app(
175
  text_progress = gr.HTML(
176
  render_text_progress(app_config, None), elem_id="text-progress-container"
177
  )
 
178
 
179
  def validate_inputs(audio, ttext):
180
  if not audio:
181
  return gr.update(
182
  visible=True, value="⚠️ Please provide a prompt audio."
183
  ), gr.update(interactive=False)
184
- if not ttext.strip():
185
  return gr.update(
186
  visible=True, value="⚠️ Please provide target text."
187
  ), gr.update(interactive=False)
@@ -197,10 +200,22 @@ def demo_app(
197
  inputs=prompt_enhancement,
198
  outputs=prompt_enhancement_msg,
199
  )
 
 
 
 
 
 
 
 
 
 
 
200
  speaking_rate_control.release(
201
- fn=lambda value: speaking_rate_state.update(value),
202
- inputs=speaking_rate_control,
203
  queue=False,
 
204
  )
205
 
206
  for inp in [prompt_audio, target_text]:
@@ -211,6 +226,7 @@ def demo_app(
211
  )
212
 
213
  def prepare_generation(speaking_rate, enable_rate):
 
214
  generation_control.start()
215
  speaking_rate_state.start(speaking_rate)
216
  return (
@@ -218,8 +234,9 @@ def demo_app(
218
  gr.update(interactive=False),
219
  empty_rate_plot(app_config, show_target=enable_rate),
220
  render_text_progress(app_config, None),
221
- render_audio_stream(app_config, session_id=uuid.uuid4().hex),
222
  *generation_button_updates(running=True),
 
223
  )
224
 
225
  submit_btn.click(
@@ -234,6 +251,7 @@ def demo_app(
234
  pause_btn,
235
  resume_btn,
236
  stop_btn,
 
237
  ],
238
  show_progress="hidden",
239
  ).then(
@@ -246,6 +264,7 @@ def demo_app(
246
  streaming_input,
247
  speaking_rate_control,
248
  enable_speaking_rate,
 
249
  ],
250
  outputs=[
251
  output_audio,
@@ -256,25 +275,29 @@ def demo_app(
256
  pause_btn,
257
  resume_btn,
258
  stop_btn,
 
259
  ],
260
  )
261
 
262
- def pause_generation():
263
  generation_control.pause()
 
264
  return generation_button_updates(running=True, paused=True)
265
 
266
- def resume_generation():
267
  generation_control.resume()
 
268
  return generation_button_updates(running=True)
269
 
270
- def stop_generation():
271
  generation_control.stop()
272
  speaking_rate_state.stop()
 
273
  return generation_button_updates(running=False)
274
 
275
  pause_btn.click(
276
  fn=pause_generation,
277
- inputs=[],
278
  outputs=[pause_btn, resume_btn, stop_btn],
279
  js=(
280
  "() => { if (window.voxtreamLowLatencyAudio) { "
@@ -284,7 +307,7 @@ def demo_app(
284
  )
285
  resume_btn.click(
286
  fn=resume_generation,
287
- inputs=[],
288
  outputs=[pause_btn, resume_btn, stop_btn],
289
  js=(
290
  "() => { if (window.voxtreamLowLatencyAudio) { "
@@ -294,7 +317,7 @@ def demo_app(
294
  )
295
  stop_btn.click(
296
  fn=stop_generation,
297
- inputs=[],
298
  outputs=[pause_btn, resume_btn, stop_btn],
299
  js=(
300
  "() => { if (window.voxtreamLowLatencyAudio) { "
@@ -303,8 +326,11 @@ def demo_app(
303
  queue=False,
304
  )
305
 
306
- clear_btn.click(
307
- fn=lambda: (
 
 
 
308
  gr.update(value=None),
309
  gr.update(value=""),
310
  gr.update(value=None, visible=False),
@@ -315,8 +341,12 @@ def demo_app(
315
  render_text_progress(app_config, None),
316
  render_audio_stream(app_config, session_id=uuid.uuid4().hex),
317
  *generation_button_updates(running=False),
318
- ),
319
- inputs=[],
 
 
 
 
320
  outputs=[
321
  prompt_audio,
322
  target_text,
@@ -330,6 +360,7 @@ def demo_app(
330
  pause_btn,
331
  resume_btn,
332
  stop_btn,
 
333
  ],
334
  )
335
 
@@ -354,19 +385,21 @@ def demo_app(
354
  pause_btn,
355
  resume_btn,
356
  stop_btn,
 
357
  ],
358
  fn=synthesize_fn,
359
  cache_examples=False,
360
  )
361
 
362
  ex.dataset.click(
363
- fn=lambda: clear_outputs(app_config),
364
  inputs=[],
365
  outputs=[
366
  output_audio,
367
  rate_plot,
368
  text_progress,
369
  stream_audio,
 
370
  ],
371
  queue=False,
372
  ).then(
@@ -376,7 +409,7 @@ def demo_app(
376
  queue=False,
377
  )
378
 
379
- demo.launch()
380
 
381
 
382
  def main():
@@ -435,6 +468,7 @@ def main():
435
  speech_generator = SpeechGenerator(config, spk_rate_config)
436
  speaking_rate_state = SpeakingRateState(app_config.speaking_rate_default)
437
  generation_control = GenerationControl()
 
438
  chunk_size = int(config.mimi_sr * app_config.min_chunk_sec)
439
 
440
  @spaces.GPU
@@ -446,13 +480,18 @@ def main():
446
  streaming_input,
447
  speaking_rate_control,
448
  enable_speaking_rate=True,
 
449
  ):
450
- stream_session_id = uuid.uuid4().hex
 
 
 
451
  stream_seq = 0
452
 
453
  if not prompt_audio_path or not target_text:
454
  speaking_rate_state.stop()
455
  generation_control.finish()
 
456
  yield (
457
  gr.update(value=None, visible=False),
458
  gr.update(interactive=True),
@@ -460,13 +499,38 @@ def main():
460
  render_text_progress(app_config, None),
461
  render_audio_stream(app_config, session_id=stream_session_id),
462
  *generation_button_updates(running=False),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
463
  )
464
  return
465
 
466
  ensure_generator_on_cuda(speech_generator)
467
  speaking_rate_state.ensure_started(speaking_rate_control)
468
  speaking_rate_gen = (
469
- speaking_rate_state.values() if enable_speaking_rate else None
 
 
 
 
470
  )
471
  text_metadata = build_text_progress_metadata(
472
  target_text,
@@ -511,14 +575,14 @@ def main():
511
 
512
  stream_iter = iter(stream)
513
  while True:
514
- if not generation_control.wait_if_paused():
515
  stopped = True
516
  break
517
  try:
518
  frame, _, progress = next(stream_iter)
519
  except StopIteration:
520
  break
521
- if generation_control.is_stopped():
522
  stopped = True
523
  break
524
 
@@ -528,7 +592,7 @@ def main():
528
  plot_update, text_update = visualization.update(progress)
529
 
530
  if buffer_len >= chunk_size:
531
- if generation_control.is_stopped():
532
  stopped = True
533
  break
534
  audio = np.concatenate(buffer)
@@ -547,14 +611,16 @@ def main():
547
  active=True,
548
  ),
549
  *generation_button_updates(
550
- running=True, paused=generation_control.is_paused()
 
551
  ),
 
552
  )
553
 
554
  buffer = []
555
  buffer_len = 0
556
 
557
- stopped = stopped or generation_control.is_stopped()
558
  if stopped and hasattr(stream, "close"):
559
  stream.close()
560
  final_text = visualization.final_text()
@@ -580,8 +646,10 @@ def main():
580
  active=True,
581
  ),
582
  *generation_button_updates(
583
- running=True, paused=generation_control.is_paused()
 
584
  ),
 
585
  )
586
 
587
  if len(total_buffer) > 0:
@@ -598,6 +666,7 @@ def main():
598
 
599
  speaking_rate_state.stop()
600
  generation_control.finish()
 
601
  yield (
602
  gr.update(value=file_path, visible=True),
603
  gr.update(interactive=True),
@@ -612,10 +681,12 @@ def main():
612
  final=True,
613
  ),
614
  *generation_button_updates(running=False),
 
615
  )
616
  else:
617
  speaking_rate_state.stop()
618
  generation_control.finish()
 
619
  yield (
620
  gr.update(value=None, visible=False),
621
  gr.update(interactive=True),
@@ -630,6 +701,7 @@ def main():
630
  final=True,
631
  ),
632
  *generation_button_updates(running=False),
 
633
  )
634
 
635
  demo_app(
@@ -639,6 +711,7 @@ def main():
639
  synthesize_fn,
640
  speaking_rate_state,
641
  generation_control,
 
642
  )
643
 
644
 
 
24
  CUSTOM_CSS,
25
  AppConfig,
26
  GenerationControl,
27
+ SharedGenerationState,
28
  SpeakingRateState,
29
  VisualizationState,
30
  build_low_latency_audio_head,
 
88
  synthesize_fn,
89
  speaking_rate_state: SpeakingRateState,
90
  generation_control: GenerationControl,
91
+ shared_generation_state: SharedGenerationState,
92
  ):
93
  with gr.Blocks(
94
  css=CUSTOM_CSS,
 
177
  text_progress = gr.HTML(
178
  render_text_progress(app_config, None), elem_id="text-progress-container"
179
  )
180
+ generation_session = gr.State("")
181
 
182
  def validate_inputs(audio, ttext):
183
  if not audio:
184
  return gr.update(
185
  visible=True, value="⚠️ Please provide a prompt audio."
186
  ), gr.update(interactive=False)
187
+ if not ttext or not ttext.strip():
188
  return gr.update(
189
  visible=True, value="⚠️ Please provide target text."
190
  ), gr.update(interactive=False)
 
200
  inputs=prompt_enhancement,
201
  outputs=prompt_enhancement_msg,
202
  )
203
+
204
+ def update_speaking_rate(value, session_id):
205
+ speaking_rate_state.update(value)
206
+ shared_generation_state.update_speaking_rate(session_id, value)
207
+
208
+ speaking_rate_control.input(
209
+ fn=update_speaking_rate,
210
+ inputs=[speaking_rate_control, generation_session],
211
+ queue=False,
212
+ show_progress="hidden",
213
+ )
214
  speaking_rate_control.release(
215
+ fn=update_speaking_rate,
216
+ inputs=[speaking_rate_control, generation_session],
217
  queue=False,
218
+ show_progress="hidden",
219
  )
220
 
221
  for inp in [prompt_audio, target_text]:
 
226
  )
227
 
228
  def prepare_generation(speaking_rate, enable_rate):
229
+ session_id = shared_generation_state.create(speaking_rate)
230
  generation_control.start()
231
  speaking_rate_state.start(speaking_rate)
232
  return (
 
234
  gr.update(interactive=False),
235
  empty_rate_plot(app_config, show_target=enable_rate),
236
  render_text_progress(app_config, None),
237
+ render_audio_stream(app_config, session_id=session_id),
238
  *generation_button_updates(running=True),
239
+ session_id,
240
  )
241
 
242
  submit_btn.click(
 
251
  pause_btn,
252
  resume_btn,
253
  stop_btn,
254
+ generation_session,
255
  ],
256
  show_progress="hidden",
257
  ).then(
 
264
  streaming_input,
265
  speaking_rate_control,
266
  enable_speaking_rate,
267
+ generation_session,
268
  ],
269
  outputs=[
270
  output_audio,
 
275
  pause_btn,
276
  resume_btn,
277
  stop_btn,
278
+ generation_session,
279
  ],
280
  )
281
 
282
+ def pause_generation(session_id):
283
  generation_control.pause()
284
+ shared_generation_state.pause(session_id)
285
  return generation_button_updates(running=True, paused=True)
286
 
287
+ def resume_generation(session_id):
288
  generation_control.resume()
289
+ shared_generation_state.resume(session_id)
290
  return generation_button_updates(running=True)
291
 
292
+ def stop_generation(session_id):
293
  generation_control.stop()
294
  speaking_rate_state.stop()
295
+ shared_generation_state.stop(session_id)
296
  return generation_button_updates(running=False)
297
 
298
  pause_btn.click(
299
  fn=pause_generation,
300
+ inputs=generation_session,
301
  outputs=[pause_btn, resume_btn, stop_btn],
302
  js=(
303
  "() => { if (window.voxtreamLowLatencyAudio) { "
 
307
  )
308
  resume_btn.click(
309
  fn=resume_generation,
310
+ inputs=generation_session,
311
  outputs=[pause_btn, resume_btn, stop_btn],
312
  js=(
313
  "() => { if (window.voxtreamLowLatencyAudio) { "
 
317
  )
318
  stop_btn.click(
319
  fn=stop_generation,
320
+ inputs=generation_session,
321
  outputs=[pause_btn, resume_btn, stop_btn],
322
  js=(
323
  "() => { if (window.voxtreamLowLatencyAudio) { "
 
326
  queue=False,
327
  )
328
 
329
+ def clear_generation(session_id):
330
+ generation_control.stop()
331
+ speaking_rate_state.stop()
332
+ shared_generation_state.stop(session_id)
333
+ return (
334
  gr.update(value=None),
335
  gr.update(value=""),
336
  gr.update(value=None, visible=False),
 
341
  render_text_progress(app_config, None),
342
  render_audio_stream(app_config, session_id=uuid.uuid4().hex),
343
  *generation_button_updates(running=False),
344
+ "",
345
+ )
346
+
347
+ clear_btn.click(
348
+ fn=clear_generation,
349
+ inputs=generation_session,
350
  outputs=[
351
  prompt_audio,
352
  target_text,
 
360
  pause_btn,
361
  resume_btn,
362
  stop_btn,
363
+ generation_session,
364
  ],
365
  )
366
 
 
385
  pause_btn,
386
  resume_btn,
387
  stop_btn,
388
+ generation_session,
389
  ],
390
  fn=synthesize_fn,
391
  cache_examples=False,
392
  )
393
 
394
  ex.dataset.click(
395
+ fn=lambda: (*clear_outputs(app_config), ""),
396
  inputs=[],
397
  outputs=[
398
  output_audio,
399
  rate_plot,
400
  text_progress,
401
  stream_audio,
402
+ generation_session,
403
  ],
404
  queue=False,
405
  ).then(
 
409
  queue=False,
410
  )
411
 
412
+ demo.queue(default_concurrency_limit=1).launch()
413
 
414
 
415
  def main():
 
468
  speech_generator = SpeechGenerator(config, spk_rate_config)
469
  speaking_rate_state = SpeakingRateState(app_config.speaking_rate_default)
470
  generation_control = GenerationControl()
471
+ shared_generation_state = SharedGenerationState()
472
  chunk_size = int(config.mimi_sr * app_config.min_chunk_sec)
473
 
474
  @spaces.GPU
 
480
  streaming_input,
481
  speaking_rate_control,
482
  enable_speaking_rate=True,
483
+ generation_session_id="",
484
  ):
485
+ control_session_id = generation_session_id or shared_generation_state.create(
486
+ speaking_rate_control
487
+ )
488
+ stream_session_id = control_session_id or uuid.uuid4().hex
489
  stream_seq = 0
490
 
491
  if not prompt_audio_path or not target_text:
492
  speaking_rate_state.stop()
493
  generation_control.finish()
494
+ shared_generation_state.finish(control_session_id)
495
  yield (
496
  gr.update(value=None, visible=False),
497
  gr.update(interactive=True),
 
499
  render_text_progress(app_config, None),
500
  render_audio_stream(app_config, session_id=stream_session_id),
501
  *generation_button_updates(running=False),
502
+ control_session_id,
503
+ )
504
+ return
505
+
506
+ if shared_generation_state.is_stopped(control_session_id):
507
+ speaking_rate_state.stop()
508
+ generation_control.finish()
509
+ shared_generation_state.finish(control_session_id)
510
+ yield (
511
+ gr.update(value=None, visible=False),
512
+ gr.update(interactive=True),
513
+ empty_rate_plot(app_config, show_target=enable_speaking_rate),
514
+ render_text_progress(app_config, None),
515
+ render_audio_stream(
516
+ app_config,
517
+ session_id=stream_session_id,
518
+ active=False,
519
+ final=True,
520
+ ),
521
+ *generation_button_updates(running=False),
522
+ control_session_id,
523
  )
524
  return
525
 
526
  ensure_generator_on_cuda(speech_generator)
527
  speaking_rate_state.ensure_started(speaking_rate_control)
528
  speaking_rate_gen = (
529
+ shared_generation_state.speaking_rate_values(
530
+ control_session_id, speaking_rate_control
531
+ )
532
+ if enable_speaking_rate
533
+ else None
534
  )
535
  text_metadata = build_text_progress_metadata(
536
  target_text,
 
575
 
576
  stream_iter = iter(stream)
577
  while True:
578
+ if not shared_generation_state.wait_if_paused(control_session_id):
579
  stopped = True
580
  break
581
  try:
582
  frame, _, progress = next(stream_iter)
583
  except StopIteration:
584
  break
585
+ if shared_generation_state.is_stopped(control_session_id):
586
  stopped = True
587
  break
588
 
 
592
  plot_update, text_update = visualization.update(progress)
593
 
594
  if buffer_len >= chunk_size:
595
+ if shared_generation_state.is_stopped(control_session_id):
596
  stopped = True
597
  break
598
  audio = np.concatenate(buffer)
 
611
  active=True,
612
  ),
613
  *generation_button_updates(
614
+ running=True,
615
+ paused=shared_generation_state.is_paused(control_session_id),
616
  ),
617
+ control_session_id,
618
  )
619
 
620
  buffer = []
621
  buffer_len = 0
622
 
623
+ stopped = stopped or shared_generation_state.is_stopped(control_session_id)
624
  if stopped and hasattr(stream, "close"):
625
  stream.close()
626
  final_text = visualization.final_text()
 
646
  active=True,
647
  ),
648
  *generation_button_updates(
649
+ running=True,
650
+ paused=shared_generation_state.is_paused(control_session_id),
651
  ),
652
+ control_session_id,
653
  )
654
 
655
  if len(total_buffer) > 0:
 
666
 
667
  speaking_rate_state.stop()
668
  generation_control.finish()
669
+ shared_generation_state.finish(control_session_id)
670
  yield (
671
  gr.update(value=file_path, visible=True),
672
  gr.update(interactive=True),
 
681
  final=True,
682
  ),
683
  *generation_button_updates(running=False),
684
+ control_session_id,
685
  )
686
  else:
687
  speaking_rate_state.stop()
688
  generation_control.finish()
689
+ shared_generation_state.finish(control_session_id)
690
  yield (
691
  gr.update(value=None, visible=False),
692
  gr.update(interactive=True),
 
701
  final=True,
702
  ),
703
  *generation_button_updates(running=False),
704
+ control_session_id,
705
  )
706
 
707
  demo_app(
 
711
  synthesize_fn,
712
  speaking_rate_state,
713
  generation_control,
714
+ shared_generation_state,
715
  )
716
 
717
 
requirements.txt CHANGED
@@ -1 +1 @@
1
- voxtream==0.2.2
 
1
+ voxtream==0.2.3