Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from sys import platform | |
| from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor | |
| from transformers.utils import is_flash_attn_2_available | |
| from subtitle_manager import Subtitle | |
| import logging | |
| print(gr.__version__) | |
| logging.basicConfig(level=logging.INFO) | |
| # Global state | |
| pipe = None | |
| last_model = None | |
| def get_language_names(): | |
| return [ | |
| "af", "am", "ar", "as", "az", "ba", "be", "bg", "bn", "bo", "br", "bs", | |
| "ca", "cs", "cy", "da", "de", "el", "en", "es", "et", "eu", "fa", "fi", | |
| "fo", "fr", "gl", "gu", "ha", "haw", "he", "hi", "hr", "ht", "hu", "hy", | |
| "id", "is", "it", "ja", "jw", "ka", "kk", "km", "kn", "ko", "la", "lb", | |
| "ln", "lo", "lt", "lv", "mg", "mi", "mk", "ml", "mn", "mr", "ms", "mt", | |
| "my", "ne", "nl", "nn", "no", "oc", "pa", "pl", "ps", "pt", "ro", "ru", | |
| "sa", "sd", "si", "sk", "sl", "sn", "so", "sq", "sr", "su", "sv", "sw", | |
| "ta", "te", "tg", "th", "tk", "tl", "tr", "tt", "uk", "ur", "uz", "vi", | |
| "yi", "yo", "zh" | |
| ] | |
| def create_pipe(model_id, flash): | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
| model_id, | |
| torch_dtype=torch_dtype, | |
| low_cpu_mem_usage=True, | |
| use_safetensors=True, | |
| attn_implementation="flash_attention_2" if flash and is_flash_attn_2_available() else "sdpa", | |
| ).to(device) | |
| processor = AutoProcessor.from_pretrained(model_id) | |
| return pipeline( | |
| "automatic-speech-recognition", | |
| model=model, | |
| tokenizer=processor.tokenizer, | |
| feature_extractor=processor.feature_extractor, | |
| device=device, | |
| torch_dtype=torch_dtype, | |
| ) | |
| def transcribe_webui_simple_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task, flash, | |
| chunk_length_s, batch_size, progress=gr.Progress()): | |
| global last_model, pipe | |
| progress(0, desc="Loading Audio...") | |
| if last_model != modelName or pipe is None: | |
| torch.cuda.empty_cache() | |
| progress(0.1, desc="Loading Model...") | |
| pipe = create_pipe(modelName, flash) | |
| last_model = modelName | |
| files = [] | |
| if multipleFiles: | |
| files += multipleFiles | |
| if urlData: | |
| files.append(urlData) | |
| if microphoneData: | |
| files.append(microphoneData) | |
| srt_sub = Subtitle("srt") | |
| vtt_sub = Subtitle("vtt") | |
| txt_sub = Subtitle("txt") | |
| files_out = [] | |
| for file in progress.tqdm(files, desc="Working..."): | |
| outputs = pipe( | |
| file, | |
| chunk_length_s=chunk_length_s, | |
| batch_size=batch_size, | |
| generate_kwargs={ | |
| "language": languageName if languageName != "Automatic Detection" else None, | |
| "task": task | |
| }, | |
| return_timestamps=True, | |
| ) | |
| file_out = file.split('/')[-1] | |
| srt = srt_sub.get_subtitle(outputs["chunks"]) | |
| vtt = vtt_sub.get_subtitle(outputs["chunks"]) | |
| txt = txt_sub.get_subtitle(outputs["chunks"]) | |
| with open(file_out+".srt", 'w', encoding='utf-8') as f: | |
| f.write(srt) | |
| with open(file_out+".vtt", 'w', encoding='utf-8') as f: | |
| f.write(vtt) | |
| with open(file_out+".txt", 'w', encoding='utf-8') as f: | |
| f.write(txt) | |
| files_out += [file_out+".srt", file_out+".vtt", file_out+".txt"] | |
| progress(1, desc="Completed!") | |
| return files_out, vtt, txt | |
| # Realtime STT | |
| def transcribe_stream(buffer, new_chunk): | |
| sr, chunk = new_chunk | |
| if chunk.ndim > 1: | |
| chunk = chunk.mean(axis=1) | |
| chunk = chunk.astype(np.float32) | |
| peak = np.max(np.abs(chunk)) | |
| if peak > 0: | |
| chunk /= peak | |
| buffer = chunk if buffer is None else np.concatenate([buffer, chunk]) | |
| text = pipe({"sampling_rate": sr, "raw": buffer})["text"] | |
| return buffer, text | |
| # Gradio UI | |
| with gr.Blocks(title="Insanely Fast Whisper") as demo: | |
| gr.Markdown("## 🎙️ Insanely Fast Whisper + Real-time STT") | |
| whisper_models = [ | |
| "openai/whisper-tiny", "openai/whisper-tiny.en", | |
| "openai/whisper-base", "openai/whisper-base.en", | |
| "openai/whisper-small", "openai/whisper-small.en", | |
| "distil-whisper/distil-small.en", | |
| "openai/whisper-medium", "openai/whisper-medium.en", | |
| "distil-whisper/distil-medium.en", | |
| "openai/whisper-large", "openai/whisper-large-v1", | |
| "openai/whisper-large-v2", "distil-whisper/distil-large-v2", | |
| "openai/whisper-large-v3", "distil-whisper/distil-large-v3", | |
| ] | |
| with gr.Tab("File Transcription"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| model_dropdown = gr.Dropdown( | |
| whisper_models, | |
| value="distil-whisper/distil-large-v2", | |
| label="Model" | |
| ) | |
| language_dropdown = gr.Dropdown( | |
| ["Automatic Detection"] + sorted(get_language_names()), | |
| value="Automatic Detection", | |
| label="Language" | |
| ) | |
| url_input = gr.Text(label="URL (YouTube, etc.)") | |
| file_input = gr.File(label="Upload Files", file_count="multiple") | |
| audio_input = gr.Audio( | |
| sources=["upload", "microphone"], | |
| type="filepath", | |
| label="Audio Input" | |
| ) | |
| task_dropdown = gr.Dropdown( | |
| ["transcribe", "translate"], | |
| label="Task", | |
| value="transcribe" | |
| ) | |
| flash_checkbox = gr.Checkbox(label='Flash', info='Use Flash Attention 2') | |
| chunk_length = gr.Number(label='chunk_length_s', value=30) | |
| batch_size = gr.Number(label='batch_size', value=24) | |
| transcribe_button = gr.Button("Transcribe") | |
| with gr.Column(): | |
| output_files = gr.File(label="Download") | |
| output_text = gr.Text(label="Transcription") | |
| output_segments = gr.Text(label="Segments") | |
| transcribe_button.click( | |
| fn=transcribe_webui_simple_progress, | |
| inputs=[ | |
| model_dropdown, language_dropdown, url_input, | |
| file_input, audio_input, task_dropdown, | |
| flash_checkbox, chunk_length, batch_size | |
| ], | |
| outputs=[output_files, output_text, output_segments] | |
| ) | |
| with gr.Tab("Real-time Transcription"): | |
| st_buffer = gr.State() | |
| mic_rt = gr.Audio( | |
| sources=["microphone"], type="numpy", streaming=True, | |
| label="🎤 Speak Now (Live Transcription)" | |
| ) | |
| txt_rt = gr.Textbox(label="Real-time Transcription") | |
| mic_rt.stream( | |
| fn=transcribe_stream, | |
| inputs=[st_buffer, mic_rt], | |
| outputs=[st_buffer, txt_rt] | |
| ) | |
| # Preload model for Hugging Face spaces | |
| def load_model(): | |
| global pipe, last_model | |
| last_model = "distil-whisper/distil-large-v2" | |
| pipe = create_pipe(last_model, flash=False) | |
| demo.load(load_model) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch() | |