Spaces:
Runtime error
Runtime error
| import os | |
| import json | |
| import uuid | |
| import tempfile | |
| import subprocess | |
| import re | |
| import gradio as gr | |
| import pytube as pt | |
| import nemo.collections.asr as nemo_asr | |
| import speech_to_text_buffered_infer_ctc as buffered_ctc | |
| import speech_to_text_buffered_infer_rnnt as buffered_rnnt | |
| # Set NeMo cache dir as /tmp | |
| from nemo import constants | |
| os.environ[constants.NEMO_ENV_CACHE_DIR] = "/tmp/nemo" | |
| SAMPLE_RATE = 16000 | |
| TITLE = "NeMo ASR Inference on Hugging Face" | |
| DESCRIPTION = "Demo of all languages supported by NeMo ASR" | |
| DEFAULT_EN_MODEL = "nvidia/stt_en_conformer_transducer_xlarge" | |
| MARKDOWN = f""" | |
| # {TITLE} | |
| ## {DESCRIPTION} | |
| """ | |
| CSS = """ | |
| p.big { | |
| font-size: 20px; | |
| } | |
| """ | |
| ARTICLE = """ | |
| <br><br> | |
| <p class='big' style='text-align: center'> | |
| <a href='https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/intro.html' target='_blank'>NeMo ASR</a> | |
| | | |
| <a href='https://github.com/NVIDIA/NeMo#nvidia-nemo' target='_blank'>Github Repo</a> | |
| </p> | |
| """ | |
| SUPPORTED_LANGUAGES = set([]) | |
| SUPPORTED_MODEL_NAMES = set([]) | |
| # HF models, grouped by language identifier | |
| hf_filter = nemo_asr.models.ASRModel.get_hf_model_filter() | |
| hf_filter.task = "automatic-speech-recognition" | |
| hf_infos = nemo_asr.models.ASRModel.search_huggingface_models(model_filter=hf_filter) | |
| for info in hf_infos: | |
| lang_id = info.modelId.split("_")[1] # obtains lang id as str | |
| SUPPORTED_LANGUAGES.add(lang_id) | |
| SUPPORTED_MODEL_NAMES.add(info.modelId) | |
| SUPPORTED_MODEL_NAMES = sorted(list(SUPPORTED_MODEL_NAMES)) | |
| model_dict = {model_name: gr.Interface.load(f'models/{model_name}') for model_name in SUPPORTED_MODEL_NAMES} | |
| SUPPORTED_LANG_MODEL_DICT = {} | |
| for lang in SUPPORTED_LANGUAGES: | |
| for model_id in SUPPORTED_MODEL_NAMES: | |
| if ("_" + lang + "_") in model_id: | |
| # create new lang in dict | |
| if lang not in SUPPORTED_LANG_MODEL_DICT: | |
| SUPPORTED_LANG_MODEL_DICT[lang] = [model_id] | |
| else: | |
| SUPPORTED_LANG_MODEL_DICT[lang].append(model_id) | |
| # Sort model names | |
| for lang in SUPPORTED_LANG_MODEL_DICT.keys(): | |
| model_ids = SUPPORTED_LANG_MODEL_DICT[lang] | |
| model_ids = sorted(model_ids) | |
| SUPPORTED_LANG_MODEL_DICT[lang] = model_ids | |
| def parse_duration(audio_file): | |
| """ | |
| FFMPEG to calculate durations. Libraries can do it too, but filetypes cause different libraries to behave differently. | |
| """ | |
| process = subprocess.Popen(['ffmpeg', '-i', audio_file], stdout=subprocess.PIPE, stderr=subprocess.STDOUT) | |
| stdout, stderr = process.communicate() | |
| matches = re.search( | |
| r"Duration:\s{1}(?P<hours>\d+?):(?P<minutes>\d+?):(?P<seconds>\d+\.\d+?),", stdout.decode(), re.DOTALL | |
| ).groupdict() | |
| duration = 0.0 | |
| duration += float(matches['hours']) * 60.0 * 60.0 | |
| duration += float(matches['minutes']) * 60.0 | |
| duration += float(matches['seconds']) * 1.0 | |
| return duration | |
| def resolve_model_type(model_name: str) -> str: | |
| """ | |
| Map model name to a class type, without loading the model. Has some hardcoded assumptions in | |
| semantics of model naming. | |
| """ | |
| # Loss specific maps | |
| if 'hybrid' in model_name or 'hybrid_ctc' in model_name or 'hybrid_transducer' in model_name: | |
| return 'hybrid' | |
| elif 'transducer' in model_name or 'rnnt' in model_id: | |
| return 'transducer' | |
| elif 'ctc' in model_name: | |
| return 'ctc' | |
| # Model specific maps | |
| elif 'jasper' in model_name: | |
| return 'ctc' | |
| elif 'quartznet' in model_name: | |
| return 'ctc' | |
| elif 'citrinet' in model_name: | |
| return 'ctc' | |
| elif 'contextnet' in model_name: | |
| return 'ctc' | |
| else: | |
| # Unknown model type | |
| return None | |
| def resolve_model_stride(model_name) -> int: | |
| """ | |
| Model specific pre-calc of stride levels. | |
| Dont laod model to get such info. | |
| """ | |
| if 'jasper' in model_name: | |
| return 2 | |
| if 'quartznet' in model_name: | |
| return 2 | |
| if 'conformer' in model_name: | |
| return 4 | |
| if 'squeezeformer' in model_name: | |
| return 4 | |
| if 'citrinet' in model_name: | |
| return 8 | |
| if 'contextnet' in model_name: | |
| return 8 | |
| return -1 | |
| def convert_audio(audio_filepath): | |
| """ | |
| Transcode all mp3 files to monochannel 16 kHz wav files. | |
| """ | |
| filedir = os.path.split(audio_filepath)[0] | |
| filename, ext = os.path.splitext(audio_filepath) | |
| if ext == 'wav': | |
| return audio_filepath | |
| out_filename = os.path.join(filedir, filename + '.wav') | |
| process = subprocess.Popen( | |
| ['ffmpeg', '-i', audio_filepath, '-ac', '1', '-ar', str(SAMPLE_RATE), out_filename], | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.STDOUT, | |
| ) | |
| stdout, stderr = process.communicate() | |
| if os.path.exists(out_filename): | |
| return out_filename | |
| else: | |
| return None | |
| def extract_result_from_manifest(filepath, model_name) -> (bool, str): | |
| """ | |
| Parse the written manifest which is result of the buffered inference process. | |
| """ | |
| data = [] | |
| with open(filepath, 'r', encoding='utf-8') as f: | |
| for line in f: | |
| try: | |
| line = json.loads(line) | |
| data.append(line['pred_text']) | |
| except Exception as e: | |
| pass | |
| if len(data) > 0: | |
| return True, data[0] | |
| else: | |
| return False, f"Could not perform inference on model with name : {model_name}" | |
| def infer_audio(model_name: str, audio_file: str) -> str: | |
| """ | |
| Main method that switches from HF inference for small audio files to Buffered CTC/RNNT mode for long audio files. | |
| Args: | |
| model_name: Str name of the model (potentially with / to denote HF models) | |
| audio_file: Path to an audio file (mp3 or wav) | |
| Returns: | |
| str which is the transcription if successful. | |
| """ | |
| # Parse the duration of the audio file | |
| duration = parse_duration(audio_file) | |
| if duration > 60.0: # Longer than one minute; use buffered mode | |
| # Process audio to be of wav type (possible youtube audio) | |
| audio_file = convert_audio(audio_file) | |
| # If audio file transcoding failed, let user know | |
| if audio_file is None: | |
| return "Failed to convert audio file to wav." | |
| # Extract audio dir from resolved audio filepath | |
| audio_dir = os.path.split(audio_file)[0] | |
| # Next calculate the stride of each model | |
| model_stride = resolve_model_stride(model_name) | |
| if model_stride < 0: | |
| return f"Failed to compute the model stride for model with name : {model_name}" | |
| # Process model type (CTC/RNNT/Hybrid) | |
| model_type = resolve_model_type(model_name) | |
| if model_type is None: | |
| # Model type could not be infered. | |
| # Try all feasible options | |
| RESULT = None | |
| try: | |
| ctc_config = buffered_ctc.TranscriptionConfig( | |
| pretrained_name=model_name, | |
| audio_dir=audio_dir, | |
| output_filename="output.json", | |
| audio_type="wav", | |
| overwrite_transcripts=True, | |
| model_stride=model_stride, | |
| chunk_len_in_secs=20.0, | |
| total_buffer_in_secs=30.0, | |
| ) | |
| buffered_ctc.main(ctc_config) | |
| result = extract_result_from_manifest('output.json', model_name) | |
| if result[0]: | |
| RESULT = result[1] | |
| except Exception as e: | |
| pass | |
| try: | |
| rnnt_config = buffered_rnnt.TranscriptionConfig( | |
| pretrained_name=model_name, | |
| audio_dir=audio_dir, | |
| output_filename="output.json", | |
| audio_type="wav", | |
| overwrite_transcripts=True, | |
| model_stride=model_stride, | |
| chunk_len_in_secs=20.0, | |
| total_buffer_in_secs=30.0, | |
| ) | |
| buffered_rnnt.main(rnnt_config) | |
| result = extract_result_from_manifest('output.json', model_name)[-1] | |
| if result[0]: | |
| RESULT = result[1] | |
| except Exception as e: | |
| pass | |
| if RESULT is None: | |
| return f"Could not parse model type; failed to perform inference with model {model_name}!" | |
| elif model_type == 'ctc': | |
| # CTC Buffered Inference | |
| ctc_config = buffered_ctc.TranscriptionConfig( | |
| pretrained_name=model_name, | |
| audio_dir=audio_dir, | |
| output_filename="output.json", | |
| audio_type="wav", | |
| overwrite_transcripts=True, | |
| model_stride=model_stride, | |
| chunk_len_in_secs=20.0, | |
| total_buffer_in_secs=30.0, | |
| ) | |
| buffered_ctc.main(ctc_config) | |
| return extract_result_from_manifest('output.json', model_name)[-1] | |
| elif model_type == 'transducer': | |
| # RNNT Buffered Inference | |
| rnnt_config = buffered_rnnt.TranscriptionConfig( | |
| pretrained_name=model_name, | |
| audio_dir=audio_dir, | |
| output_filename="output.json", | |
| audio_type="wav", | |
| overwrite_transcripts=True, | |
| model_stride=model_stride, | |
| chunk_len_in_secs=20.0, | |
| total_buffer_in_secs=30.0, | |
| ) | |
| buffered_rnnt.main(rnnt_config) | |
| return extract_result_from_manifest('output.json', model_name)[-1] | |
| else: | |
| return f"Could not parse model type; failed to perform inference with model {model_name}!" | |
| else: | |
| if model_name in model_dict: | |
| model = model_dict[model_name] | |
| else: | |
| model = None | |
| if model is not None: | |
| # Use HF API for transcription | |
| transcriptions = model(audio_file) | |
| return transcriptions | |
| else: | |
| error = ( | |
| f"Could not find model {model_name} in list of available models : " | |
| f"{list([k for k in model_dict.keys()])}" | |
| ) | |
| return error | |
| def transcribe(microphone, audio_file, model_name): | |
| warn_output = "" | |
| if (microphone is not None) and (audio_file is not None): | |
| warn_output = ( | |
| "WARNING: You've uploaded an audio file and used the microphone. " | |
| "The recorded file from the microphone will be used and the uploaded audio will be discarded.\n" | |
| ) | |
| audio_data = microphone | |
| elif (microphone is None) and (audio_file is None): | |
| return "ERROR: You have to either use the microphone or upload an audio file" | |
| elif microphone is not None: | |
| audio_data = microphone | |
| else: | |
| audio_data = audio_file | |
| try: | |
| # Use HF API for transcription | |
| transcriptions = infer_audio(model_name, audio_data) | |
| except Exception as e: | |
| transcriptions = "" | |
| warn_output = warn_output + "\n\n" | |
| warn_output += ( | |
| f"The model `{model_name}` is currently loading and cannot be used " | |
| f"for transcription.\n" | |
| f"Please try another model or wait a few minutes." | |
| ) | |
| return warn_output + transcriptions | |
| def _return_yt_html_embed(yt_url): | |
| video_id = yt_url.split("?v=")[-1] | |
| HTML_str = ( | |
| f'<center> <iframe width="500" height="320" src="https://www.youtube.com/embed/{video_id}"> </iframe>' | |
| " </center>" | |
| ) | |
| return HTML_str | |
| def yt_transcribe(yt_url, model_name): | |
| yt = pt.YouTube(yt_url) | |
| html_embed_str = _return_yt_html_embed(yt_url) | |
| with tempfile.TemporaryDirectory() as tempdir: | |
| file_uuid = str(uuid.uuid4().hex) | |
| file_uuid = f"{tempdir}/{file_uuid}.mp3" | |
| stream = yt.streams.filter(only_audio=True)[0] | |
| stream.download(filename=file_uuid) | |
| text = infer_audio(model_name, file_uuid) | |
| return html_embed_str, text | |
| def create_lang_selector_component(default_en_model=DEFAULT_EN_MODEL): | |
| lang_selector = gr.components.Dropdown( | |
| choices=sorted(list(SUPPORTED_LANGUAGES)), value="en", type="value", label="Languages", interactive=True, | |
| ) | |
| models_in_lang = gr.components.Dropdown( | |
| choices=sorted(list(SUPPORTED_LANG_MODEL_DICT["en"])), | |
| value=default_en_model, | |
| label="Models", | |
| interactive=True, | |
| ) | |
| def update_models_with_lang(lang): | |
| models_names = sorted(list(SUPPORTED_LANG_MODEL_DICT[lang])) | |
| default = models_names[0] | |
| if lang == 'en': | |
| default = default_en_model | |
| return models_in_lang.update(choices=models_names, value=default) | |
| lang_selector.change(update_models_with_lang, inputs=[lang_selector], outputs=[models_in_lang]) | |
| return lang_selector, models_in_lang | |
| demo = gr.Blocks(title=TITLE, css=CSS) | |
| with demo: | |
| header = gr.Markdown(MARKDOWN) | |
| with gr.Tab("Transcribe Audio"): | |
| with gr.Row() as row: | |
| file_upload = gr.components.Audio(source="upload", type='filepath', label='Upload File') | |
| microphone = gr.components.Audio(source="microphone", type='filepath', label='Microphone') | |
| lang_selector, models_in_lang = create_lang_selector_component() | |
| transcript = gr.components.Label(label='Transcript') | |
| run = gr.components.Button('Transcribe') | |
| run.click(transcribe, inputs=[microphone, file_upload, models_in_lang], outputs=[transcript]) | |
| with gr.Tab("Transcribe Youtube"): | |
| yt_url = gr.components.Textbox( | |
| lines=1, label="Youtube URL", placeholder="Paste the URL to a YouTube video here" | |
| ) | |
| lang_selector_yt, models_in_lang_yt = create_lang_selector_component( | |
| default_en_model='nvidia/stt_en_conformer_transducer_large' | |
| ) | |
| embedded_video = gr.components.HTML() | |
| transcript = gr.components.Label(label='Transcript') | |
| run = gr.components.Button('Transcribe YouTube') | |
| run.click(yt_transcribe, inputs=[yt_url, models_in_lang_yt], outputs=[embedded_video, transcript]) | |
| gr.components.HTML(ARTICLE) | |
| demo.queue(concurrency_count=1) | |
| demo.launch(enable_queue=True) | |