Spaces:
Paused
Paused
| import os | |
| import gradio as gr | |
| import spaces | |
| from huggingface_hub import get_collection, HfApi | |
| from transformers import pipeline, Pipeline | |
| is_hf_space = os.getenv("IS_HF_SPACE") | |
| def get_dropdown_model_ids(): | |
| mozilla_ai_model_ids = [] | |
| # Get model ids from collection and append the language in () from the model's metadata | |
| for model_i in get_collection( | |
| "mozilla-ai/common-voice-whisper-67b847a74ad7561781aa10fd" | |
| ).items: | |
| model_metadata = HfApi().model_info(model_i.item_id) | |
| language = model_metadata.card_data.model_name.split("on ")[1] | |
| mozilla_ai_model_ids.append(model_i.item_id + f" ({language})") | |
| return ( | |
| [""] | |
| + mozilla_ai_model_ids | |
| + [ | |
| "openai/whisper-tiny (Multilingual)", | |
| "openai/whisper-small (Multilingual)", | |
| "openai/whisper-medium (Multilingual)", | |
| "openai/whisper-large-v3 (Multilingual)", | |
| "openai/whisper-large-v3-turbo (Multilingual)", | |
| ] | |
| ) | |
| def _load_local_model(model_dir: str) -> Pipeline | str: | |
| from transformers import WhisperProcessor, WhisperForConditionalGeneration | |
| processor = WhisperProcessor.from_pretrained(model_dir) | |
| model = WhisperForConditionalGeneration.from_pretrained(model_dir) | |
| try: | |
| return pipeline( | |
| task="automatic-speech-recognition", | |
| model=model, | |
| processor=processor, | |
| chunk_length_s=30, # max input duration for whisper | |
| ) | |
| except Exception as e: | |
| return str(e) | |
| def _load_hf_model(model_repo_id: str) -> Pipeline | str: | |
| try: | |
| return pipeline( | |
| "automatic-speech-recognition", | |
| model=model_repo_id, | |
| chunk_length_s=30, # max input duration for whisper | |
| ) | |
| except Exception as e: | |
| return str(e) | |
| # Copied from https://github.com/openai/whisper/blob/517a43ecd132a2089d85f4ebc044728a71d49f6e/whisper/utils.py#L50 | |
| def format_timestamp( | |
| seconds: float, always_include_hours: bool = False, decimal_marker: str = "." | |
| ): | |
| assert seconds >= 0, "non-negative timestamp expected" | |
| milliseconds = round(seconds * 1000.0) | |
| hours = milliseconds // 3_600_000 | |
| milliseconds -= hours * 3_600_000 | |
| minutes = milliseconds // 60_000 | |
| milliseconds -= minutes * 60_000 | |
| seconds = milliseconds // 1_000 | |
| milliseconds -= seconds * 1_000 | |
| hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else "" | |
| return ( | |
| f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" | |
| ) | |
| def transcribe( | |
| dropdown_model_id: str, | |
| hf_model_id: str, | |
| local_model_id: str, | |
| audio: gr.Audio, | |
| show_timestamps: bool, | |
| ) -> str: | |
| if dropdown_model_id and not hf_model_id and not local_model_id: | |
| dropdown_model_id = dropdown_model_id.split(" (")[0] | |
| pipe = _load_hf_model(dropdown_model_id) | |
| elif hf_model_id and not local_model_id and not dropdown_model_id: | |
| pipe = _load_hf_model(hf_model_id) | |
| elif local_model_id and not hf_model_id and not dropdown_model_id: | |
| pipe = _load_local_model(local_model_id) | |
| else: | |
| return ( | |
| "⚠️ Error: Please select or fill at least and only one of the options above" | |
| ) | |
| if isinstance(pipe, str): | |
| # Exception raised when loading | |
| return f"⚠️ Error: {pipe}" | |
| output = pipe( | |
| audio, | |
| generate_kwargs={"task": "transcribe"}, | |
| batch_size=16, | |
| return_timestamps=show_timestamps, | |
| ) | |
| text = output["text"] | |
| if show_timestamps: | |
| timestamps = output["chunks"] | |
| timestamps = [ | |
| f"[{format_timestamp(chunk['timestamp'][0])} -> {format_timestamp(chunk['timestamp'][1])}] {chunk['text']}" | |
| for chunk in timestamps | |
| ] | |
| text = "\n".join(str(feature) for feature in timestamps) | |
| return text | |
| def setup_gradio_demo(): | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ # 🗣️ Speech-to-Text Transcription | |
| ### 1. Select which model to use from one of the options below. | |
| ### 2. Record a message or upload an audio file. | |
| ### 3. Click Transcribe to see the transcription generated by the model. | |
| """ | |
| ) | |
| ### Model selection ### | |
| model_ids = get_dropdown_model_ids() | |
| with gr.Row(): | |
| with gr.Column(): | |
| dropdown_model = gr.Dropdown( | |
| choices=model_ids, label="Option 1: Select a model" | |
| ) | |
| with gr.Column(): | |
| user_model = gr.Textbox( | |
| label="Option 2: Paste HF model id", | |
| placeholder="my-username/my-whisper-tiny", | |
| ) | |
| with gr.Column(visible=not is_hf_space): | |
| local_model = gr.Textbox( | |
| label="Option 3: Paste local path to model directory", | |
| placeholder="artifacts/my-whisper-tiny", | |
| ) | |
| ### Transcription ### | |
| with gr.Group(): | |
| audio_input = gr.Audio( | |
| sources=["microphone", "upload"], | |
| type="filepath", | |
| label="Record a message / Upload audio file", | |
| show_download_button=True, | |
| ) | |
| timestamps_check = gr.Checkbox(label="Show timestamps") | |
| transcribe_button = gr.Button("Transcribe") | |
| transcribe_output = gr.Text(label="Output") | |
| transcribe_button.click( | |
| fn=transcribe, | |
| inputs=[ | |
| dropdown_model, | |
| user_model, | |
| local_model, | |
| audio_input, | |
| timestamps_check, | |
| ], | |
| outputs=transcribe_output, | |
| ) | |
| demo.launch() | |
| if __name__ == "__main__": | |
| setup_gradio_demo() | |