| | import gradio as gr |
| | import torch |
| | import tempfile |
| | import os |
| | import time |
| | import datetime |
| | import csv |
| | import warnings |
| | import numpy as np |
| |
|
| | |
| | warnings.filterwarnings("ignore", message=".*deprecated.*") |
| | warnings.filterwarnings("ignore", message=".*torch.cuda.*") |
| |
|
| | |
| | _NEMO_IMPORT_ERROR = None |
| | try: |
| | from nemo.collections.asr.models import ASRModel |
| | except Exception as e: |
| | ASRModel = None |
| | _NEMO_IMPORT_ERROR = str(e) |
| |
|
| | try: |
| | from pydub import AudioSegment |
| | except ImportError: |
| | AudioSegment = None |
| |
|
| | try: |
| | import yt_dlp as youtube_dl |
| | except ImportError: |
| | youtube_dl = None |
| |
|
| | |
| | MODEL_NAME = "nvidia/parakeet-tdt-0.6b-v3" |
| | SAMPLE_RATE = 16000 |
| | LONG_AUDIO_THRESHOLD_S = 480 |
| | YT_LENGTH_LIMIT_S = 3600 |
| |
|
| | |
| | IS_HF_SPACE = os.environ.get("SPACE_ID") is not None |
| |
|
| | |
| | SUPPORTED_LANGUAGES = [ |
| | "Bulgarian (bg)", "Croatian (hr)", "Czech (cs)", "Danish (da)", |
| | "Dutch (nl)", "English (en)", "Estonian (et)", "Finnish (fi)", |
| | "French (fr)", "German (de)", "Greek (el)", "Hungarian (hu)", |
| | "Italian (it)", "Latvian (lv)", "Lithuanian (lt)", "Maltese (mt)", |
| | "Polish (pl)", "Portuguese (pt)", "Romanian (ro)", "Slovak (sk)", |
| | "Slovenian (sl)", "Spanish (es)", "Swedish (sv)", "Russian (ru)", |
| | "Ukrainian (uk)" |
| | ] |
| |
|
| | |
| | _PARAKEET_STATE = {"initialized": False, "model": None, "device": "cpu"} |
| |
|
| |
|
| | def _init_parakeet() -> None: |
| | """Initialize the Parakeet model lazily on first use.""" |
| | if _PARAKEET_STATE["initialized"]: |
| | return |
| |
|
| | if ASRModel is None: |
| | error_msg = _NEMO_IMPORT_ERROR or "Unknown import error" |
| | raise gr.Error( |
| | f"NeMo toolkit import failed: {error_msg}. " |
| | "Please run: pip install nemo_toolkit[asr]" |
| | ) |
| |
|
| | |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
|
| | print(f"Initializing Parakeet model on device: {device}") |
| |
|
| | try: |
| | model = ASRModel.from_pretrained(model_name=MODEL_NAME) |
| | model.eval() |
| |
|
| | if device == "cuda": |
| | model.to("cuda") |
| | model.to(torch.bfloat16) |
| |
|
| | _PARAKEET_STATE.update({ |
| | "initialized": True, |
| | "model": model, |
| | "device": device, |
| | }) |
| | print("Parakeet model initialized successfully.") |
| |
|
| | except Exception as e: |
| | raise gr.Error(f"Failed to initialize Parakeet model: {str(e)[:200]}") |
| |
|
| |
|
| | def get_device_info() -> str: |
| | """Get the current device being used for inference.""" |
| | if _PARAKEET_STATE["initialized"]: |
| | return _PARAKEET_STATE["device"] |
| | return "cuda" if torch.cuda.is_available() else "cpu" |
| |
|
| |
|
| | def _load_and_preprocess_audio(audio_path: str) -> tuple[str, float]: |
| | """ |
| | Load audio file, resample to 16kHz mono if needed. |
| | Returns (processed_path, duration_seconds). |
| | """ |
| | if AudioSegment is None: |
| | raise gr.Error("pydub not installed. Please run: pip install pydub") |
| |
|
| | audio = AudioSegment.from_file(audio_path) |
| | duration_sec = audio.duration_seconds |
| |
|
| | needs_processing = False |
| |
|
| | |
| | if audio.frame_rate != SAMPLE_RATE: |
| | audio = audio.set_frame_rate(SAMPLE_RATE) |
| | needs_processing = True |
| |
|
| | |
| | if audio.channels > 1: |
| | audio = audio.set_channels(1) |
| | needs_processing = True |
| |
|
| | if needs_processing: |
| | |
| | temp_dir = tempfile.mkdtemp() |
| | processed_path = os.path.join(temp_dir, "processed_audio.wav") |
| | audio.export(processed_path, format="wav") |
| | return processed_path, duration_sec |
| | else: |
| | return audio_path, duration_sec |
| |
|
| |
|
| | def _format_srt_time(seconds: float) -> str: |
| | """Convert seconds to SRT time format HH:MM:SS,mmm.""" |
| | sanitized = max(0.0, seconds) |
| | delta = datetime.timedelta(seconds=sanitized) |
| | total_int_seconds = int(delta.total_seconds()) |
| |
|
| | hours = total_int_seconds // 3600 |
| | minutes = (total_int_seconds % 3600) // 60 |
| | secs = total_int_seconds % 60 |
| | ms = delta.microseconds // 1000 |
| |
|
| | return f"{hours:02d}:{minutes:02d}:{secs:02d},{ms:03d}" |
| |
|
| |
|
| | def _generate_srt_content(segment_timestamps: list) -> str: |
| | """Generate SRT formatted string from segment timestamps.""" |
| | srt_lines = [] |
| | for i, ts in enumerate(segment_timestamps): |
| | start_time = _format_srt_time(ts['start']) |
| | end_time = _format_srt_time(ts['end']) |
| | text = ts['segment'] |
| | srt_lines.append(str(i + 1)) |
| | srt_lines.append(f"{start_time} --> {end_time}") |
| | srt_lines.append(text) |
| | srt_lines.append("") |
| | return "\n".join(srt_lines) |
| |
|
| |
|
| | def _generate_csv_content(segment_timestamps: list) -> str: |
| | """Generate CSV formatted string from segment timestamps.""" |
| | import io |
| | output = io.StringIO() |
| | writer = csv.writer(output) |
| | writer.writerow(["Start (s)", "End (s)", "Segment"]) |
| | for ts in segment_timestamps: |
| | writer.writerow([f"{ts['start']:.2f}", f"{ts['end']:.2f}", ts['segment']]) |
| | return output.getvalue() |
| |
|
| |
|
| | def transcribe_audio( |
| | audio_path: str, |
| | return_timestamps: bool, |
| | timestamp_level: str, |
| | ): |
| | """ |
| | Transcribe audio file using Parakeet. |
| | |
| | Args: |
| | audio_path: Path to the audio file |
| | return_timestamps: Whether to include timestamps |
| | timestamp_level: Level of timestamps ("word", "segment", or "char") |
| | |
| | Returns: |
| | Tuple of (transcription_text, csv_file_path, srt_file_path) |
| | """ |
| | if not audio_path: |
| | raise gr.Error("Please provide an audio file to transcribe.") |
| |
|
| | |
| | _init_parakeet() |
| | model = _PARAKEET_STATE["model"] |
| | device = _PARAKEET_STATE["device"] |
| |
|
| | processed_path = None |
| | long_audio_settings_applied = False |
| |
|
| | try: |
| | |
| | gr.Info("Loading and preprocessing audio...") |
| | processed_path, duration_sec = _load_and_preprocess_audio(audio_path) |
| |
|
| | |
| | if duration_sec > LONG_AUDIO_THRESHOLD_S: |
| | gr.Info(f"Audio is {duration_sec:.0f}s (>{LONG_AUDIO_THRESHOLD_S}s). Applying local attention for long audio.") |
| | try: |
| | model.change_attention_model("rel_pos_local_attn", [256, 256]) |
| | model.change_subsampling_conv_chunking_factor(1) |
| | long_audio_settings_applied = True |
| | except Exception as e: |
| | gr.Warning(f"Could not apply long audio settings: {e}") |
| |
|
| | |
| | if device == "cuda": |
| | model.to("cuda") |
| | model.to(torch.bfloat16) |
| | else: |
| | model.to("cpu") |
| | model.to(torch.float32) |
| |
|
| | |
| | gr.Info("Transcribing audio...") |
| | print(f"DEBUG: Calling transcribe with timestamps={return_timestamps}") |
| | output = model.transcribe([processed_path], timestamps=return_timestamps) |
| | print(f"DEBUG: Transcription complete, got output type: {type(output)}") |
| |
|
| | if not output or not isinstance(output, list) or not output[0]: |
| | raise gr.Error("Transcription failed or produced unexpected output.") |
| |
|
| | |
| | transcription_text = output[0].text if hasattr(output[0], 'text') else str(output[0]) |
| | print(f"DEBUG: Extracted text: {transcription_text[:100] if transcription_text else 'empty'}...") |
| |
|
| | |
| | csv_path = None |
| | srt_path = None |
| |
|
| | if return_timestamps and hasattr(output[0], 'timestamp') and output[0].timestamp: |
| | timestamps = output[0].timestamp |
| |
|
| | |
| | if timestamp_level in timestamps: |
| | ts_data = timestamps[timestamp_level] |
| |
|
| | |
| | if timestamp_level == "segment": |
| | lines = [] |
| | for ts in ts_data: |
| | start = ts.get('start', 0) |
| | end = ts.get('end', 0) |
| | text = ts.get('segment', '') |
| | lines.append(f"[{start:.2f}s - {end:.2f}s] {text}") |
| | transcription_text = "\n".join(lines) |
| |
|
| | |
| | temp_dir = tempfile.mkdtemp() |
| |
|
| | |
| | csv_content = _generate_csv_content(ts_data) |
| | csv_path = os.path.join(temp_dir, "transcription.csv") |
| | with open(csv_path, 'w', encoding='utf-8') as f: |
| | f.write(csv_content) |
| |
|
| | |
| | srt_content = _generate_srt_content(ts_data) |
| | srt_path = os.path.join(temp_dir, "transcription.srt") |
| | with open(srt_path, 'w', encoding='utf-8') as f: |
| | f.write(srt_content) |
| |
|
| | elif timestamp_level == "word": |
| | lines = [] |
| | for ts in ts_data: |
| | start = ts.get('start', 0) |
| | end = ts.get('end', 0) |
| | word = ts.get('word', '') |
| | lines.append(f"[{start:.2f}s] {word}") |
| | transcription_text = "\n".join(lines) |
| |
|
| | elif timestamp_level == "char": |
| | lines = [] |
| | for ts in ts_data: |
| | start = ts.get('start', 0) |
| | char = ts.get('char', '') |
| | lines.append(f"[{start:.3f}s] {char}") |
| | transcription_text = "\n".join(lines) |
| |
|
| | gr.Info("Transcription complete!") |
| | print(f"DEBUG: Returning transcription of length {len(transcription_text)}") |
| |
|
| | |
| | return ( |
| | transcription_text, |
| | gr.update(value=csv_path, visible=csv_path is not None), |
| | gr.update(value=srt_path, visible=srt_path is not None), |
| | ) |
| |
|
| | except gr.Error: |
| | raise |
| | except torch.cuda.OutOfMemoryError: |
| | raise gr.Error("CUDA out of memory. Please try a shorter audio file.") |
| | except Exception as e: |
| | raise gr.Error(f"Transcription failed: {str(e)[:200]}") |
| |
|
| | finally: |
| | |
| | if long_audio_settings_applied: |
| | try: |
| | model.change_attention_model("rel_pos") |
| | model.change_subsampling_conv_chunking_factor(-1) |
| | except Exception: |
| | pass |
| |
|
| | |
| | if processed_path and processed_path != audio_path: |
| | try: |
| | os.remove(processed_path) |
| | os.rmdir(os.path.dirname(processed_path)) |
| | except Exception: |
| | pass |
| |
|
| | |
| | |
| |
|
| |
|
| | def _get_yt_html_embed(yt_url: str) -> str: |
| | """Generate YouTube embed HTML for display.""" |
| | video_id = yt_url.split("?v=")[-1].split("&")[0] |
| | return ( |
| | f'<center><iframe width="500" height="320" ' |
| | f'src="https://www.youtube.com/embed/{video_id}"></iframe></center>' |
| | ) |
| |
|
| |
|
| | def _download_yt_audio(yt_url: str, filepath: str) -> None: |
| | """Download audio from a YouTube URL.""" |
| | if youtube_dl is None: |
| | raise gr.Error("yt-dlp not installed. Please run: pip install yt-dlp") |
| |
|
| | info_loader = youtube_dl.YoutubeDL() |
| |
|
| | try: |
| | info = info_loader.extract_info(yt_url, download=False) |
| | except youtube_dl.utils.DownloadError as err: |
| | err_str = str(err) |
| | if "Failed to resolve" in err_str or "No address associated" in err_str: |
| | raise gr.Error( |
| | "YouTube download failed due to network restrictions. " |
| | "This feature requires running the app locally. " |
| | "On Hugging Face Spaces, outbound connections to YouTube are blocked." |
| | ) |
| | raise gr.Error(str(err)) |
| |
|
| | |
| | file_length = info.get("duration_string", "0") |
| | file_h_m_s = file_length.split(":") |
| | file_h_m_s = [int(sub_length) for sub_length in file_h_m_s] |
| |
|
| | if len(file_h_m_s) == 1: |
| | file_h_m_s.insert(0, 0) |
| | if len(file_h_m_s) == 2: |
| | file_h_m_s.insert(0, 0) |
| |
|
| | file_length_s = file_h_m_s[0] * 3600 + file_h_m_s[1] * 60 + file_h_m_s[2] |
| |
|
| | if file_length_s > YT_LENGTH_LIMIT_S: |
| | yt_limit_hms = time.strftime("%H:%M:%S", time.gmtime(YT_LENGTH_LIMIT_S)) |
| | file_hms = time.strftime("%H:%M:%S", time.gmtime(file_length_s)) |
| | raise gr.Error(f"Maximum YouTube length is {yt_limit_hms}, got {file_hms}.") |
| |
|
| | ydl_opts = { |
| | "outtmpl": filepath, |
| | "format": "worstvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best", |
| | } |
| |
|
| | with youtube_dl.YoutubeDL(ydl_opts) as ydl: |
| | try: |
| | ydl.download([yt_url]) |
| | except youtube_dl.utils.ExtractorError as err: |
| | raise gr.Error(str(err)) |
| |
|
| |
|
| | def transcribe_youtube( |
| | yt_url: str, |
| | return_timestamps: bool, |
| | timestamp_level: str, |
| | ): |
| | """ |
| | Transcribe a YouTube video. |
| | |
| | Yields tuples of (html_embed, transcription_text) for streaming updates. |
| | """ |
| | if not yt_url: |
| | raise gr.Error("Please provide a YouTube URL.") |
| |
|
| | if youtube_dl is None: |
| | raise gr.Error("yt-dlp not installed. Please run: pip install yt-dlp") |
| |
|
| | html_embed = _get_yt_html_embed(yt_url) |
| |
|
| | |
| | _init_parakeet() |
| | model = _PARAKEET_STATE["model"] |
| | device = _PARAKEET_STATE["device"] |
| |
|
| | |
| | with tempfile.TemporaryDirectory() as tmpdir: |
| | filepath = os.path.join(tmpdir, "video.mp4") |
| | |
| | |
| | yield html_embed, "Downloading video..." |
| | |
| | _download_yt_audio(yt_url, filepath) |
| | |
| | yield html_embed, "Processing audio..." |
| |
|
| | |
| | processed_path, duration_sec = _load_and_preprocess_audio(filepath) |
| | |
| | long_audio_settings_applied = False |
| | |
| | try: |
| | |
| | if duration_sec > LONG_AUDIO_THRESHOLD_S: |
| | try: |
| | model.change_attention_model("rel_pos_local_attn", [256, 256]) |
| | model.change_subsampling_conv_chunking_factor(1) |
| | long_audio_settings_applied = True |
| | except Exception: |
| | pass |
| |
|
| | |
| | if device == "cuda": |
| | model.to("cuda") |
| | model.to(torch.bfloat16) |
| | else: |
| | model.to("cpu") |
| | model.to(torch.float32) |
| |
|
| | yield html_embed, "Transcribing audio..." |
| | |
| | |
| | output = model.transcribe([processed_path], timestamps=return_timestamps) |
| |
|
| | if not output or not isinstance(output, list) or not output[0]: |
| | raise gr.Error("Transcription failed or produced unexpected output.") |
| |
|
| | |
| | transcription_text = output[0].text if hasattr(output[0], 'text') else str(output[0]) |
| |
|
| | |
| | if return_timestamps and hasattr(output[0], 'timestamp') and output[0].timestamp: |
| | timestamps = output[0].timestamp |
| | if timestamp_level in timestamps: |
| | ts_data = timestamps[timestamp_level] |
| | if timestamp_level == "segment": |
| | lines = [] |
| | for ts in ts_data: |
| | start = ts.get('start', 0) |
| | end = ts.get('end', 0) |
| | text = ts.get('segment', '') |
| | lines.append(f"[{start:.2f}s - {end:.2f}s] {text}") |
| | transcription_text = "\n".join(lines) |
| | elif timestamp_level == "word": |
| | lines = [] |
| | for ts in ts_data: |
| | start = ts.get('start', 0) |
| | word = ts.get('word', '') |
| | lines.append(f"[{start:.2f}s] {word}") |
| | transcription_text = "\n".join(lines) |
| |
|
| | yield html_embed, transcription_text |
| |
|
| | finally: |
| | |
| | if long_audio_settings_applied: |
| | try: |
| | model.change_attention_model("rel_pos") |
| | model.change_subsampling_conv_chunking_factor(-1) |
| | except Exception: |
| | pass |
| |
|
| | |
| | if processed_path != filepath: |
| | try: |
| | os.remove(processed_path) |
| | os.rmdir(os.path.dirname(processed_path)) |
| | except Exception: |
| | pass |
| |
|
| |
|
| | |
| | with gr.Blocks(title="Parakeet-ASR") as demo: |
| | with gr.Row(): |
| | with gr.Column(): |
| | audio_input = gr.Audio( |
| | label="Audio Input", |
| | sources=["microphone", "upload"], |
| | type="filepath", |
| | ) |
| |
|
| | timestamps_checkbox = gr.Checkbox( |
| | label="Return Timestamps", |
| | value=False, |
| | ) |
| |
|
| | timestamp_level_radio = gr.Radio( |
| | choices=["segment", "word", "char"], |
| | value="segment", |
| | label="Timestamp Level", |
| | info="Level of detail for timestamps", |
| | visible=False, |
| | ) |
| |
|
| | timestamps_checkbox.change( |
| | fn=lambda x: gr.Radio(visible=x), |
| | inputs=[timestamps_checkbox], |
| | outputs=[timestamp_level_radio], |
| | ) |
| |
|
| | transcribe_btn = gr.Button("Transcribe", variant="primary") |
| |
|
| | with gr.Column(): |
| | audio_output = gr.Textbox( |
| | label="Transcription", |
| | placeholder="Transcribed text will appear here...", |
| | lines=12, |
| | ) |
| |
|
| | with gr.Row(): |
| | download_csv_btn = gr.DownloadButton( |
| | label="Download CSV", |
| | visible=False, |
| | ) |
| | download_srt_btn = gr.DownloadButton( |
| | label="Download SRT", |
| | visible=False, |
| | ) |
| |
|
| | transcribe_btn.click( |
| | fn=transcribe_audio, |
| | inputs=[audio_input, timestamps_checkbox, timestamp_level_radio], |
| | outputs=[audio_output, download_csv_btn, download_srt_btn], |
| | api_name="transcribe", |
| | ) |
| |
|
| | if not IS_HF_SPACE: |
| | with gr.Row(): |
| | with gr.Column(): |
| | yt_url_input = gr.Textbox( |
| | label="YouTube URL", |
| | placeholder="Paste a YouTube video URL here...", |
| | lines=1, |
| | ) |
| |
|
| | yt_timestamps_checkbox = gr.Checkbox( |
| | label="Return Timestamps", |
| | value=False, |
| | ) |
| |
|
| | yt_timestamp_level_radio = gr.Radio( |
| | choices=["segment", "word"], |
| | value="segment", |
| | label="Timestamp Level", |
| | visible=False, |
| | ) |
| |
|
| | yt_timestamps_checkbox.change( |
| | fn=lambda x: gr.Radio(visible=x), |
| | inputs=[yt_timestamps_checkbox], |
| | outputs=[yt_timestamp_level_radio], |
| | ) |
| |
|
| | yt_transcribe_btn = gr.Button("Transcribe YouTube", variant="primary") |
| |
|
| | with gr.Column(): |
| | yt_embed = gr.HTML(label="Video") |
| | yt_output = gr.Textbox( |
| | label="Transcription", |
| | placeholder="Transcribed text will appear here...", |
| | lines=10, |
| | ) |
| |
|
| | yt_transcribe_btn.click( |
| | fn=transcribe_youtube, |
| | inputs=[yt_url_input, yt_timestamps_checkbox, yt_timestamp_level_radio], |
| | outputs=[yt_embed, yt_output], |
| | api_name="transcribe_youtube", |
| | ) |
| |
|
| |
|
| |
|
| | if __name__ == "__main__": |
| | demo.queue().launch(theme="Nymbo/Instantly_Theme") |
| |
|