Spaces:
Running
Running
| import os | |
| import gradio as gr | |
| import torch | |
| import nemo.collections.asr as nemo_asr | |
| from omegaconf import OmegaConf | |
| import time | |
| import spaces | |
| import librosa | |
| # Important: Don't initialize CUDA in the main process for Spaces | |
| # The model will be loaded in the worker process through the GPU decorator | |
| model = None | |
| current_model_name = "nvidia/parakeet-tdt-0.6b-v2" | |
| # Available models | |
| available_models = ["nvidia/parakeet-tdt-0.6b-v2","nvidia/parakeet-tdt-1.1b"] | |
| def load_model(model_name=None): | |
| # This function will be called in the GPU worker process | |
| global model, current_model_name | |
| # Use the specified model name or the current one | |
| model_name = model_name or current_model_name | |
| # Check if we need to load a new model | |
| if model is None or model_name != current_model_name: | |
| print(f"Loading model {model_name} in worker process") | |
| # print(f"CUDA available: {torch.cuda.is_available()}") | |
| # if torch.cuda.is_available(): | |
| # print(f"CUDA device: {torch.cuda.get_device_name(0)}") | |
| # Update the current model name | |
| current_model_name = model_name | |
| # Load the selected model | |
| model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained(model_name) | |
| print(f"Model loaded on device: {model.device}") | |
| return model | |
| def transcribe(audio, model_name="nvidia/parakeet-tdt-0.6b-v2", state="", audio_buffer=None, last_processed_time=0): | |
| # Load the model inside the GPU worker process | |
| import numpy as np | |
| import soundfile as sf | |
| import librosa | |
| import os | |
| model = load_model(model_name) | |
| if audio_buffer is None: | |
| audio_buffer = [] | |
| if audio is None or isinstance(audio, int): | |
| print(f"Skipping invalid audio input: {type(audio)}") | |
| return state, state, audio_buffer, last_processed_time | |
| print(f"Received audio input of type: {type(audio)}") | |
| if isinstance(audio, tuple) and len(audio) == 2 and isinstance(audio[1], np.ndarray): | |
| sample_rate, audio_data = audio | |
| print(f"Sample rate: {sample_rate}, Audio shape: {audio_data.shape}") | |
| # Append chunk to buffer | |
| audio_buffer.append(audio_data) | |
| # Calculate total duration in seconds | |
| total_samples = sum(arr.shape[0] for arr in audio_buffer) | |
| total_duration = total_samples / sample_rate | |
| print(f"Total buffered duration: {total_duration:.2f}s") | |
| # Process 5-second chunks with 2-second step size (3-second overlap) | |
| # Using longer chunks usually helps with transcription accuracy | |
| chunk_duration = 5.0 # seconds (increased from 2.0) | |
| step_size = 2.0 # seconds (increased from 1.0) | |
| # min_samples = int(chunk_duration * 16000) # 5s at 16kHz | |
| if total_duration < chunk_duration: | |
| print(f"Buffering audio, total duration: {total_duration:.2f}s") | |
| return state, state, audio_buffer, last_processed_time | |
| try: | |
| # Concatenate buffered chunks | |
| full_audio = np.concatenate(audio_buffer) | |
| # Resample to 16kHz if needed | |
| if sample_rate != 16000: | |
| print(f"Resampling from {sample_rate}Hz to 16000Hz") | |
| full_audio = librosa.resample(full_audio.astype(float), orig_sr=sample_rate, target_sr=16000) | |
| sample_rate = 16000 | |
| else: | |
| full_audio = full_audio.astype(float) | |
| # Normalize audio (helps with consistent volume levels) | |
| # if np.abs(full_audio).max() > 0: | |
| # full_audio = full_audio / np.abs(full_audio).max() * 0.9 | |
| # print("Audio normalized to improve transcription") | |
| # Process chunks | |
| new_state = state | |
| current_time = last_processed_time | |
| total_samples_16k = len(full_audio) | |
| while current_time + chunk_duration <= total_duration: | |
| start_sample = int(current_time * sample_rate) | |
| end_sample = int((current_time + chunk_duration) * sample_rate) | |
| if end_sample > total_samples_16k: | |
| end_sample = total_samples_16k | |
| chunk = full_audio[start_sample:end_sample] | |
| print(f"Processing chunk from {current_time:.2f}s to {current_time + chunk_duration:.2f}s") | |
| # Save to temporary WAV file | |
| temp_file = "temp_audio.wav" | |
| sf.write(temp_file, chunk, samplerate=16000) | |
| # Transcribe | |
| print(f"Transcribing chunk of duration {chunk_duration}s...") | |
| hypothesis = model.transcribe([temp_file])[0] | |
| transcription = hypothesis.text | |
| print(f"Transcription: {transcription}") | |
| os.remove(temp_file) | |
| print("Temporary file removed.") | |
| # Append transcription if non-empty | |
| if transcription.strip(): | |
| new_state = new_state + " " + transcription if new_state else transcription | |
| current_time += step_size | |
| # Update last processed time | |
| last_processed_time = current_time | |
| # Trim buffer to keep only unprocessed audio | |
| keep_samples = int((total_duration - current_time) * sample_rate) | |
| if keep_samples > 0: | |
| audio_buffer = [full_audio[-keep_samples:]] | |
| else: | |
| audio_buffer = [] | |
| print(f"New state: {new_state}") | |
| return new_state, transcription, audio_buffer, last_processed_time # Return last transcription for streaming_text | |
| except Exception as e: | |
| print(f"Error processing audio: {e}") | |
| return state, state, audio_buffer, last_processed_time | |
| print(f"Invalid audio input format: {type(audio)}") | |
| return state, state, audio_buffer, last_processed_time | |
| def transcribe_file(audio_file, model_name="nvidia/parakeet-tdt-0.6b-v2"): | |
| # Load the model inside the GPU worker process | |
| import numpy as np | |
| import soundfile as sf | |
| import librosa | |
| import os | |
| # Check if audio file is provided | |
| if audio_file is None: | |
| return "No audio file provided. Please upload an audio file." | |
| try: | |
| global model | |
| model = load_model(model_name) | |
| print(f"Processing file: {audio_file}") | |
| # Transcribe the entire file at once | |
| hypothesis = model.transcribe([audio_file])[0] | |
| transcription = hypothesis.text | |
| print(f"File transcription: {transcription}") | |
| return transcription | |
| except Exception as e: | |
| print(f"Error transcribing file: {e}") | |
| return f"Error transcribing file: {str(e)}" | |
| # Define the Gradio interface | |
| with gr.Blocks(title="Real-time Speech-to-Text with NeMo") as demo: | |
| gr.Markdown("# 🎙️ Real-time Speech-to-Text Transcription") | |
| gr.Markdown("Powered by NVIDIA NeMo") | |
| # Model selection and loading | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| model_dropdown = gr.Dropdown( | |
| choices=available_models, | |
| value=current_model_name, | |
| label="Select ASR Model" | |
| ) | |
| with gr.Column(scale=1): | |
| load_button = gr.Button("Load Selected Model", elem_id="load-button", elem_classes=["btn-blue"]) | |
| # Status indicator for model loading | |
| model_status = gr.Textbox( | |
| value=f"Current model: {current_model_name}", | |
| label="Model Status", | |
| container=False | |
| ) | |
| # Create tabs for real-time and file-based transcription | |
| with gr.Tabs(): | |
| # File-based transcription tab | |
| with gr.TabItem("File Transcription"): | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| # Audio recorder that saves to file | |
| audio_recorder = gr.Audio( | |
| sources=["microphone"], | |
| type="filepath", | |
| label="Record or upload audio file" | |
| ) | |
| with gr.Row(): | |
| transcribe_btn = gr.Button("Transcribe Audio File", variant="primary") | |
| clear_file_btn = gr.Button("Clear Transcript", variant="secondary") | |
| with gr.Column(scale=3): | |
| file_transcription = gr.Textbox( | |
| label="File Transcription", | |
| placeholder="Transcription will appear here after clicking 'Transcribe Audio File'", | |
| lines=10 | |
| ) | |
| # Real-time transcription tab | |
| with gr.TabItem("Real-time Transcription"): | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| audio_input = gr.Audio( | |
| sources=["microphone"], | |
| type="numpy", | |
| streaming=True, | |
| label="Speak into your microphone", | |
| waveform_options=gr.WaveformOptions( | |
| sample_rate=16000 | |
| ) | |
| ) | |
| clear_btn = gr.Button("Clear Transcript", variant="secondary") | |
| with gr.Column(scale=3): | |
| text_output = gr.Textbox( | |
| label="Transcription", | |
| placeholder="Your speech will appear here...", | |
| lines=10 | |
| ) | |
| streaming_text = gr.Textbox( | |
| label="Real-time Transcription", | |
| placeholder="Real-time results will appear here...", | |
| lines=2 | |
| ) | |
| # State to store the ongoing transcription | |
| state = gr.State("") | |
| audio_buffer = gr.State(value=None) | |
| last_processed_time = gr.State(value=0) | |
| # Function to handle model selection | |
| def update_model(model_name): | |
| global current_model_name, model | |
| current_model_name = model_name | |
| # Load the model immediately if we're in a GPU context | |
| try: | |
| # This will load the model in the GPU worker | |
| model = load_model(model_name) | |
| status_message = f"Current model: {model_name} (loaded)" | |
| print(f"Model {model_name} loaded successfully") | |
| except Exception as e: | |
| status_message = f"Current model: {model_name} (will be loaded on first use)" | |
| print(f"Model will be loaded on first use: {e}") | |
| return status_message, None, 0 # Reset audio buffer and last processed time | |
| # Load model button event | |
| load_button.click( | |
| fn=update_model, | |
| inputs=[model_dropdown], | |
| outputs=[model_status, audio_buffer, last_processed_time] | |
| ) | |
| # Handle the audio stream for real-time transcription | |
| streaming_text = gr.State(value="") | |
| audio_input.stream( | |
| fn=transcribe, | |
| inputs=[audio_input, model_dropdown, state, audio_buffer, last_processed_time], | |
| outputs=[state, streaming_text, audio_buffer, last_processed_time], | |
| ) | |
| # Handle file transcription | |
| transcribe_btn.click( | |
| fn=transcribe_file, | |
| inputs=[audio_recorder, model_dropdown], | |
| outputs=[file_transcription] | |
| ) # Clear the real-time transcription | |
| def clear_transcription(): | |
| print("Clearing real-time transcription") | |
| return "", "", None, 0 # Return empty values for state, text_output, audio_buffer, and last_processed_time | |
| # Clear the file transcription | |
| def clear_file_transcription(): | |
| print("Clearing file transcription") | |
| return "" # Clear file_transcription | |
| # Set up clear button event handlers | |
| clear_btn.click( | |
| fn=clear_transcription, | |
| inputs=[], | |
| outputs=[state, text_output, audio_buffer, last_processed_time] | |
| ) | |
| # Also clear streaming_text when clearing the transcription | |
| clear_btn.click( | |
| fn=lambda: "", | |
| inputs=[], | |
| outputs=[streaming_text] | |
| ) | |
| clear_file_btn.click( | |
| fn=clear_file_transcription, | |
| inputs=[], | |
| outputs=[file_transcription] | |
| ) # Update the main text output when the state changes | |
| def update_output(transcript): | |
| # For streaming_text, show just the last few words or chunks | |
| words = transcript.split() | |
| if len(words) > 15: | |
| streaming_text = " ".join(words[-15:]) | |
| else: | |
| streaming_text = transcript | |
| return transcript, streaming_text | |
| state.change( | |
| fn=update_output, | |
| inputs=[state], | |
| outputs=[text_output, streaming_text] | |
| ) | |
| gr.Markdown("## 📝 Instructions") | |
| gr.Markdown(""" | |
| ### Real-time Transcription: | |
| 1. Select an ASR model from the dropdown menu | |
| 2. Click 'Load Selected Model' to load the model | |
| 3. Click the microphone button to start recording | |
| 4. Speak clearly into your microphone | |
| 5. The transcription will appear in real-time | |
| 6. Click 'Clear Transcript' to reset the transcription | |
| ### File Transcription: | |
| 1. Select an ASR model from the dropdown menu | |
| 2. Click 'Load Selected Model' to load the model | |
| 3. Switch to the 'File Transcription' tab | |
| 4. Record audio by clicking the microphone button or upload an existing audio file | |
| 5. Click 'Transcribe Audio File' to process the recording | |
| 6. The complete transcription will appear in the text box | |
| 7. Click 'Clear Transcript' to reset the file transcription | |
| """) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch() | |