import gradio as gr import torch import numpy as np from sys import platform from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor from transformers.utils import is_flash_attn_2_available from subtitle_manager import Subtitle import logging print(gr.__version__) logging.basicConfig(level=logging.INFO) # Global state pipe = None last_model = None def get_language_names(): return [ "af", "am", "ar", "as", "az", "ba", "be", "bg", "bn", "bo", "br", "bs", "ca", "cs", "cy", "da", "de", "el", "en", "es", "et", "eu", "fa", "fi", "fo", "fr", "gl", "gu", "ha", "haw", "he", "hi", "hr", "ht", "hu", "hy", "id", "is", "it", "ja", "jw", "ka", "kk", "km", "kn", "ko", "la", "lb", "ln", "lo", "lt", "lv", "mg", "mi", "mk", "ml", "mn", "mr", "ms", "mt", "my", "ne", "nl", "nn", "no", "oc", "pa", "pl", "ps", "pt", "ro", "ru", "sa", "sd", "si", "sk", "sl", "sn", "so", "sq", "sr", "su", "sv", "sw", "ta", "te", "tg", "th", "tk", "tl", "tr", "tt", "uk", "ur", "uz", "vi", "yi", "yo", "zh" ] def create_pipe(model_id, flash): device = "cuda:0" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 model = AutoModelForSpeechSeq2Seq.from_pretrained( model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, attn_implementation="flash_attention_2" if flash and is_flash_attn_2_available() else "sdpa", ).to(device) processor = AutoProcessor.from_pretrained(model_id) return pipeline( "automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, device=device, torch_dtype=torch_dtype, ) def transcribe_webui_simple_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task, flash, chunk_length_s, batch_size, progress=gr.Progress()): global last_model, pipe progress(0, desc="Loading Audio...") if last_model != modelName or pipe is None: torch.cuda.empty_cache() progress(0.1, desc="Loading Model...") pipe = create_pipe(modelName, flash) last_model = modelName files = [] if multipleFiles: files += multipleFiles if urlData: files.append(urlData) if microphoneData: files.append(microphoneData) srt_sub = Subtitle("srt") vtt_sub = Subtitle("vtt") txt_sub = Subtitle("txt") files_out = [] for file in progress.tqdm(files, desc="Working..."): outputs = pipe( file, chunk_length_s=chunk_length_s, batch_size=batch_size, generate_kwargs={ "language": languageName if languageName != "Automatic Detection" else None, "task": task }, return_timestamps=True, ) file_out = file.split('/')[-1] srt = srt_sub.get_subtitle(outputs["chunks"]) vtt = vtt_sub.get_subtitle(outputs["chunks"]) txt = txt_sub.get_subtitle(outputs["chunks"]) with open(file_out+".srt", 'w', encoding='utf-8') as f: f.write(srt) with open(file_out+".vtt", 'w', encoding='utf-8') as f: f.write(vtt) with open(file_out+".txt", 'w', encoding='utf-8') as f: f.write(txt) files_out += [file_out+".srt", file_out+".vtt", file_out+".txt"] progress(1, desc="Completed!") return files_out, vtt, txt # Realtime STT def transcribe_stream(buffer, new_chunk): sr, chunk = new_chunk if chunk.ndim > 1: chunk = chunk.mean(axis=1) chunk = chunk.astype(np.float32) peak = np.max(np.abs(chunk)) if peak > 0: chunk /= peak buffer = chunk if buffer is None else np.concatenate([buffer, chunk]) text = pipe({"sampling_rate": sr, "raw": buffer})["text"] return buffer, text # Gradio UI with gr.Blocks(title="Insanely Fast Whisper") as demo: gr.Markdown("## 🎙️ Insanely Fast Whisper + Real-time STT") whisper_models = [ "openai/whisper-tiny", "openai/whisper-tiny.en", "openai/whisper-base", "openai/whisper-base.en", "openai/whisper-small", "openai/whisper-small.en", "distil-whisper/distil-small.en", "openai/whisper-medium", "openai/whisper-medium.en", "distil-whisper/distil-medium.en", "openai/whisper-large", "openai/whisper-large-v1", "openai/whisper-large-v2", "distil-whisper/distil-large-v2", "openai/whisper-large-v3", "distil-whisper/distil-large-v3", ] with gr.Tab("File Transcription"): with gr.Row(): with gr.Column(): model_dropdown = gr.Dropdown( whisper_models, value="distil-whisper/distil-large-v2", label="Model" ) language_dropdown = gr.Dropdown( ["Automatic Detection"] + sorted(get_language_names()), value="Automatic Detection", label="Language" ) url_input = gr.Text(label="URL (YouTube, etc.)") file_input = gr.File(label="Upload Files", file_count="multiple") audio_input = gr.Audio( sources=["upload", "microphone"], type="filepath", label="Audio Input" ) task_dropdown = gr.Dropdown( ["transcribe", "translate"], label="Task", value="transcribe" ) flash_checkbox = gr.Checkbox(label='Flash', info='Use Flash Attention 2') chunk_length = gr.Number(label='chunk_length_s', value=30) batch_size = gr.Number(label='batch_size', value=24) transcribe_button = gr.Button("Transcribe") with gr.Column(): output_files = gr.File(label="Download") output_text = gr.Text(label="Transcription") output_segments = gr.Text(label="Segments") transcribe_button.click( fn=transcribe_webui_simple_progress, inputs=[ model_dropdown, language_dropdown, url_input, file_input, audio_input, task_dropdown, flash_checkbox, chunk_length, batch_size ], outputs=[output_files, output_text, output_segments] ) with gr.Tab("Real-time Transcription"): st_buffer = gr.State() mic_rt = gr.Audio( sources=["microphone"], type="numpy", streaming=True, label="🎤 Speak Now (Live Transcription)" ) txt_rt = gr.Textbox(label="Real-time Transcription") mic_rt.stream( fn=transcribe_stream, inputs=[st_buffer, mic_rt], outputs=[st_buffer, txt_rt] ) # Preload model for Hugging Face spaces def load_model(): global pipe, last_model last_model = "distil-whisper/distil-large-v2" pipe = create_pipe(last_model, flash=False) demo.load(load_model) # Launch the app if __name__ == "__main__": demo.launch()