Spaces:
Runtime error
Runtime error
| import json | |
| import logging | |
| import os | |
| import shutil | |
| import nltk | |
| import wget | |
| from omegaconf import OmegaConf | |
| from whisperx.alignment import DEFAULT_ALIGN_MODELS_HF, DEFAULT_ALIGN_MODELS_TORCH | |
| from whisperx.utils import LANGUAGES, TO_LANGUAGE_CODE | |
| punct_model_langs = [ | |
| "en", | |
| "fr", | |
| "de", | |
| "es", | |
| "it", | |
| "nl", | |
| "pt", | |
| "bg", | |
| "pl", | |
| "cs", | |
| "sk", | |
| "sl", | |
| ] | |
| wav2vec2_langs = list(DEFAULT_ALIGN_MODELS_TORCH.keys()) + list( | |
| DEFAULT_ALIGN_MODELS_HF.keys() | |
| ) | |
| whisper_langs = sorted(LANGUAGES.keys()) + sorted( | |
| [k.title() for k in TO_LANGUAGE_CODE.keys()] | |
| ) | |
| langs_to_iso = { | |
| "aa": "aar", | |
| "ab": "abk", | |
| "ae": "ave", | |
| "af": "afr", | |
| "ak": "aka", | |
| "am": "amh", | |
| "an": "arg", | |
| "ar": "ara", | |
| "as": "asm", | |
| "av": "ava", | |
| "ay": "aym", | |
| "az": "aze", | |
| "ba": "bak", | |
| "be": "bel", | |
| "bg": "bul", | |
| "bh": "bih", | |
| "bi": "bis", | |
| "bm": "bam", | |
| "bn": "ben", | |
| "bo": "tib", | |
| "br": "bre", | |
| "bs": "bos", | |
| "ca": "cat", | |
| "ce": "che", | |
| "ch": "cha", | |
| "co": "cos", | |
| "cr": "cre", | |
| "cs": "cze", | |
| "cu": "chu", | |
| "cv": "chv", | |
| "cy": "wel", | |
| "da": "dan", | |
| "de": "ger", | |
| "dv": "div", | |
| "dz": "dzo", | |
| "ee": "ewe", | |
| "el": "gre", | |
| "en": "eng", | |
| "eo": "epo", | |
| "es": "spa", | |
| "et": "est", | |
| "eu": "baq", | |
| "fa": "per", | |
| "ff": "ful", | |
| "fi": "fin", | |
| "fj": "fij", | |
| "fo": "fao", | |
| "fr": "fre", | |
| "fy": "fry", | |
| "ga": "gle", | |
| "gd": "gla", | |
| "gl": "glg", | |
| "gn": "grn", | |
| "gu": "guj", | |
| "gv": "glv", | |
| "ha": "hau", | |
| "he": "heb", | |
| "hi": "hin", | |
| "ho": "hmo", | |
| "hr": "hrv", | |
| "ht": "hat", | |
| "hu": "hun", | |
| "hy": "arm", | |
| "hz": "her", | |
| "ia": "ina", | |
| "id": "ind", | |
| "ie": "ile", | |
| "ig": "ibo", | |
| "ii": "iii", | |
| "ik": "ipk", | |
| "io": "ido", | |
| "is": "ice", | |
| "it": "ita", | |
| "iu": "iku", | |
| "ja": "jpn", | |
| "jv": "jav", | |
| "ka": "geo", | |
| "kg": "kon", | |
| "ki": "kik", | |
| "kj": "kua", | |
| "kk": "kaz", | |
| "kl": "kal", | |
| "km": "khm", | |
| "kn": "kan", | |
| "ko": "kor", | |
| "kr": "kau", | |
| "ks": "kas", | |
| "ku": "kur", | |
| "kv": "kom", | |
| "kw": "cor", | |
| "ky": "kir", | |
| "la": "lat", | |
| "lb": "ltz", | |
| "lg": "lug", | |
| "li": "lim", | |
| "ln": "lin", | |
| "lo": "lao", | |
| "lt": "lit", | |
| "lu": "lub", | |
| "lv": "lav", | |
| "mg": "mlg", | |
| "mh": "mah", | |
| "mi": "mao", | |
| "mk": "mac", | |
| "ml": "mal", | |
| "mn": "mon", | |
| "mr": "mar", | |
| "ms": "may", | |
| "mt": "mlt", | |
| "my": "bur", | |
| "na": "nau", | |
| "nb": "nob", | |
| "nd": "nde", | |
| "ne": "nep", | |
| "ng": "ndo", | |
| "nl": "dut", | |
| "nn": "nno", | |
| "no": "nor", | |
| "nr": "nbl", | |
| "nv": "nav", | |
| "ny": "nya", | |
| "oc": "oci", | |
| "oj": "oji", | |
| "om": "orm", | |
| "or": "ori", | |
| "os": "oss", | |
| "pa": "pan", | |
| "pi": "pli", | |
| "pl": "pol", | |
| "ps": "pus", | |
| "pt": "por", | |
| "qu": "que", | |
| "rm": "roh", | |
| "rn": "run", | |
| "ro": "rum", | |
| "ru": "rus", | |
| "rw": "kin", | |
| "sa": "san", | |
| "sc": "srd", | |
| "sd": "snd", | |
| "se": "sme", | |
| "sg": "sag", | |
| "si": "sin", | |
| "sk": "slo", | |
| "sl": "slv", | |
| "sm": "smo", | |
| "sn": "sna", | |
| "so": "som", | |
| "sq": "alb", | |
| "sr": "srp", | |
| "ss": "ssw", | |
| "st": "sot", | |
| "su": "sun", | |
| "sv": "swe", | |
| "sw": "swa", | |
| "ta": "tam", | |
| "te": "tel", | |
| "tg": "tgk", | |
| "th": "tha", | |
| "ti": "tir", | |
| "tk": "tuk", | |
| "tl": "tgl", | |
| "tn": "tsn", | |
| "to": "ton", | |
| "tr": "tur", | |
| "ts": "tso", | |
| "tt": "tat", | |
| "tw": "twi", | |
| "ty": "tah", | |
| "ug": "uig", | |
| "uk": "ukr", | |
| "ur": "urd", | |
| "uz": "uzb", | |
| "ve": "ven", | |
| "vi": "vie", | |
| "vo": "vol", | |
| "wa": "wln", | |
| "wo": "wol", | |
| "xh": "xho", | |
| "yi": "yid", | |
| "yo": "yor", | |
| "za": "zha", | |
| "zh": "chi", | |
| "zu": "zul", | |
| } | |
| def create_config(output_dir): | |
| DOMAIN_TYPE = "telephonic" # Can be meeting, telephonic, or general based on domain type of the audio file | |
| CONFIG_LOCAL_DIRECTORY = "nemo_msdd_configs" | |
| CONFIG_FILE_NAME = f"diar_infer_{DOMAIN_TYPE}.yaml" | |
| MODEL_CONFIG_PATH = os.path.join(CONFIG_LOCAL_DIRECTORY, CONFIG_FILE_NAME) | |
| if not os.path.exists(MODEL_CONFIG_PATH): | |
| os.makedirs(CONFIG_LOCAL_DIRECTORY, exist_ok=True) | |
| CONFIG_URL = f"https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/speaker_tasks/diarization/conf/inference/{CONFIG_FILE_NAME}" | |
| MODEL_CONFIG_PATH = wget.download(CONFIG_URL, MODEL_CONFIG_PATH) | |
| config = OmegaConf.load(MODEL_CONFIG_PATH) | |
| data_dir = os.path.join(output_dir, "data") | |
| os.makedirs(data_dir, exist_ok=True) | |
| meta = { | |
| "audio_filepath": os.path.join(output_dir, "mono_file.wav"), | |
| "offset": 0, | |
| "duration": None, | |
| "label": "infer", | |
| "text": "-", | |
| "rttm_filepath": None, | |
| "uem_filepath": None, | |
| } | |
| with open(os.path.join(data_dir, "input_manifest.json"), "w") as fp: | |
| json.dump(meta, fp) | |
| fp.write("\n") | |
| pretrained_vad = "vad_multilingual_marblenet" | |
| pretrained_speaker_model = "titanet_large" | |
| config.num_workers = 0 | |
| config.diarizer.manifest_filepath = os.path.join(data_dir, "input_manifest.json") | |
| config.diarizer.out_dir = ( | |
| output_dir # Directory to store intermediate files and prediction outputs | |
| ) | |
| config.diarizer.speaker_embeddings.model_path = pretrained_speaker_model | |
| config.diarizer.oracle_vad = ( | |
| False # compute VAD provided with model_path to vad config | |
| ) | |
| config.diarizer.clustering.parameters.oracle_num_speakers = False | |
| # Here, we use our in-house pretrained NeMo VAD model | |
| config.diarizer.vad.model_path = pretrained_vad | |
| config.diarizer.vad.parameters.onset = 0.8 | |
| config.diarizer.vad.parameters.offset = 0.6 | |
| config.diarizer.vad.parameters.pad_offset = -0.05 | |
| config.diarizer.msdd_model.model_path = ( | |
| "diar_msdd_telephonic" # Telephonic speaker diarization model | |
| ) | |
| return config | |
| def get_word_ts_anchor(s, e, option="start"): | |
| if option == "end": | |
| return e | |
| elif option == "mid": | |
| return (s + e) / 2 | |
| return s | |
| def get_words_speaker_mapping(wrd_ts, spk_ts, word_anchor_option="start"): | |
| s, e, sp = spk_ts[0] | |
| wrd_pos, turn_idx = 0, 0 | |
| wrd_spk_mapping = [] | |
| for wrd_dict in wrd_ts: | |
| ws, we, wrd = ( | |
| int(wrd_dict["start"] * 1000), | |
| int(wrd_dict["end"] * 1000), | |
| wrd_dict["text"], | |
| ) | |
| wrd_pos = get_word_ts_anchor(ws, we, word_anchor_option) | |
| while wrd_pos > float(e): | |
| turn_idx += 1 | |
| turn_idx = min(turn_idx, len(spk_ts) - 1) | |
| s, e, sp = spk_ts[turn_idx] | |
| if turn_idx == len(spk_ts) - 1: | |
| e = get_word_ts_anchor(ws, we, option="end") | |
| wrd_spk_mapping.append( | |
| {"word": wrd, "start_time": ws, "end_time": we, "speaker": sp} | |
| ) | |
| return wrd_spk_mapping | |
| sentence_ending_punctuations = ".?!" | |
| def get_first_word_idx_of_sentence(word_idx, word_list, speaker_list, max_words): | |
| is_word_sentence_end = ( | |
| lambda x: x >= 0 and word_list[x][-1] in sentence_ending_punctuations | |
| ) | |
| left_idx = word_idx | |
| while ( | |
| left_idx > 0 | |
| and word_idx - left_idx < max_words | |
| and speaker_list[left_idx - 1] == speaker_list[left_idx] | |
| and not is_word_sentence_end(left_idx - 1) | |
| ): | |
| left_idx -= 1 | |
| return left_idx if left_idx == 0 or is_word_sentence_end(left_idx - 1) else -1 | |
| def get_last_word_idx_of_sentence(word_idx, word_list, max_words): | |
| is_word_sentence_end = ( | |
| lambda x: x >= 0 and word_list[x][-1] in sentence_ending_punctuations | |
| ) | |
| right_idx = word_idx | |
| while ( | |
| right_idx < len(word_list) - 1 | |
| and right_idx - word_idx < max_words | |
| and not is_word_sentence_end(right_idx) | |
| ): | |
| right_idx += 1 | |
| return ( | |
| right_idx | |
| if right_idx == len(word_list) - 1 or is_word_sentence_end(right_idx) | |
| else -1 | |
| ) | |
| def get_realigned_ws_mapping_with_punctuation( | |
| word_speaker_mapping, max_words_in_sentence=50 | |
| ): | |
| is_word_sentence_end = ( | |
| lambda x: x >= 0 | |
| and word_speaker_mapping[x]["word"][-1] in sentence_ending_punctuations | |
| ) | |
| wsp_len = len(word_speaker_mapping) | |
| words_list, speaker_list = [], [] | |
| for k, line_dict in enumerate(word_speaker_mapping): | |
| word, speaker = line_dict["word"], line_dict["speaker"] | |
| words_list.append(word) | |
| speaker_list.append(speaker) | |
| k = 0 | |
| while k < len(word_speaker_mapping): | |
| line_dict = word_speaker_mapping[k] | |
| if ( | |
| k < wsp_len - 1 | |
| and speaker_list[k] != speaker_list[k + 1] | |
| and not is_word_sentence_end(k) | |
| ): | |
| left_idx = get_first_word_idx_of_sentence( | |
| k, words_list, speaker_list, max_words_in_sentence | |
| ) | |
| right_idx = ( | |
| get_last_word_idx_of_sentence( | |
| k, words_list, max_words_in_sentence - k + left_idx - 1 | |
| ) | |
| if left_idx > -1 | |
| else -1 | |
| ) | |
| if min(left_idx, right_idx) == -1: | |
| k += 1 | |
| continue | |
| spk_labels = speaker_list[left_idx : right_idx + 1] | |
| mod_speaker = max(set(spk_labels), key=spk_labels.count) | |
| if spk_labels.count(mod_speaker) < len(spk_labels) // 2: | |
| k += 1 | |
| continue | |
| speaker_list[left_idx : right_idx + 1] = [mod_speaker] * ( | |
| right_idx - left_idx + 1 | |
| ) | |
| k = right_idx | |
| k += 1 | |
| k, realigned_list = 0, [] | |
| while k < len(word_speaker_mapping): | |
| line_dict = word_speaker_mapping[k].copy() | |
| line_dict["speaker"] = speaker_list[k] | |
| realigned_list.append(line_dict) | |
| k += 1 | |
| return realigned_list | |
| def get_sentences_speaker_mapping(word_speaker_mapping, spk_ts): | |
| sentence_checker = nltk.tokenize.PunktSentenceTokenizer().text_contains_sentbreak | |
| s, e, spk = spk_ts[0] | |
| prev_spk = spk | |
| snts = [] | |
| snt = {"speaker": f"Speaker {spk}", "start_time": s, "end_time": e, "text": ""} | |
| for wrd_dict in word_speaker_mapping: | |
| wrd, spk = wrd_dict["word"], wrd_dict["speaker"] | |
| s, e = wrd_dict["start_time"], wrd_dict["end_time"] | |
| if spk != prev_spk or sentence_checker(snt["text"] + " " + wrd): | |
| snts.append(snt) | |
| snt = { | |
| "speaker": f"Speaker {spk}", | |
| "start_time": s, | |
| "end_time": e, | |
| "text": "", | |
| } | |
| else: | |
| snt["end_time"] = e | |
| snt["text"] += wrd + " " | |
| prev_spk = spk | |
| snts.append(snt) | |
| return snts | |
| def get_speaker_aware_transcript(sentences_speaker_mapping, f): | |
| previous_speaker = sentences_speaker_mapping[0]["speaker"] | |
| f.write(f"{previous_speaker}: ") | |
| for sentence_dict in sentences_speaker_mapping: | |
| speaker = sentence_dict["speaker"] | |
| sentence = sentence_dict["text"] | |
| # If this speaker doesn't match the previous one, start a new paragraph | |
| if speaker != previous_speaker: | |
| f.write(f"\n\n{speaker}: ") | |
| previous_speaker = speaker | |
| # No matter what, write the current sentence | |
| f.write(sentence + " ") | |
| def format_timestamp( | |
| milliseconds: float, always_include_hours: bool = False, decimal_marker: str = "." | |
| ): | |
| assert milliseconds >= 0, "non-negative timestamp expected" | |
| hours = milliseconds // 3_600_000 | |
| milliseconds -= hours * 3_600_000 | |
| minutes = milliseconds // 60_000 | |
| milliseconds -= minutes * 60_000 | |
| seconds = milliseconds // 1_000 | |
| milliseconds -= seconds * 1_000 | |
| hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else "" | |
| return ( | |
| f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" | |
| ) | |
| def write_srt(transcript, file): | |
| """ | |
| Write a transcript to a file in SRT format. | |
| """ | |
| for i, segment in enumerate(transcript, start=1): | |
| # write srt lines | |
| print( | |
| f"{i}\n" | |
| f"{format_timestamp(segment['start_time'], always_include_hours=True, decimal_marker=',')} --> " | |
| f"{format_timestamp(segment['end_time'], always_include_hours=True, decimal_marker=',')}\n" | |
| f"{segment['speaker']}: {segment['text'].strip().replace('-->', '->')}\n", | |
| file=file, | |
| flush=True, | |
| ) | |
| def find_numeral_symbol_tokens(tokenizer): | |
| numeral_symbol_tokens = [ | |
| -1, | |
| ] | |
| for token, token_id in tokenizer.get_vocab().items(): | |
| has_numeral_symbol = any(c in "0123456789%$£" for c in token) | |
| if has_numeral_symbol: | |
| numeral_symbol_tokens.append(token_id) | |
| return numeral_symbol_tokens | |
| def _get_next_start_timestamp(word_timestamps, current_word_index, final_timestamp): | |
| # if current word is the last word | |
| if current_word_index == len(word_timestamps) - 1: | |
| return word_timestamps[current_word_index]["start"] | |
| next_word_index = current_word_index + 1 | |
| while current_word_index < len(word_timestamps) - 1: | |
| if word_timestamps[next_word_index].get("start") is None: | |
| # if next word doesn't have a start timestamp | |
| # merge it with the current word and delete it | |
| word_timestamps[current_word_index]["word"] += ( | |
| " " + word_timestamps[next_word_index]["word"] | |
| ) | |
| word_timestamps[next_word_index]["word"] = None | |
| next_word_index += 1 | |
| if next_word_index == len(word_timestamps): | |
| return final_timestamp | |
| else: | |
| return word_timestamps[next_word_index]["start"] | |
| def filter_missing_timestamps( | |
| word_timestamps, initial_timestamp=0, final_timestamp=None | |
| ): | |
| # handle the first and last word | |
| if word_timestamps[0].get("start") is None: | |
| word_timestamps[0]["start"] = ( | |
| initial_timestamp if initial_timestamp is not None else 0 | |
| ) | |
| word_timestamps[0]["end"] = _get_next_start_timestamp( | |
| word_timestamps, 0, final_timestamp | |
| ) | |
| result = [ | |
| word_timestamps[0], | |
| ] | |
| for i, ws in enumerate(word_timestamps[1:], start=1): | |
| # if ws doesn't have a start and end | |
| # use the previous end as start and next start as end | |
| if ws.get("start") is None and ws.get("word") is not None: | |
| ws["start"] = word_timestamps[i - 1]["end"] | |
| ws["end"] = _get_next_start_timestamp(word_timestamps, i, final_timestamp) | |
| if ws["word"] is not None: | |
| result.append(ws) | |
| return result | |
| def cleanup(path: str): | |
| """path could either be relative or absolute.""" | |
| # check if file or directory exists | |
| if os.path.isfile(path) or os.path.islink(path): | |
| # remove file | |
| os.remove(path) | |
| elif os.path.isdir(path): | |
| # remove directory and all its content | |
| shutil.rmtree(path) | |
| else: | |
| raise ValueError("Path {} is not a file or dir.".format(path)) | |
| def process_language_arg(language: str, model_name: str): | |
| """ | |
| Process the language argument to make sure it's valid and convert language names to language codes. | |
| """ | |
| if language is not None: | |
| language = language.lower() | |
| if language not in LANGUAGES: | |
| if language in TO_LANGUAGE_CODE: | |
| language = TO_LANGUAGE_CODE[language] | |
| else: | |
| raise ValueError(f"Unsupported language: {language}") | |
| if model_name.endswith(".en") and language != "en": | |
| if language is not None: | |
| logging.warning( | |
| f"{model_name} is an English-only model but received '{language}'; using English instead." | |
| ) | |
| language = "en" | |
| return language | |