Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import tempfile | |
| import os | |
| import time | |
| import datetime | |
| import csv | |
| import warnings | |
| import numpy as np | |
| # Suppress expected warnings | |
| warnings.filterwarnings("ignore", message=".*deprecated.*") | |
| warnings.filterwarnings("ignore", message=".*torch.cuda.*") | |
| # Lazy imports for heavy dependencies | |
| _NEMO_IMPORT_ERROR = None | |
| try: | |
| from nemo.collections.asr.models import ASRModel | |
| except Exception as e: | |
| ASRModel = None | |
| _NEMO_IMPORT_ERROR = str(e) | |
| try: | |
| from pydub import AudioSegment | |
| except ImportError: | |
| AudioSegment = None | |
| try: | |
| import yt_dlp as youtube_dl | |
| except ImportError: | |
| youtube_dl = None | |
| # Model configuration | |
| MODEL_NAME = "nvidia/parakeet-tdt-0.6b-v3" | |
| SAMPLE_RATE = 16000 # Parakeet expects 16kHz audio | |
| LONG_AUDIO_THRESHOLD_S = 480 # 8 minutes - switch to local attention | |
| YT_LENGTH_LIMIT_S = 3600 # Limit YouTube videos to 1 hour | |
| # Detect if running on Hugging Face Spaces (YouTube won't work there due to network restrictions) | |
| IS_HF_SPACE = os.environ.get("SPACE_ID") is not None | |
| # Supported languages (auto-detected by the model) | |
| SUPPORTED_LANGUAGES = [ | |
| "Bulgarian (bg)", "Croatian (hr)", "Czech (cs)", "Danish (da)", | |
| "Dutch (nl)", "English (en)", "Estonian (et)", "Finnish (fi)", | |
| "French (fr)", "German (de)", "Greek (el)", "Hungarian (hu)", | |
| "Italian (it)", "Latvian (lv)", "Lithuanian (lt)", "Maltese (mt)", | |
| "Polish (pl)", "Portuguese (pt)", "Romanian (ro)", "Slovak (sk)", | |
| "Slovenian (sl)", "Spanish (es)", "Swedish (sv)", "Russian (ru)", | |
| "Ukrainian (uk)" | |
| ] | |
| # Lazy load state for the Parakeet model | |
| _PARAKEET_STATE = {"initialized": False, "model": None, "device": "cpu"} | |
| def _init_parakeet() -> None: | |
| """Initialize the Parakeet model lazily on first use.""" | |
| if _PARAKEET_STATE["initialized"]: | |
| return | |
| if ASRModel is None: | |
| error_msg = _NEMO_IMPORT_ERROR or "Unknown import error" | |
| raise gr.Error( | |
| f"NeMo toolkit import failed: {error_msg}. " | |
| "Please run: pip install nemo_toolkit[asr]" | |
| ) | |
| # Detect device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Initializing Parakeet model on device: {device}") | |
| try: | |
| model = ASRModel.from_pretrained(model_name=MODEL_NAME) | |
| model.eval() | |
| if device == "cuda": | |
| model.to("cuda") | |
| model.to(torch.bfloat16) | |
| _PARAKEET_STATE.update({ | |
| "initialized": True, | |
| "model": model, | |
| "device": device, | |
| }) | |
| print("Parakeet model initialized successfully.") | |
| except Exception as e: | |
| raise gr.Error(f"Failed to initialize Parakeet model: {str(e)[:200]}") | |
| def get_device_info() -> str: | |
| """Get the current device being used for inference.""" | |
| if _PARAKEET_STATE["initialized"]: | |
| return _PARAKEET_STATE["device"] | |
| return "cuda" if torch.cuda.is_available() else "cpu" | |
| def _load_and_preprocess_audio(audio_path: str) -> tuple[str, float]: | |
| """ | |
| Load audio file, resample to 16kHz mono if needed. | |
| Returns (processed_path, duration_seconds). | |
| """ | |
| if AudioSegment is None: | |
| raise gr.Error("pydub not installed. Please run: pip install pydub") | |
| audio = AudioSegment.from_file(audio_path) | |
| duration_sec = audio.duration_seconds | |
| needs_processing = False | |
| # Resample to 16kHz if needed | |
| if audio.frame_rate != SAMPLE_RATE: | |
| audio = audio.set_frame_rate(SAMPLE_RATE) | |
| needs_processing = True | |
| # Convert to mono if stereo or multi-channel | |
| if audio.channels > 1: | |
| audio = audio.set_channels(1) | |
| needs_processing = True | |
| if needs_processing: | |
| # Export to temp file | |
| temp_dir = tempfile.mkdtemp() | |
| processed_path = os.path.join(temp_dir, "processed_audio.wav") | |
| audio.export(processed_path, format="wav") | |
| return processed_path, duration_sec | |
| else: | |
| return audio_path, duration_sec | |
| def _format_srt_time(seconds: float) -> str: | |
| """Convert seconds to SRT time format HH:MM:SS,mmm.""" | |
| sanitized = max(0.0, seconds) | |
| delta = datetime.timedelta(seconds=sanitized) | |
| total_int_seconds = int(delta.total_seconds()) | |
| hours = total_int_seconds // 3600 | |
| minutes = (total_int_seconds % 3600) // 60 | |
| secs = total_int_seconds % 60 | |
| ms = delta.microseconds // 1000 | |
| return f"{hours:02d}:{minutes:02d}:{secs:02d},{ms:03d}" | |
| def _generate_srt_content(segment_timestamps: list) -> str: | |
| """Generate SRT formatted string from segment timestamps.""" | |
| srt_lines = [] | |
| for i, ts in enumerate(segment_timestamps): | |
| start_time = _format_srt_time(ts['start']) | |
| end_time = _format_srt_time(ts['end']) | |
| text = ts['segment'] | |
| srt_lines.append(str(i + 1)) | |
| srt_lines.append(f"{start_time} --> {end_time}") | |
| srt_lines.append(text) | |
| srt_lines.append("") | |
| return "\n".join(srt_lines) | |
| def _generate_csv_content(segment_timestamps: list) -> str: | |
| """Generate CSV formatted string from segment timestamps.""" | |
| import io | |
| output = io.StringIO() | |
| writer = csv.writer(output) | |
| writer.writerow(["Start (s)", "End (s)", "Segment"]) | |
| for ts in segment_timestamps: | |
| writer.writerow([f"{ts['start']:.2f}", f"{ts['end']:.2f}", ts['segment']]) | |
| return output.getvalue() | |
| def transcribe_audio( | |
| audio_path: str, | |
| return_timestamps: bool, | |
| timestamp_level: str, | |
| ): | |
| """ | |
| Transcribe audio file using Parakeet. | |
| Args: | |
| audio_path: Path to the audio file | |
| return_timestamps: Whether to include timestamps | |
| timestamp_level: Level of timestamps ("word", "segment", or "char") | |
| Returns: | |
| Tuple of (transcription_text, csv_file_path, srt_file_path) | |
| """ | |
| if not audio_path: | |
| raise gr.Error("Please provide an audio file to transcribe.") | |
| # Initialize model on first use | |
| _init_parakeet() | |
| model = _PARAKEET_STATE["model"] | |
| device = _PARAKEET_STATE["device"] | |
| processed_path = None | |
| long_audio_settings_applied = False | |
| try: | |
| # Preprocess audio | |
| gr.Info("Loading and preprocessing audio...") | |
| processed_path, duration_sec = _load_and_preprocess_audio(audio_path) | |
| # Apply long audio settings if needed | |
| if duration_sec > LONG_AUDIO_THRESHOLD_S: | |
| gr.Info(f"Audio is {duration_sec:.0f}s (>{LONG_AUDIO_THRESHOLD_S}s). Applying local attention for long audio.") | |
| try: | |
| model.change_attention_model("rel_pos_local_attn", [256, 256]) | |
| model.change_subsampling_conv_chunking_factor(1) | |
| long_audio_settings_applied = True | |
| except Exception as e: | |
| gr.Warning(f"Could not apply long audio settings: {e}") | |
| # Ensure model is on correct device with correct dtype | |
| if device == "cuda": | |
| model.to("cuda") | |
| model.to(torch.bfloat16) | |
| else: | |
| model.to("cpu") | |
| model.to(torch.float32) | |
| # Transcribe | |
| gr.Info("Transcribing audio...") | |
| print(f"DEBUG: Calling transcribe with timestamps={return_timestamps}") | |
| output = model.transcribe([processed_path], timestamps=return_timestamps) | |
| print(f"DEBUG: Transcription complete, got output type: {type(output)}") | |
| if not output or not isinstance(output, list) or not output[0]: | |
| raise gr.Error("Transcription failed or produced unexpected output.") | |
| # Extract text | |
| transcription_text = output[0].text if hasattr(output[0], 'text') else str(output[0]) | |
| print(f"DEBUG: Extracted text: {transcription_text[:100] if transcription_text else 'empty'}...") | |
| # Handle timestamps | |
| csv_path = None | |
| srt_path = None | |
| if return_timestamps and hasattr(output[0], 'timestamp') and output[0].timestamp: | |
| timestamps = output[0].timestamp | |
| # Get timestamps at the requested level | |
| if timestamp_level in timestamps: | |
| ts_data = timestamps[timestamp_level] | |
| # Format text with timestamps | |
| if timestamp_level == "segment": | |
| lines = [] | |
| for ts in ts_data: | |
| start = ts.get('start', 0) | |
| end = ts.get('end', 0) | |
| text = ts.get('segment', '') | |
| lines.append(f"[{start:.2f}s - {end:.2f}s] {text}") | |
| transcription_text = "\n".join(lines) | |
| # Generate download files | |
| temp_dir = tempfile.mkdtemp() | |
| # CSV | |
| csv_content = _generate_csv_content(ts_data) | |
| csv_path = os.path.join(temp_dir, "transcription.csv") | |
| with open(csv_path, 'w', encoding='utf-8') as f: | |
| f.write(csv_content) | |
| # SRT | |
| srt_content = _generate_srt_content(ts_data) | |
| srt_path = os.path.join(temp_dir, "transcription.srt") | |
| with open(srt_path, 'w', encoding='utf-8') as f: | |
| f.write(srt_content) | |
| elif timestamp_level == "word": | |
| lines = [] | |
| for ts in ts_data: | |
| start = ts.get('start', 0) | |
| end = ts.get('end', 0) | |
| word = ts.get('word', '') | |
| lines.append(f"[{start:.2f}s] {word}") | |
| transcription_text = "\n".join(lines) | |
| elif timestamp_level == "char": | |
| lines = [] | |
| for ts in ts_data: | |
| start = ts.get('start', 0) | |
| char = ts.get('char', '') | |
| lines.append(f"[{start:.3f}s] {char}") | |
| transcription_text = "\n".join(lines) | |
| gr.Info("Transcription complete!") | |
| print(f"DEBUG: Returning transcription of length {len(transcription_text)}") | |
| # Return with download buttons visibility using gr.update() | |
| return ( | |
| transcription_text, | |
| gr.update(value=csv_path, visible=csv_path is not None), | |
| gr.update(value=srt_path, visible=srt_path is not None), | |
| ) | |
| except gr.Error: | |
| raise | |
| except torch.cuda.OutOfMemoryError: | |
| raise gr.Error("CUDA out of memory. Please try a shorter audio file.") | |
| except Exception as e: | |
| raise gr.Error(f"Transcription failed: {str(e)[:200]}") | |
| finally: | |
| # Revert long audio settings | |
| if long_audio_settings_applied: | |
| try: | |
| model.change_attention_model("rel_pos") | |
| model.change_subsampling_conv_chunking_factor(-1) | |
| except Exception: | |
| pass | |
| # Clean up temp file | |
| if processed_path and processed_path != audio_path: | |
| try: | |
| os.remove(processed_path) | |
| os.rmdir(os.path.dirname(processed_path)) | |
| except Exception: | |
| pass | |
| # Note: We intentionally keep the model on GPU to avoid reload overhead | |
| # The model will be reused for subsequent transcriptions | |
| def _get_yt_html_embed(yt_url: str) -> str: | |
| """Generate YouTube embed HTML for display.""" | |
| video_id = yt_url.split("?v=")[-1].split("&")[0] | |
| return ( | |
| f'<center><iframe width="500" height="320" ' | |
| f'src="https://www.youtube.com/embed/{video_id}"></iframe></center>' | |
| ) | |
| def _download_yt_audio(yt_url: str, filepath: str) -> None: | |
| """Download audio from a YouTube URL.""" | |
| if youtube_dl is None: | |
| raise gr.Error("yt-dlp not installed. Please run: pip install yt-dlp") | |
| info_loader = youtube_dl.YoutubeDL() | |
| try: | |
| info = info_loader.extract_info(yt_url, download=False) | |
| except youtube_dl.utils.DownloadError as err: | |
| err_str = str(err) | |
| if "Failed to resolve" in err_str or "No address associated" in err_str: | |
| raise gr.Error( | |
| "YouTube download failed due to network restrictions. " | |
| "This feature requires running the app locally. " | |
| "On Hugging Face Spaces, outbound connections to YouTube are blocked." | |
| ) | |
| raise gr.Error(str(err)) | |
| # Parse duration | |
| file_length = info.get("duration_string", "0") | |
| file_h_m_s = file_length.split(":") | |
| file_h_m_s = [int(sub_length) for sub_length in file_h_m_s] | |
| if len(file_h_m_s) == 1: | |
| file_h_m_s.insert(0, 0) | |
| if len(file_h_m_s) == 2: | |
| file_h_m_s.insert(0, 0) | |
| file_length_s = file_h_m_s[0] * 3600 + file_h_m_s[1] * 60 + file_h_m_s[2] | |
| if file_length_s > YT_LENGTH_LIMIT_S: | |
| yt_limit_hms = time.strftime("%H:%M:%S", time.gmtime(YT_LENGTH_LIMIT_S)) | |
| file_hms = time.strftime("%H:%M:%S", time.gmtime(file_length_s)) | |
| raise gr.Error(f"Maximum YouTube length is {yt_limit_hms}, got {file_hms}.") | |
| ydl_opts = { | |
| "outtmpl": filepath, | |
| "format": "worstvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best", | |
| } | |
| with youtube_dl.YoutubeDL(ydl_opts) as ydl: | |
| try: | |
| ydl.download([yt_url]) | |
| except youtube_dl.utils.ExtractorError as err: | |
| raise gr.Error(str(err)) | |
| def transcribe_youtube( | |
| yt_url: str, | |
| return_timestamps: bool, | |
| timestamp_level: str, | |
| ): | |
| """ | |
| Transcribe a YouTube video. | |
| Yields tuples of (html_embed, transcription_text) for streaming updates. | |
| """ | |
| if not yt_url: | |
| raise gr.Error("Please provide a YouTube URL.") | |
| if youtube_dl is None: | |
| raise gr.Error("yt-dlp not installed. Please run: pip install yt-dlp") | |
| html_embed = _get_yt_html_embed(yt_url) | |
| # Initialize model | |
| _init_parakeet() | |
| model = _PARAKEET_STATE["model"] | |
| device = _PARAKEET_STATE["device"] | |
| # Download video to temp directory | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| filepath = os.path.join(tmpdir, "video.mp4") | |
| # Yield initial state while downloading | |
| yield html_embed, "Downloading video..." | |
| _download_yt_audio(yt_url, filepath) | |
| yield html_embed, "Processing audio..." | |
| # Preprocess audio | |
| processed_path, duration_sec = _load_and_preprocess_audio(filepath) | |
| long_audio_settings_applied = False | |
| try: | |
| # Apply long audio settings if needed | |
| if duration_sec > LONG_AUDIO_THRESHOLD_S: | |
| try: | |
| model.change_attention_model("rel_pos_local_attn", [256, 256]) | |
| model.change_subsampling_conv_chunking_factor(1) | |
| long_audio_settings_applied = True | |
| except Exception: | |
| pass | |
| # Ensure model is on correct device | |
| if device == "cuda": | |
| model.to("cuda") | |
| model.to(torch.bfloat16) | |
| else: | |
| model.to("cpu") | |
| model.to(torch.float32) | |
| yield html_embed, "Transcribing audio..." | |
| # Transcribe | |
| output = model.transcribe([processed_path], timestamps=return_timestamps) | |
| if not output or not isinstance(output, list) or not output[0]: | |
| raise gr.Error("Transcription failed or produced unexpected output.") | |
| # Extract text | |
| transcription_text = output[0].text if hasattr(output[0], 'text') else str(output[0]) | |
| # Handle timestamps if requested | |
| if return_timestamps and hasattr(output[0], 'timestamp') and output[0].timestamp: | |
| timestamps = output[0].timestamp | |
| if timestamp_level in timestamps: | |
| ts_data = timestamps[timestamp_level] | |
| if timestamp_level == "segment": | |
| lines = [] | |
| for ts in ts_data: | |
| start = ts.get('start', 0) | |
| end = ts.get('end', 0) | |
| text = ts.get('segment', '') | |
| lines.append(f"[{start:.2f}s - {end:.2f}s] {text}") | |
| transcription_text = "\n".join(lines) | |
| elif timestamp_level == "word": | |
| lines = [] | |
| for ts in ts_data: | |
| start = ts.get('start', 0) | |
| word = ts.get('word', '') | |
| lines.append(f"[{start:.2f}s] {word}") | |
| transcription_text = "\n".join(lines) | |
| yield html_embed, transcription_text | |
| finally: | |
| # Revert long audio settings | |
| if long_audio_settings_applied: | |
| try: | |
| model.change_attention_model("rel_pos") | |
| model.change_subsampling_conv_chunking_factor(-1) | |
| except Exception: | |
| pass | |
| # Clean up temp file if different from original | |
| if processed_path != filepath: | |
| try: | |
| os.remove(processed_path) | |
| os.rmdir(os.path.dirname(processed_path)) | |
| except Exception: | |
| pass | |
| # Build the Gradio interface | |
| with gr.Blocks(title="Parakeet-ASR") as demo: | |
| # Header | |
| gr.HTML( | |
| f""" | |
| <h1 style='text-align: center;'>Parakeet-ASR 🦜</h1> | |
| <p style='text-align: center;'> | |
| Powered by <code>nvidia/parakeet-tdt-0.6b-v3</code> on | |
| <strong>{get_device_info().upper()}</strong> | |
| </p> | |
| <p style='text-align: center; font-size: 0.9em;'> | |
| Supports 25 European languages with automatic detection, punctuation, and capitalization. | |
| </p> | |
| """ | |
| ) | |
| with gr.Tabs(): | |
| # Tab 1: Audio File / Microphone | |
| with gr.TabItem("Audio File"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| audio_input = gr.Audio( | |
| label="Audio Input", | |
| sources=["microphone", "upload"], | |
| type="filepath", | |
| ) | |
| timestamps_checkbox = gr.Checkbox( | |
| label="Return Timestamps", | |
| value=False, | |
| ) | |
| timestamp_level_radio = gr.Radio( | |
| choices=["segment", "word", "char"], | |
| value="segment", | |
| label="Timestamp Level", | |
| info="Level of detail for timestamps", | |
| visible=False, | |
| ) | |
| # Show/hide timestamp level based on checkbox | |
| timestamps_checkbox.change( | |
| fn=lambda x: gr.Radio(visible=x), | |
| inputs=[timestamps_checkbox], | |
| outputs=[timestamp_level_radio], | |
| ) | |
| transcribe_btn = gr.Button("Transcribe", variant="primary") | |
| with gr.Column(): | |
| audio_output = gr.Textbox( | |
| label="Transcription", | |
| placeholder="Transcribed text will appear here...", | |
| lines=12, | |
| ) | |
| with gr.Row(): | |
| download_csv_btn = gr.DownloadButton( | |
| label="Download CSV", | |
| visible=False, | |
| ) | |
| download_srt_btn = gr.DownloadButton( | |
| label="Download SRT", | |
| visible=False, | |
| ) | |
| transcribe_btn.click( | |
| fn=transcribe_audio, | |
| inputs=[audio_input, timestamps_checkbox, timestamp_level_radio], | |
| outputs=[audio_output, download_csv_btn, download_srt_btn], | |
| api_name="transcribe", | |
| ) | |
| # Tab 2: YouTube (only shown when running locally) | |
| if not IS_HF_SPACE: | |
| with gr.TabItem("YouTube"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| yt_url_input = gr.Textbox( | |
| label="YouTube URL", | |
| placeholder="Paste a YouTube video URL here...", | |
| lines=1, | |
| ) | |
| yt_timestamps_checkbox = gr.Checkbox( | |
| label="Return Timestamps", | |
| value=False, | |
| ) | |
| yt_timestamp_level_radio = gr.Radio( | |
| choices=["segment", "word"], | |
| value="segment", | |
| label="Timestamp Level", | |
| visible=False, | |
| ) | |
| yt_timestamps_checkbox.change( | |
| fn=lambda x: gr.Radio(visible=x), | |
| inputs=[yt_timestamps_checkbox], | |
| outputs=[yt_timestamp_level_radio], | |
| ) | |
| yt_transcribe_btn = gr.Button("Transcribe YouTube", variant="primary") | |
| with gr.Column(): | |
| yt_embed = gr.HTML(label="Video") | |
| yt_output = gr.Textbox( | |
| label="Transcription", | |
| placeholder="Transcribed text will appear here...", | |
| lines=10, | |
| ) | |
| yt_transcribe_btn.click( | |
| fn=transcribe_youtube, | |
| inputs=[yt_url_input, yt_timestamps_checkbox, yt_timestamp_level_radio], | |
| outputs=[yt_embed, yt_output], | |
| api_name="transcribe_youtube", | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch(theme="Nymbo/Nymbo_Theme") | |