| """ |
| Forced Alignment with Whisper |
| C. Max Bain |
| """ |
| import logging |
| import math |
|
|
| from dataclasses import dataclass |
| from typing import Iterable, Optional, Union, List |
|
|
| import numpy as np |
| import pandas as pd |
| import torch |
| import torchaudio |
| from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor |
|
|
| from utils import SAMPLE_RATE, load_audio |
| from schema import ( |
| AlignedTranscriptionResult, |
| SingleSegment, |
| SingleAlignedSegment, |
| SingleWordSegment, |
| SegmentData, |
| ) |
| import nltk |
| from nltk.data import load as nltk_load |
|
|
| logger = logging.getLogger(__name__) |
|
|
| PUNKT_LANGUAGES = { |
| 'cs': 'czech', |
| 'da': 'danish', |
| 'de': 'german', |
| 'el': 'greek', |
| 'en': 'english', |
| 'es': 'spanish', |
| 'et': 'estonian', |
| 'fi': 'finnish', |
| 'fr': 'french', |
| 'it': 'italian', |
| 'nl': 'dutch', |
| 'no': 'norwegian', |
| 'pl': 'polish', |
| 'pt': 'portuguese', |
| 'sl': 'slovene', |
| 'sv': 'swedish', |
| 'tr': 'turkish', |
| "ml": "malayalam", |
| "ru": "russian", |
| } |
|
|
|
|
| LANGUAGES_WITHOUT_SPACES = ["ja", "zh"] |
|
|
| DEFAULT_ALIGN_MODELS_TORCH = { |
| "en": "WAV2VEC2_ASR_BASE_960H", |
| "fr": "VOXPOPULI_ASR_BASE_10K_FR", |
| "de": "VOXPOPULI_ASR_BASE_10K_DE", |
| "es": "VOXPOPULI_ASR_BASE_10K_ES", |
| "it": "VOXPOPULI_ASR_BASE_10K_IT", |
| } |
|
|
| DEFAULT_ALIGN_MODELS_HF = { |
| "ja": "jonatasgrosman/wav2vec2-large-xlsr-53-japanese", |
| "zh": "jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn", |
| "nl": "jonatasgrosman/wav2vec2-large-xlsr-53-dutch", |
| "uk": "Yehor/wav2vec2-xls-r-300m-uk-with-small-lm", |
| "pt": "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese", |
| "ar": "jonatasgrosman/wav2vec2-large-xlsr-53-arabic", |
| "cs": "comodoro/wav2vec2-xls-r-300m-cs-250", |
| "ru": "jonatasgrosman/wav2vec2-large-xlsr-53-russian", |
| "pl": "jonatasgrosman/wav2vec2-large-xlsr-53-polish", |
| "hu": "jonatasgrosman/wav2vec2-large-xlsr-53-hungarian", |
| "fi": "jonatasgrosman/wav2vec2-large-xlsr-53-finnish", |
| "fa": "jonatasgrosman/wav2vec2-large-xlsr-53-persian", |
| "el": "jonatasgrosman/wav2vec2-large-xlsr-53-greek", |
| "tr": "mpoyraz/wav2vec2-xls-r-300m-cv7-turkish", |
| "da": "saattrupdan/wav2vec2-xls-r-300m-ftspeech", |
| "he": "imvladikon/wav2vec2-xls-r-300m-hebrew", |
| "vi": 'nguyenvulebinh/wav2vec2-base-vi-vlsp2020', |
| "ko": "kresnik/wav2vec2-large-xlsr-korean", |
| "ur": "kingabzpro/wav2vec2-large-xls-r-300m-Urdu", |
| "te": "anuragshas/wav2vec2-large-xlsr-53-telugu", |
| "hi": "theainerd/Wav2Vec2-large-xlsr-hindi", |
| "ca": "softcatala/wav2vec2-large-xlsr-catala", |
| "ml": "gvs/wav2vec2-large-xlsr-malayalam", |
| "no": "NbAiLab/nb-wav2vec2-1b-bokmaal-v2", |
| "nn": "NbAiLab/nb-wav2vec2-1b-nynorsk", |
| "sk": "comodoro/wav2vec2-xls-r-300m-sk-cv8", |
| "sl": "anton-l/wav2vec2-large-xlsr-53-slovenian", |
| "hr": "classla/wav2vec2-xls-r-parlaspeech-hr", |
| "ro": "gigant/romanian-wav2vec2", |
| "eu": "stefan-it/wav2vec2-large-xlsr-53-basque", |
| "gl": "ifrz/wav2vec2-large-xlsr-galician", |
| "ka": "xsway/wav2vec2-large-xlsr-georgian", |
| "lv": "jimregan/wav2vec2-large-xlsr-latvian-cv", |
| "tl": "Khalsuu/filipino-wav2vec2-l-xls-r-300m-official", |
| "sv": "KBLab/wav2vec2-large-voxrex-swedish", |
| } |
|
|
|
|
| def interpolate_nans(x, method='nearest'): |
| if x.notnull().sum() > 1: |
| return x.interpolate(method=method).ffill().bfill() |
| else: |
| return x.ffill().bfill() |
|
|
|
|
| def load_align_model(language_code: str, device: str, model_name: Optional[str] = None, model_dir=None): |
| if model_name is None: |
| |
| if language_code in DEFAULT_ALIGN_MODELS_TORCH: |
| model_name = DEFAULT_ALIGN_MODELS_TORCH[language_code] |
| elif language_code in DEFAULT_ALIGN_MODELS_HF: |
| model_name = DEFAULT_ALIGN_MODELS_HF[language_code] |
| else: |
| logger.error(f"No default alignment model for language: {language_code}. " |
| f"Please find a wav2vec2.0 model finetuned on this language at https://huggingface.co/models, " |
| f"then pass the model name via --align_model [MODEL_NAME]") |
| raise ValueError(f"No default align-model for language: {language_code}") |
|
|
| if model_name in torchaudio.pipelines.__all__: |
| pipeline_type = "torchaudio" |
| bundle = torchaudio.pipelines.__dict__[model_name] |
| align_model = bundle.get_model(dl_kwargs={"model_dir": model_dir}).to(device) |
| labels = bundle.get_labels() |
| align_dictionary = {c.lower(): i for i, c in enumerate(labels)} |
| else: |
| try: |
| processor = Wav2Vec2Processor.from_pretrained(model_name, cache_dir=model_dir) |
| align_model = Wav2Vec2ForCTC.from_pretrained(model_name, cache_dir=model_dir) |
| except Exception as e: |
| print(e) |
| print(f"Error loading model from huggingface, check https://huggingface.co/models for finetuned wav2vec2.0 models") |
| raise ValueError(f'The chosen align_model "{model_name}" could not be found in huggingface (https://huggingface.co/models) or torchaudio (https://pytorch.org/audio/stable/pipelines.html#id14)') |
| pipeline_type = "huggingface" |
| align_model = align_model.to(device) |
| labels = processor.tokenizer.get_vocab() |
| align_dictionary = {char.lower(): code for char,code in processor.tokenizer.get_vocab().items()} |
|
|
| align_metadata = {"language": language_code, "dictionary": align_dictionary, "type": pipeline_type} |
|
|
| return align_model, align_metadata |
|
|
|
|
| def align( |
| transcript: Iterable[SingleSegment], |
| model: torch.nn.Module, |
| align_model_metadata: dict, |
| audio: Union[str, np.ndarray, torch.Tensor], |
| device: str, |
| interpolate_method: str = "nearest", |
| return_char_alignments: bool = False, |
| print_progress: bool = False, |
| combined_progress: bool = False, |
| ) -> AlignedTranscriptionResult: |
| """ |
| Align phoneme recognition predictions to known transcription. |
| """ |
|
|
| if not torch.is_tensor(audio): |
| if isinstance(audio, str): |
| audio = load_audio(audio) |
| audio = torch.from_numpy(audio) |
| if len(audio.shape) == 1: |
| audio = audio.unsqueeze(0) |
|
|
| MAX_DURATION = audio.shape[1] / SAMPLE_RATE |
|
|
| model_dictionary = align_model_metadata["dictionary"] |
| model_lang = align_model_metadata["language"] |
| model_type = align_model_metadata["type"] |
|
|
| |
| total_segments = len(transcript) |
| |
| segment_data: dict[int, SegmentData] = {} |
| for sdx, segment in enumerate(transcript): |
| |
| if print_progress: |
| base_progress = ((sdx + 1) / total_segments) * 100 |
| percent_complete = (50 + base_progress / 2) if combined_progress else base_progress |
| print(f"Progress: {percent_complete:.2f}%...") |
|
|
| num_leading = len(segment["text"]) - len(segment["text"].lstrip()) |
| num_trailing = len(segment["text"]) - len(segment["text"].rstrip()) |
| text = segment["text"] |
|
|
| |
| if model_lang not in LANGUAGES_WITHOUT_SPACES: |
| per_word = text.split(" ") |
| else: |
| per_word = text |
|
|
| clean_char, clean_cdx = [], [] |
| for cdx, char in enumerate(text): |
| char_ = char.lower() |
| |
| if model_lang not in LANGUAGES_WITHOUT_SPACES: |
| char_ = char_.replace(" ", "|") |
|
|
| |
| if cdx < num_leading: |
| pass |
| elif cdx > len(text) - num_trailing - 1: |
| pass |
| elif char_ in model_dictionary.keys(): |
| clean_char.append(char_) |
| clean_cdx.append(cdx) |
| else: |
| |
| clean_char.append('*') |
| clean_cdx.append(cdx) |
|
|
| clean_wdx = [] |
| for wdx, wrd in enumerate(per_word): |
| if any([c in model_dictionary.keys() for c in wrd.lower()]): |
| clean_wdx.append(wdx) |
| else: |
| |
| clean_wdx.append(wdx) |
|
|
|
|
| |
| punkt_lang = PUNKT_LANGUAGES.get(model_lang, 'english') |
| try: |
| sentence_splitter = nltk_load(f'tokenizers/punkt_tab/{punkt_lang}.pickle') |
| except LookupError: |
| nltk.download('punkt_tab', quiet=True) |
| sentence_splitter = nltk_load(f'tokenizers/punkt_tab/{punkt_lang}.pickle') |
| sentence_spans = list(sentence_splitter.span_tokenize(text)) |
|
|
| segment_data[sdx] = { |
| "clean_char": clean_char, |
| "clean_cdx": clean_cdx, |
| "clean_wdx": clean_wdx, |
| "sentence_spans": sentence_spans |
| } |
|
|
| aligned_segments: List[SingleAlignedSegment] = [] |
|
|
| |
| for sdx, segment in enumerate(transcript): |
|
|
| t1 = segment["start"] |
| t2 = segment["end"] |
| text = segment["text"] |
|
|
| aligned_seg: SingleAlignedSegment = { |
| "start": t1, |
| "end": t2, |
| "text": text, |
| "words": [], |
| "chars": None, |
| } |
|
|
| if return_char_alignments: |
| aligned_seg["chars"] = [] |
|
|
| |
| if len(segment_data[sdx]["clean_char"]) == 0: |
| logger.warning(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original') |
| aligned_segments.append(aligned_seg) |
| continue |
|
|
| if t1 >= MAX_DURATION: |
| logger.warning(f'Failed to align segment ("{segment["text"]}"): original start time longer than audio duration, skipping') |
| aligned_segments.append(aligned_seg) |
| continue |
|
|
| text_clean = "".join(segment_data[sdx]["clean_char"]) |
| tokens = [model_dictionary.get(c, -1) for c in text_clean] |
|
|
| f1 = int(t1 * SAMPLE_RATE) |
| f2 = int(t2 * SAMPLE_RATE) |
|
|
| |
| waveform_segment = audio[:, f1:f2] |
| |
| if waveform_segment.shape[-1] < 400: |
| lengths = torch.as_tensor([waveform_segment.shape[-1]]).to(device) |
| waveform_segment = torch.nn.functional.pad( |
| waveform_segment, (0, 400 - waveform_segment.shape[-1]) |
| ) |
| else: |
| lengths = None |
|
|
| with torch.inference_mode(): |
| if model_type == "torchaudio": |
| emissions, _ = model(waveform_segment.to(device), lengths=lengths) |
| elif model_type == "huggingface": |
| emissions = model(waveform_segment.to(device)).logits |
| else: |
| raise NotImplementedError(f"Align model of type {model_type} not supported.") |
| emissions = torch.log_softmax(emissions, dim=-1) |
|
|
| emission = emissions[0].cpu().detach() |
|
|
| blank_id = 0 |
| for char, code in model_dictionary.items(): |
| if char == '[pad]' or char == '<pad>': |
| blank_id = code |
|
|
| trellis = get_trellis(emission, tokens, blank_id) |
| |
| path = backtrack_beam(trellis, emission, tokens, blank_id, beam_width=2) |
|
|
| if path is None: |
| logger.warning(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original') |
| aligned_segments.append(aligned_seg) |
| continue |
|
|
| char_segments = merge_repeats(path, text_clean) |
|
|
| duration = t2 - t1 |
| ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1) |
|
|
| |
| char_segments_arr = [] |
| word_idx = 0 |
| for cdx, char in enumerate(text): |
| start, end, score = None, None, None |
| if cdx in segment_data[sdx]["clean_cdx"]: |
| char_seg = char_segments[segment_data[sdx]["clean_cdx"].index(cdx)] |
| start = round(char_seg.start * ratio + t1, 3) |
| end = round(char_seg.end * ratio + t1, 3) |
| score = round(char_seg.score, 3) |
|
|
| char_segments_arr.append( |
| { |
| "char": char, |
| "start": start, |
| "end": end, |
| "score": score, |
| "word-idx": word_idx, |
| } |
| ) |
|
|
| |
| if model_lang in LANGUAGES_WITHOUT_SPACES: |
| word_idx += 1 |
| elif cdx == len(text) - 1 or text[cdx+1] == " ": |
| word_idx += 1 |
|
|
| char_segments_arr = pd.DataFrame(char_segments_arr) |
|
|
| aligned_subsegments = [] |
| |
| char_segments_arr["sentence-idx"] = None |
| for sdx2, (sstart, send) in enumerate(segment_data[sdx]["sentence_spans"]): |
| curr_chars = char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send)] |
| char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send), "sentence-idx"] = sdx2 |
|
|
| sentence_text = text[sstart:send] |
| sentence_start = curr_chars["start"].min() |
| end_chars = curr_chars[curr_chars["char"] != ' '] |
| sentence_end = end_chars["end"].max() |
| sentence_words = [] |
|
|
| for word_idx in curr_chars["word-idx"].unique(): |
| word_chars = curr_chars.loc[curr_chars["word-idx"] == word_idx] |
| word_text = "".join(word_chars["char"].tolist()).strip() |
| if len(word_text) == 0: |
| continue |
|
|
| |
| word_chars = word_chars[word_chars["char"] != " "] |
|
|
| word_start = word_chars["start"].min() |
| word_end = word_chars["end"].max() |
| word_score = round(word_chars["score"].mean(), 3) |
|
|
| |
| word_segment = {"word": word_text} |
|
|
| if not np.isnan(word_start): |
| word_segment["start"] = word_start |
| if not np.isnan(word_end): |
| word_segment["end"] = word_end |
| if not np.isnan(word_score): |
| word_segment["score"] = word_score |
|
|
| sentence_words.append(word_segment) |
|
|
| aligned_subsegments.append({ |
| "text": sentence_text, |
| "start": sentence_start, |
| "end": sentence_end, |
| "words": sentence_words, |
| }) |
|
|
| if return_char_alignments: |
| curr_chars = curr_chars[["char", "start", "end", "score"]] |
| curr_chars.fillna(-1, inplace=True) |
| curr_chars = curr_chars.to_dict("records") |
| curr_chars = [{key: val for key, val in char.items() if val != -1} for char in curr_chars] |
| aligned_subsegments[-1]["chars"] = curr_chars |
|
|
| aligned_subsegments = pd.DataFrame(aligned_subsegments) |
| aligned_subsegments["start"] = interpolate_nans(aligned_subsegments["start"], method=interpolate_method) |
| aligned_subsegments["end"] = interpolate_nans(aligned_subsegments["end"], method=interpolate_method) |
| |
| agg_dict = {"text": " ".join, "words": "sum"} |
| if model_lang in LANGUAGES_WITHOUT_SPACES: |
| agg_dict["text"] = "".join |
| if return_char_alignments: |
| agg_dict["chars"] = "sum" |
| aligned_subsegments= aligned_subsegments.groupby(["start", "end"], as_index=False).agg(agg_dict) |
| aligned_subsegments = aligned_subsegments.to_dict('records') |
| aligned_segments += aligned_subsegments |
|
|
| |
| word_segments: List[SingleWordSegment] = [] |
| for segment in aligned_segments: |
| word_segments += segment["words"] |
|
|
| return {"segments": aligned_segments, "word_segments": word_segments} |
|
|
| """ |
| source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html |
| """ |
|
|
|
|
| def get_trellis(emission, tokens, blank_id=0): |
| num_frame = emission.size(0) |
| num_tokens = len(tokens) |
|
|
| trellis = torch.zeros((num_frame, num_tokens)) |
| trellis[1:, 0] = torch.cumsum(emission[1:, blank_id], 0) |
| trellis[0, 1:] = -float("inf") |
| trellis[-num_tokens + 1:, 0] = float("inf") |
|
|
| for t in range(num_frame - 1): |
| trellis[t + 1, 1:] = torch.maximum( |
| |
| trellis[t, 1:] + emission[t, blank_id], |
| |
| |
| trellis[t, :-1] + get_wildcard_emission(emission[t], tokens[1:], blank_id), |
| ) |
| return trellis |
|
|
|
|
| def get_wildcard_emission(frame_emission, tokens, blank_id): |
| """Processing token emission scores containing wildcards (vectorized version) |
| |
| Args: |
| frame_emission: Emission probability vector for the current frame |
| tokens: List of token indices |
| blank_id: ID of the blank token |
| |
| Returns: |
| tensor: Maximum probability score for each token position |
| """ |
| assert 0 <= blank_id < len(frame_emission) |
|
|
| |
| tokens = torch.tensor(tokens) if not isinstance(tokens, torch.Tensor) else tokens |
|
|
| |
| wildcard_mask = (tokens == -1) |
|
|
| |
| regular_scores = frame_emission[tokens.clamp(min=0).long()] |
|
|
| |
| max_valid_score = frame_emission.clone() |
| max_valid_score[blank_id] = float('-inf') |
| max_valid_score = max_valid_score.max() |
|
|
| |
| result = torch.where(wildcard_mask, max_valid_score, regular_scores) |
|
|
| return result |
|
|
|
|
| @dataclass |
| class Point: |
| token_index: int |
| time_index: int |
| score: float |
|
|
|
|
| def backtrack(trellis, emission, tokens, blank_id=0): |
| t, j = trellis.size(0) - 1, trellis.size(1) - 1 |
|
|
| path = [Point(j, t, emission[t, blank_id].exp().item())] |
| while j > 0: |
| |
| assert t > 0 |
|
|
| |
| |
| p_stay = emission[t - 1, blank_id] |
| |
| p_change = get_wildcard_emission(emission[t - 1], [tokens[j]], blank_id)[0] |
|
|
| |
| stayed = trellis[t - 1, j] + p_stay |
| changed = trellis[t - 1, j - 1] + p_change |
|
|
| |
| t -= 1 |
| if changed > stayed: |
| j -= 1 |
|
|
| |
| prob = (p_change if changed > stayed else p_stay).exp().item() |
| path.append(Point(j, t, prob)) |
|
|
| |
| |
| while t > 0: |
| prob = emission[t - 1, blank_id].exp().item() |
| path.append(Point(j, t - 1, prob)) |
| t -= 1 |
|
|
| return path[::-1] |
|
|
|
|
|
|
| @dataclass |
| class Path: |
| points: List[Point] |
| score: float |
|
|
|
|
| @dataclass |
| class BeamState: |
| """State in beam search.""" |
| token_index: int |
| time_index: int |
| score: float |
| path: List[Point] |
|
|
|
|
| def backtrack_beam(trellis, emission, tokens, blank_id=0, beam_width=5): |
| """Standard CTC beam search backtracking implementation. |
| |
| Args: |
| trellis (torch.Tensor): The trellis (or lattice) of shape (T, N), where T is the number of time steps |
| and N is the number of tokens (including the blank token). |
| emission (torch.Tensor): The emission probabilities of shape (T, N). |
| tokens (List[int]): List of token indices (excluding the blank token). |
| blank_id (int, optional): The ID of the blank token. Defaults to 0. |
| beam_width (int, optional): The number of top paths to keep during beam search. Defaults to 5. |
| |
| Returns: |
| List[Point]: the best path |
| """ |
| T, J = trellis.size(0) - 1, trellis.size(1) - 1 |
|
|
| init_state = BeamState( |
| token_index=J, |
| time_index=T, |
| score=trellis[T, J], |
| path=[Point(J, T, emission[T, blank_id].exp().item())] |
| ) |
|
|
| beams = [init_state] |
|
|
| while beams and beams[0].token_index > 0: |
| next_beams = [] |
|
|
| for beam in beams: |
| t, j = beam.time_index, beam.token_index |
|
|
| if t <= 0: |
| continue |
|
|
| p_stay = emission[t - 1, blank_id] |
| p_change = get_wildcard_emission(emission[t - 1], [tokens[j]], blank_id)[0] |
|
|
| stay_score = trellis[t - 1, j] |
| change_score = trellis[t - 1, j - 1] if j > 0 else float('-inf') |
|
|
| |
| if not math.isinf(stay_score): |
| new_path = beam.path.copy() |
| new_path.append(Point(j, t - 1, p_stay.exp().item())) |
| next_beams.append(BeamState( |
| token_index=j, |
| time_index=t - 1, |
| score=stay_score, |
| path=new_path |
| )) |
|
|
| |
| if j > 0 and not math.isinf(change_score): |
| new_path = beam.path.copy() |
| new_path.append(Point(j - 1, t - 1, p_change.exp().item())) |
| next_beams.append(BeamState( |
| token_index=j - 1, |
| time_index=t - 1, |
| score=change_score, |
| path=new_path |
| )) |
|
|
| |
| beams = sorted(next_beams, key=lambda x: x.score, reverse=True)[:beam_width] |
|
|
| if not beams: |
| break |
|
|
| if not beams: |
| return None |
|
|
| best_beam = beams[0] |
| t = best_beam.time_index |
| j = best_beam.token_index |
| while t > 0: |
| prob = emission[t - 1, blank_id].exp().item() |
| best_beam.path.append(Point(j, t - 1, prob)) |
| t -= 1 |
|
|
| return best_beam.path[::-1] |
|
|
|
|
| |
| @dataclass |
| class Segment: |
| label: str |
| start: int |
| end: int |
| score: float |
|
|
| def __repr__(self): |
| return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})" |
|
|
| @property |
| def length(self): |
| return self.end - self.start |
|
|
| def merge_repeats(path, transcript): |
| i1, i2 = 0, 0 |
| segments = [] |
| while i1 < len(path): |
| while i2 < len(path) and path[i1].token_index == path[i2].token_index: |
| i2 += 1 |
| score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1) |
| segments.append( |
| Segment( |
| transcript[path[i1].token_index], |
| path[i1].time_index, |
| path[i2 - 1].time_index + 1, |
| score, |
| ) |
| ) |
| i1 = i2 |
| return segments |
|
|
| def merge_words(segments, separator="|"): |
| words = [] |
| i1, i2 = 0, 0 |
| while i1 < len(segments): |
| if i2 >= len(segments) or segments[i2].label == separator: |
| if i1 != i2: |
| segs = segments[i1:i2] |
| word = "".join([seg.label for seg in segs]) |
| score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs) |
| words.append(Segment(word, segments[i1].start, segments[i2 - 1].end, score)) |
| i1 = i2 + 1 |
| i2 = i1 |
| else: |
| i2 += 1 |
| return words |