VoiceFocus / stream_pipeline.py
mariesig's picture
vad (#3)
32dcdfe
import numpy as np
from constants import STREAMER_CLASSES
import gradio as gr
from stt_streamers import DeepgramStreamer
from sdk import SDKWrapper, SDKParams
from ui import LED_DOT_BLACK, LED_DOT_GREEN, LED_DOT_OFF, LED_DOT_RED, LED_DOT_YELLOW
_ENHANCED_TRANSCRIPT = ""
_RAW_TRANSCRIPT = ""
SDK_STREAMING = SDKWrapper()
Streamer_enhanced = None
Streamer_raw = None
def _set_transcript_enhanced(text: str) -> None:
global _ENHANCED_TRANSCRIPT
_ENHANCED_TRANSCRIPT = text
def _set_transcript_raw(text: str) -> None:
global _RAW_TRANSCRIPT
_RAW_TRANSCRIPT = text
def clear_live_transcripts():
global _ENHANCED_TRANSCRIPT, _RAW_TRANSCRIPT
_ENHANCED_TRANSCRIPT = ""
_RAW_TRANSCRIPT = ""
def render_system_status(status: str):
if status == "off":
return LED_DOT_OFF, "Off"
if status == "init":
return LED_DOT_YELLOW, "Initializing..."
if status == "ready":
return LED_DOT_GREEN, "Ready"
if status == "error":
return LED_DOT_RED, "Error. Please refresh the page."
raise ValueError(f"Invalid status: {status}")
def shutdown_streamers():
global Streamer_enhanced, Streamer_raw
try:
if Streamer_enhanced is not None:
Streamer_enhanced.shutdown()
if Streamer_raw is not None:
Streamer_raw.shutdown()
except Exception as e:
print(f"Error shutting down streamers: {e}")
finally:
Streamer_enhanced = None
Streamer_raw = None
def set_stt_streamer(sample_rate: int, stt_model: str):
global Streamer_enhanced, Streamer_raw
StreamerCls = STREAMER_CLASSES.get(stt_model, DeepgramStreamer)
try:
Streamer_enhanced = StreamerCls(
fs_hz=sample_rate,
stream_name="Enhanced",
on_update=_set_transcript_enhanced,
)
Streamer_raw = StreamerCls(
fs_hz=sample_rate,
stream_name="Raw",
on_update=_set_transcript_raw,
)
except Exception as e:
Streamer_enhanced = None
Streamer_raw = None
raise RuntimeError(f"Error initializing STT streamer '{stt_model}': {e}")
def _to_float32_mono(y: np.ndarray) -> np.ndarray:
y = np.asarray(y)
if y.ndim > 1:
y = y.mean(axis=1)
if y.dtype == np.int16:
y = y.astype(np.float32) / 32768.0
else:
y = y.astype(np.float32)
return np.asarray(y, dtype=np.float32).flatten()
def _ensure_initialized(sr: int, streaming_sr, stt_model: str, enhancement_level: float):
streamer_cls = STREAMER_CLASSES[stt_model]
needs_init = (
streaming_sr is None
or streaming_sr != sr
or Streamer_enhanced is None
or Streamer_raw is None
or not isinstance(Streamer_enhanced, streamer_cls)
or not isinstance(Streamer_raw, streamer_cls)
)
if not needs_init:
if SDK_STREAMING.enhancement_level != enhancement_level:
SDK_STREAMING.change_enhancement_level(enhancement_level)
return streaming_sr, *render_system_status("ready")
try:
shutdown_streamers()
sdk_params = SDKParams(
sample_rate=sr,
enhancement_level=enhancement_level
)
SDK_STREAMING.init_processor(sdk_params)
set_stt_streamer(sr, stt_model)
return sr, *render_system_status("ready")
except Exception as e:
gr.Warning(f"Streaming process failed: {e}")
return None, *render_system_status("error")
def stream_step(audio_stream, streaming_sr, stt_model, enhancement_level, input_gain_db):
if audio_stream is None:
return streaming_sr, *render_system_status("off"), _ENHANCED_TRANSCRIPT, _RAW_TRANSCRIPT, LED_DOT_OFF
sr, chunk = audio_stream
if chunk is None:
return streaming_sr, *render_system_status("off"), _ENHANCED_TRANSCRIPT, _RAW_TRANSCRIPT, LED_DOT_OFF
enhancement_level_float = enhancement_level / 100.0
streaming_sr, system_led, system_text = _ensure_initialized(
sr=sr,
streaming_sr=streaming_sr,
stt_model=stt_model,
enhancement_level=enhancement_level_float,
)
try:
y = _to_float32_mono(chunk)
if input_gain_db and input_gain_db > 0:
gain_linear = np.float32(10.0 ** (float(input_gain_db) / 20.0))
y = np.clip(y * gain_linear, -1.0, 1.0).astype(np.float32)
enhanced_chunk_16k, vad_detected = SDK_STREAMING.process_with_vad(y)
enhanced_chunk_16k = np.asarray(enhanced_chunk_16k, dtype=np.float32).flatten()
Streamer_raw.process_chunk(y)
Streamer_enhanced.process_chunk(enhanced_chunk_16k)
vad_led = LED_DOT_GREEN if vad_detected else LED_DOT_BLACK
return streaming_sr, system_led, system_text, _ENHANCED_TRANSCRIPT, _RAW_TRANSCRIPT, vad_led
except Exception as e:
gr.Warning(f"Streaming process failed: {e}")
err_led, err_text = render_system_status("error")
return streaming_sr, err_led, err_text, _ENHANCED_TRANSCRIPT, _RAW_TRANSCRIPT, LED_DOT_OFF
def on_start_recording():
clear_live_transcripts()
led, text = render_system_status("init")
return "", "", led, text
def on_stop_recording():
led, text = render_system_status("off")
return led, led, text, None