real_time_ASR / app.py
leenag's picture
Update app.py
4e139da verified
import gradio as gr
from transformers import pipeline
import numpy as np
import os
import torch
import torchaudio # For VAD
print(f"DEBUG: Gradio version being used: {gr.__version__}")
# --- Configuration ---
MODEL_NAME = os.getenv("ASR_MODEL", "openai/whisper-base.en")
DEVICE = "cuda" if torch.cuda.is_available() and os.getenv("USE_GPU", "false").lower() == "true" else "cpu"
print(f"Using device: {DEVICE}")
# --- Global Variables ---
asr_pipeline = None
vad_model = None
vad_utils = None
audio_buffer = [] # To accumulate audio chunks
MAX_BUFFER_SECONDS = 10 # Max audio to buffer before forcing transcription
SILENCE_THRESHOLD_SECONDS = 1.5 # How long silence before processing speech segment
# --- Load Models ---
def load_models():
global asr_pipeline, vad_model, vad_utils
try:
print(f"Loading ASR model: {MODEL_NAME} on device: {DEVICE}")
asr_pipeline = pipeline(
task="automatic-speech-recognition",
model=MODEL_NAME,
device=DEVICE if DEVICE == "cuda" else -1
)
print("ASR model loaded successfully.")
print("Loading Silero VAD model...")
# Silero VAD model itself is small and runs on CPU efficiently
vad_model, vad_utils_tuple = torch.hub.load(repo_or_dir='snakers4/silero-vad',
model='silero_vad',
force_reload=False, # Set to True if you have issues
onnx=True) # Use ONNX for better CPU performance
(get_speech_timestamps,
save_audio,
read_audio,
VADIterator,
collect_chunks) = vad_utils_tuple
vad_utils = {
"get_speech_timestamps": get_speech_timestamps,
"VADIterator": VADIterator
}
print("Silero VAD model loaded successfully.")
except Exception as e:
print(f"Error loading models: {e}")
if asr_pipeline is None: print("ASR pipeline failed to load.")
if vad_model is None: print("VAD model failed to load.")
load_models() # Load models at startup
# --- Core Transcription Logic with VAD ---
def transcribe_with_vad(new_chunk_audio, history_state):
global audio_buffer
if new_chunk_audio is None or asr_pipeline is None or vad_model is None:
return history_state.get("full_text", ""), history_state
sample_rate, audio_data = new_chunk_audio
audio_data_float32 = audio_data.astype(np.float32) / np.iinfo(audio_data.dtype).max
# Append to buffer
audio_buffer.append(audio_data_float32)
# Check buffer length; if too short, wait for more audio
current_buffer_duration = sum(len(chunk) / sample_rate for chunk in audio_buffer)
# If buffer is empty or too short, just return current state
if not audio_buffer or current_buffer_duration < 0.2: # Minimum duration to process
return history_state.get("full_text", ""), history_state
# Concatenate buffer for VAD processing
full_audio_np = np.concatenate(audio_buffer)
full_audio_tensor = torch.from_numpy(full_audio_np).float()
# Use VAD to find speech timestamps
# We're looking for the *end* of speech segments
# This is a simplified approach: we process if VAD detects no speech in the latest part
# or if the buffer gets too long.
try:
# For simplicity, let's analyze the last N seconds for silence
# A more robust VADIterator approach would be better for continuous streaming
# but is more complex to manage with Gradio's chunking.
# Let's try a simpler VAD: check if the last chunk contains speech
# For a more robust solution, use VADIterator or process the whole buffer
speech_timestamps = vad_utils["get_speech_timestamps"](
full_audio_tensor,
vad_model,
sampling_rate=sample_rate,
min_silence_duration_ms=500 # ms of silence to consider a break
)
# Heuristic: if speech_timestamps is empty for the latest chunk,
# OR if the buffer is long, OR if there's a significant pause
process_now = False
transcribed_text_segment = ""
if not speech_timestamps: # If no speech detected in the current combined buffer
if current_buffer_duration > SILENCE_THRESHOLD_SECONDS: # and we have enough audio to assume it's silence after speech
process_now = True
elif current_buffer_duration > MAX_BUFFER_SECONDS: # Buffer is too long, process it
process_now = True
else:
# If speech is detected, check if the end of the last speech segment is significantly before the end of the buffer
# This indicates a pause after speech.
if speech_timestamps:
last_speech_end_s = speech_timestamps[-1]['end'] / sample_rate
if current_buffer_duration - last_speech_end_s > SILENCE_THRESHOLD_SECONDS:
process_now = True
if process_now and full_audio_np.any(): # Ensure there's actual audio data
print(f"Processing {current_buffer_duration:.2f}s of buffered audio.")
# Transcribe the entire current buffer
transcription_result = asr_pipeline(
{"sampling_rate": sample_rate, "raw": full_audio_np.copy()}, # Send a copy
# You can add whisper specific args here if needed e.g. chunk_length_s for long-form
# generate_kwargs={"task": "transcribe", "language": "<|en|>"} # for multilingual models
)
new_text = transcription_result["text"].strip()
if new_text:
transcribed_text_segment = new_text + " "
history_state["full_text"] = history_state.get("full_text", "") + transcribed_text_segment
print(f"VAD processed: '{new_text}'")
audio_buffer = [] # Clear buffer after processing
except Exception as e:
print(f"Error during VAD/transcription: {e}")
# Fallback: transcribe accumulated buffer if error, then clear
if audio_buffer:
try:
full_audio_fallback = np.concatenate(audio_buffer)
if full_audio_fallback.any():
transcription_result = asr_pipeline(
{"sampling_rate": sample_rate, "raw": full_audio_fallback.copy()}
)
new_text = transcription_result["text"].strip()
if new_text:
history_state["full_text"] = history_state.get("full_text", "") + new_text + " "
print(f"Fallback processed: '{new_text}'")
except Exception as fallback_e:
print(f"Error during fallback transcription: {fallback_e}")
audio_buffer = [] # Clear buffer
return history_state.get("full_text", ""), history_state
# --- Gradio UI (largely the same, just point to new function and manage state) ---
with gr.Blocks(title="Live Transcription with VAD") as demo:
gr.Markdown(
f"""
# πŸŽ™οΈ Live Speech-to-Text with VAD & Hugging Face Whisper
Speak into your microphone. Transcription will appear after speech segments.
Using model: `{MODEL_NAME}` on device: `{DEVICE}`.
VAD: Silero VAD
"""
)
if asr_pipeline is None or vad_model is None:
gr.Markdown("## ⚠️ Error: Models Not Loaded. Check logs. ⚠️")
transcription_history = gr.State({"full_text": ""})
with gr.Row():
audio_input = gr.Audio(
sources=["microphone"],
type="numpy",
streaming=True,
label="Speak Here (Streaming Active with VAD)",
)
transcription_output = gr.Textbox(
label="Live Transcription", lines=15, interactive=False, show_copy_button=True
)
# Adjust 'every' based on how frequently you want to check the VAD buffer
# Smaller 'every' means more frequent checks, potentially more responsive VAD
# but also more frequent function calls.
audio_input.stream(
fn=transcribe_with_vad,
inputs=[audio_input, transcription_history],
outputs=[transcription_output, transcription_history],
every=0.5 # Check buffer and VAD every 0.5 seconds
)
def clear_transcription_state(current_state):
global audio_buffer
audio_buffer = [] # Also clear the audio buffer
current_state["full_text"] = ""
print("Transcription and audio buffer cleared.")
return "", current_state
clear_button = gr.Button("Clear Transcription & Buffer")
clear_button.click(
fn=clear_transcription_state,
inputs=[transcription_history],
outputs=[transcription_output, transcription_history]
)
gr.Markdown("---")
if __name__ == "__main__":
# os.environ["ASR_MODEL"] = "openai/whisper-tiny.en"
# os.environ["USE_GPU"] = "False"
# load_models() # Ensure models are loaded if running locally
demo.queue().launch(debug=True, share=False)