Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from transformers import pipeline | |
| import numpy as np | |
| import os | |
| import torch | |
| import torchaudio # For VAD | |
| print(f"DEBUG: Gradio version being used: {gr.__version__}") | |
| # --- Configuration --- | |
| MODEL_NAME = os.getenv("ASR_MODEL", "openai/whisper-base.en") | |
| DEVICE = "cuda" if torch.cuda.is_available() and os.getenv("USE_GPU", "false").lower() == "true" else "cpu" | |
| print(f"Using device: {DEVICE}") | |
| # --- Global Variables --- | |
| asr_pipeline = None | |
| vad_model = None | |
| vad_utils = None | |
| audio_buffer = [] # To accumulate audio chunks | |
| MAX_BUFFER_SECONDS = 10 # Max audio to buffer before forcing transcription | |
| SILENCE_THRESHOLD_SECONDS = 1.5 # How long silence before processing speech segment | |
| # --- Load Models --- | |
| def load_models(): | |
| global asr_pipeline, vad_model, vad_utils | |
| try: | |
| print(f"Loading ASR model: {MODEL_NAME} on device: {DEVICE}") | |
| asr_pipeline = pipeline( | |
| task="automatic-speech-recognition", | |
| model=MODEL_NAME, | |
| device=DEVICE if DEVICE == "cuda" else -1 | |
| ) | |
| print("ASR model loaded successfully.") | |
| print("Loading Silero VAD model...") | |
| # Silero VAD model itself is small and runs on CPU efficiently | |
| vad_model, vad_utils_tuple = torch.hub.load(repo_or_dir='snakers4/silero-vad', | |
| model='silero_vad', | |
| force_reload=False, # Set to True if you have issues | |
| onnx=True) # Use ONNX for better CPU performance | |
| (get_speech_timestamps, | |
| save_audio, | |
| read_audio, | |
| VADIterator, | |
| collect_chunks) = vad_utils_tuple | |
| vad_utils = { | |
| "get_speech_timestamps": get_speech_timestamps, | |
| "VADIterator": VADIterator | |
| } | |
| print("Silero VAD model loaded successfully.") | |
| except Exception as e: | |
| print(f"Error loading models: {e}") | |
| if asr_pipeline is None: print("ASR pipeline failed to load.") | |
| if vad_model is None: print("VAD model failed to load.") | |
| load_models() # Load models at startup | |
| # --- Core Transcription Logic with VAD --- | |
| def transcribe_with_vad(new_chunk_audio, history_state): | |
| global audio_buffer | |
| if new_chunk_audio is None or asr_pipeline is None or vad_model is None: | |
| return history_state.get("full_text", ""), history_state | |
| sample_rate, audio_data = new_chunk_audio | |
| audio_data_float32 = audio_data.astype(np.float32) / np.iinfo(audio_data.dtype).max | |
| # Append to buffer | |
| audio_buffer.append(audio_data_float32) | |
| # Check buffer length; if too short, wait for more audio | |
| current_buffer_duration = sum(len(chunk) / sample_rate for chunk in audio_buffer) | |
| # If buffer is empty or too short, just return current state | |
| if not audio_buffer or current_buffer_duration < 0.2: # Minimum duration to process | |
| return history_state.get("full_text", ""), history_state | |
| # Concatenate buffer for VAD processing | |
| full_audio_np = np.concatenate(audio_buffer) | |
| full_audio_tensor = torch.from_numpy(full_audio_np).float() | |
| # Use VAD to find speech timestamps | |
| # We're looking for the *end* of speech segments | |
| # This is a simplified approach: we process if VAD detects no speech in the latest part | |
| # or if the buffer gets too long. | |
| try: | |
| # For simplicity, let's analyze the last N seconds for silence | |
| # A more robust VADIterator approach would be better for continuous streaming | |
| # but is more complex to manage with Gradio's chunking. | |
| # Let's try a simpler VAD: check if the last chunk contains speech | |
| # For a more robust solution, use VADIterator or process the whole buffer | |
| speech_timestamps = vad_utils["get_speech_timestamps"]( | |
| full_audio_tensor, | |
| vad_model, | |
| sampling_rate=sample_rate, | |
| min_silence_duration_ms=500 # ms of silence to consider a break | |
| ) | |
| # Heuristic: if speech_timestamps is empty for the latest chunk, | |
| # OR if the buffer is long, OR if there's a significant pause | |
| process_now = False | |
| transcribed_text_segment = "" | |
| if not speech_timestamps: # If no speech detected in the current combined buffer | |
| if current_buffer_duration > SILENCE_THRESHOLD_SECONDS: # and we have enough audio to assume it's silence after speech | |
| process_now = True | |
| elif current_buffer_duration > MAX_BUFFER_SECONDS: # Buffer is too long, process it | |
| process_now = True | |
| else: | |
| # If speech is detected, check if the end of the last speech segment is significantly before the end of the buffer | |
| # This indicates a pause after speech. | |
| if speech_timestamps: | |
| last_speech_end_s = speech_timestamps[-1]['end'] / sample_rate | |
| if current_buffer_duration - last_speech_end_s > SILENCE_THRESHOLD_SECONDS: | |
| process_now = True | |
| if process_now and full_audio_np.any(): # Ensure there's actual audio data | |
| print(f"Processing {current_buffer_duration:.2f}s of buffered audio.") | |
| # Transcribe the entire current buffer | |
| transcription_result = asr_pipeline( | |
| {"sampling_rate": sample_rate, "raw": full_audio_np.copy()}, # Send a copy | |
| # You can add whisper specific args here if needed e.g. chunk_length_s for long-form | |
| # generate_kwargs={"task": "transcribe", "language": "<|en|>"} # for multilingual models | |
| ) | |
| new_text = transcription_result["text"].strip() | |
| if new_text: | |
| transcribed_text_segment = new_text + " " | |
| history_state["full_text"] = history_state.get("full_text", "") + transcribed_text_segment | |
| print(f"VAD processed: '{new_text}'") | |
| audio_buffer = [] # Clear buffer after processing | |
| except Exception as e: | |
| print(f"Error during VAD/transcription: {e}") | |
| # Fallback: transcribe accumulated buffer if error, then clear | |
| if audio_buffer: | |
| try: | |
| full_audio_fallback = np.concatenate(audio_buffer) | |
| if full_audio_fallback.any(): | |
| transcription_result = asr_pipeline( | |
| {"sampling_rate": sample_rate, "raw": full_audio_fallback.copy()} | |
| ) | |
| new_text = transcription_result["text"].strip() | |
| if new_text: | |
| history_state["full_text"] = history_state.get("full_text", "") + new_text + " " | |
| print(f"Fallback processed: '{new_text}'") | |
| except Exception as fallback_e: | |
| print(f"Error during fallback transcription: {fallback_e}") | |
| audio_buffer = [] # Clear buffer | |
| return history_state.get("full_text", ""), history_state | |
| # --- Gradio UI (largely the same, just point to new function and manage state) --- | |
| with gr.Blocks(title="Live Transcription with VAD") as demo: | |
| gr.Markdown( | |
| f""" | |
| # ποΈ Live Speech-to-Text with VAD & Hugging Face Whisper | |
| Speak into your microphone. Transcription will appear after speech segments. | |
| Using model: `{MODEL_NAME}` on device: `{DEVICE}`. | |
| VAD: Silero VAD | |
| """ | |
| ) | |
| if asr_pipeline is None or vad_model is None: | |
| gr.Markdown("## β οΈ Error: Models Not Loaded. Check logs. β οΈ") | |
| transcription_history = gr.State({"full_text": ""}) | |
| with gr.Row(): | |
| audio_input = gr.Audio( | |
| sources=["microphone"], | |
| type="numpy", | |
| streaming=True, | |
| label="Speak Here (Streaming Active with VAD)", | |
| ) | |
| transcription_output = gr.Textbox( | |
| label="Live Transcription", lines=15, interactive=False, show_copy_button=True | |
| ) | |
| # Adjust 'every' based on how frequently you want to check the VAD buffer | |
| # Smaller 'every' means more frequent checks, potentially more responsive VAD | |
| # but also more frequent function calls. | |
| audio_input.stream( | |
| fn=transcribe_with_vad, | |
| inputs=[audio_input, transcription_history], | |
| outputs=[transcription_output, transcription_history], | |
| every=0.5 # Check buffer and VAD every 0.5 seconds | |
| ) | |
| def clear_transcription_state(current_state): | |
| global audio_buffer | |
| audio_buffer = [] # Also clear the audio buffer | |
| current_state["full_text"] = "" | |
| print("Transcription and audio buffer cleared.") | |
| return "", current_state | |
| clear_button = gr.Button("Clear Transcription & Buffer") | |
| clear_button.click( | |
| fn=clear_transcription_state, | |
| inputs=[transcription_history], | |
| outputs=[transcription_output, transcription_history] | |
| ) | |
| gr.Markdown("---") | |
| if __name__ == "__main__": | |
| # os.environ["ASR_MODEL"] = "openai/whisper-tiny.en" | |
| # os.environ["USE_GPU"] = "False" | |
| # load_models() # Ensure models are loaded if running locally | |
| demo.queue().launch(debug=True, share=False) |