Spaces:
Runtime error
Runtime error
| import itertools | |
| import subprocess | |
| import warnings | |
| from dataclasses import dataclass | |
| from typing import TYPE_CHECKING, List | |
| import numba | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND | |
| from .tokenizer import Tokenizer | |
| if TYPE_CHECKING: | |
| from .model import Whisper | |
| def median_filter(x: torch.Tensor, filter_width: int): | |
| """Apply a median filter of width `filter_width` along the last dimension of `x`""" | |
| pad_width = filter_width // 2 | |
| if x.shape[-1] <= pad_width: | |
| # F.pad requires the padding width to be smaller than the input dimension | |
| return x | |
| if (ndim := x.ndim) <= 2: | |
| # `F.pad` does not support 1D or 2D inputs for reflect padding but supports 3D and 4D | |
| x = x[None, None, :] | |
| assert ( | |
| filter_width > 0 and filter_width % 2 == 1 | |
| ), "`filter_width` should be an odd number" | |
| result = None | |
| x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect") | |
| if x.is_cuda: | |
| try: | |
| from .triton_ops import median_filter_cuda | |
| result = median_filter_cuda(x, filter_width) | |
| except (RuntimeError, subprocess.CalledProcessError): | |
| warnings.warn( | |
| "Failed to launch Triton kernels, likely due to missing CUDA toolkit; " | |
| "falling back to a slower median kernel implementation..." | |
| ) | |
| if result is None: | |
| # sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450) | |
| result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2] | |
| if ndim <= 2: | |
| result = result[0, 0] | |
| return result | |
| def backtrace(trace: np.ndarray): | |
| i = trace.shape[0] - 1 | |
| j = trace.shape[1] - 1 | |
| trace[0, :] = 2 | |
| trace[:, 0] = 1 | |
| result = [] | |
| while i > 0 or j > 0: | |
| result.append((i - 1, j - 1)) | |
| if trace[i, j] == 0: | |
| i -= 1 | |
| j -= 1 | |
| elif trace[i, j] == 1: | |
| i -= 1 | |
| elif trace[i, j] == 2: | |
| j -= 1 | |
| else: | |
| raise ValueError("Unexpected trace[i, j]") | |
| result = np.array(result) | |
| return result[::-1, :].T | |
| def dtw_cpu(x: np.ndarray): | |
| N, M = x.shape | |
| cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf | |
| trace = -np.ones((N + 1, M + 1), dtype=np.float32) | |
| cost[0, 0] = 0 | |
| for j in range(1, M + 1): | |
| for i in range(1, N + 1): | |
| c0 = cost[i - 1, j - 1] | |
| c1 = cost[i - 1, j] | |
| c2 = cost[i, j - 1] | |
| if c0 < c1 and c0 < c2: | |
| c, t = c0, 0 | |
| elif c1 < c0 and c1 < c2: | |
| c, t = c1, 1 | |
| else: | |
| c, t = c2, 2 | |
| cost[i, j] = x[i - 1, j - 1] + c | |
| trace[i, j] = t | |
| return backtrace(trace) | |
| def dtw_cuda(x, BLOCK_SIZE=1024): | |
| from .triton_ops import dtw_kernel | |
| M, N = x.shape | |
| assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}" | |
| x_skew = ( | |
| F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M) | |
| ) | |
| x_skew = x_skew.T.contiguous() | |
| cost = torch.ones(N + M + 2, M + 2) * np.inf | |
| cost[0, 0] = 0 | |
| cost = cost.cuda() | |
| trace = torch.zeros_like(cost, dtype=torch.int32) | |
| dtw_kernel[(1,)]( | |
| cost, | |
| trace, | |
| x_skew, | |
| x_skew.stride(0), | |
| cost.stride(0), | |
| trace.stride(0), | |
| N, | |
| M, | |
| BLOCK_SIZE=BLOCK_SIZE, | |
| ) | |
| trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[ | |
| :, : N + 1 | |
| ] | |
| return backtrace(trace.cpu().numpy()) | |
| def dtw(x: torch.Tensor) -> np.ndarray: | |
| if x.is_cuda: | |
| try: | |
| return dtw_cuda(x) | |
| except (RuntimeError, subprocess.CalledProcessError): | |
| warnings.warn( | |
| "Failed to launch Triton kernels, likely due to missing CUDA toolkit; " | |
| "falling back to a slower DTW implementation..." | |
| ) | |
| return dtw_cpu(x.double().cpu().numpy()) | |
| class WordTiming: | |
| word: str | |
| tokens: List[int] | |
| start: float | |
| end: float | |
| probability: float | |
| def find_alignment( | |
| model: "Whisper", | |
| tokenizer: Tokenizer, | |
| text_tokens: List[int], | |
| mel: torch.Tensor, | |
| num_frames: int, | |
| *, | |
| medfilt_width: int = 7, | |
| qk_scale: float = 1.0, | |
| ) -> List[WordTiming]: | |
| if len(text_tokens) == 0: | |
| return [] | |
| tokens = torch.tensor( | |
| [ | |
| *tokenizer.sot_sequence, | |
| tokenizer.no_timestamps, | |
| *text_tokens, | |
| tokenizer.eot, | |
| ] | |
| ).to(model.device) | |
| # install hooks on the cross attention layers to retrieve the attention weights | |
| QKs = [None] * model.dims.n_text_layer | |
| hooks = [ | |
| block.cross_attn.register_forward_hook( | |
| lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0]) | |
| ) | |
| for i, block in enumerate(model.decoder.blocks) | |
| ] | |
| from .model import disable_sdpa | |
| with torch.no_grad(), disable_sdpa(): | |
| logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0] | |
| sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot] | |
| token_probs = sampled_logits.softmax(dim=-1) | |
| text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens] | |
| text_token_probs = text_token_probs.tolist() | |
| for hook in hooks: | |
| hook.remove() | |
| # heads * tokens * frames | |
| weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T]) | |
| weights = weights[:, :, : num_frames // 2] | |
| weights = (weights * qk_scale).softmax(dim=-1) | |
| std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False) | |
| weights = (weights - mean) / std | |
| weights = median_filter(weights, medfilt_width) | |
| matrix = weights.mean(axis=0) | |
| matrix = matrix[len(tokenizer.sot_sequence) : -1] | |
| text_indices, time_indices = dtw(-matrix) | |
| words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot]) | |
| if len(word_tokens) <= 1: | |
| # return on eot only | |
| # >>> np.pad([], (1, 0)) | |
| # array([0.]) | |
| # This results in crashes when we lookup jump_times with float, like | |
| # IndexError: arrays used as indices must be of integer (or boolean) type | |
| return [] | |
| word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0)) | |
| jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool) | |
| jump_times = time_indices[jumps] / TOKENS_PER_SECOND | |
| start_times = jump_times[word_boundaries[:-1]] | |
| end_times = jump_times[word_boundaries[1:]] | |
| word_probabilities = [ | |
| np.mean(text_token_probs[i:j]) | |
| for i, j in zip(word_boundaries[:-1], word_boundaries[1:]) | |
| ] | |
| return [ | |
| WordTiming(word, tokens, start, end, probability) | |
| for word, tokens, start, end, probability in zip( | |
| words, word_tokens, start_times, end_times, word_probabilities | |
| ) | |
| ] | |
| def merge_punctuations(alignment: List[WordTiming], prepended: str, appended: str): | |
| # merge prepended punctuations | |
| i = len(alignment) - 2 | |
| j = len(alignment) - 1 | |
| while i >= 0: | |
| previous = alignment[i] | |
| following = alignment[j] | |
| if previous.word.startswith(" ") and previous.word.strip() in prepended: | |
| # prepend it to the following word | |
| following.word = previous.word + following.word | |
| following.tokens = previous.tokens + following.tokens | |
| previous.word = "" | |
| previous.tokens = [] | |
| else: | |
| j = i | |
| i -= 1 | |
| # merge appended punctuations | |
| i = 0 | |
| j = 1 | |
| while j < len(alignment): | |
| previous = alignment[i] | |
| following = alignment[j] | |
| if not previous.word.endswith(" ") and following.word in appended: | |
| # append it to the previous word | |
| previous.word = previous.word + following.word | |
| previous.tokens = previous.tokens + following.tokens | |
| following.word = "" | |
| following.tokens = [] | |
| else: | |
| i = j | |
| j += 1 | |
| def add_word_timestamps( | |
| *, | |
| segments: List[dict], | |
| model: "Whisper", | |
| tokenizer: Tokenizer, | |
| mel: torch.Tensor, | |
| num_frames: int, | |
| prepend_punctuations: str = "\"'“¿([{-", | |
| append_punctuations: str = "\"'.。,,!!??::”)]}、", | |
| last_speech_timestamp: float, | |
| **kwargs, | |
| ): | |
| if len(segments) == 0: | |
| return | |
| text_tokens_per_segment = [ | |
| [token for token in segment["tokens"] if token < tokenizer.eot] | |
| for segment in segments | |
| ] | |
| text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment)) | |
| alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs) | |
| word_durations = np.array([t.end - t.start for t in alignment]) | |
| word_durations = word_durations[word_durations.nonzero()] | |
| median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0 | |
| median_duration = min(0.7, float(median_duration)) | |
| max_duration = median_duration * 2 | |
| # hack: truncate long words at sentence boundaries. | |
| # a better segmentation algorithm based on VAD should be able to replace this. | |
| if len(word_durations) > 0: | |
| sentence_end_marks = ".。!!??" | |
| # ensure words at sentence boundaries are not longer than twice the median word duration. | |
| for i in range(1, len(alignment)): | |
| if alignment[i].end - alignment[i].start > max_duration: | |
| if alignment[i].word in sentence_end_marks: | |
| alignment[i].end = alignment[i].start + max_duration | |
| elif alignment[i - 1].word in sentence_end_marks: | |
| alignment[i].start = alignment[i].end - max_duration | |
| merge_punctuations(alignment, prepend_punctuations, append_punctuations) | |
| time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE | |
| word_index = 0 | |
| for segment, text_tokens in zip(segments, text_tokens_per_segment): | |
| saved_tokens = 0 | |
| words = [] | |
| while word_index < len(alignment) and saved_tokens < len(text_tokens): | |
| timing = alignment[word_index] | |
| if timing.word: | |
| words.append( | |
| dict( | |
| word=timing.word, | |
| start=round(time_offset + timing.start, 2), | |
| end=round(time_offset + timing.end, 2), | |
| probability=timing.probability, | |
| ) | |
| ) | |
| saved_tokens += len(timing.tokens) | |
| word_index += 1 | |
| # hack: truncate long words at segment boundaries. | |
| # a better segmentation algorithm based on VAD should be able to replace this. | |
| if len(words) > 0: | |
| # ensure the first and second word after a pause is not longer than | |
| # twice the median word duration. | |
| if words[0]["end"] - last_speech_timestamp > median_duration * 4 and ( | |
| words[0]["end"] - words[0]["start"] > max_duration | |
| or ( | |
| len(words) > 1 | |
| and words[1]["end"] - words[0]["start"] > max_duration * 2 | |
| ) | |
| ): | |
| if ( | |
| len(words) > 1 | |
| and words[1]["end"] - words[1]["start"] > max_duration | |
| ): | |
| boundary = max(words[1]["end"] / 2, words[1]["end"] - max_duration) | |
| words[0]["end"] = words[1]["start"] = boundary | |
| words[0]["start"] = max(0, words[0]["end"] - max_duration) | |
| # prefer the segment-level start timestamp if the first word is too long. | |
| if ( | |
| segment["start"] < words[0]["end"] | |
| and segment["start"] - 0.5 > words[0]["start"] | |
| ): | |
| words[0]["start"] = max( | |
| 0, min(words[0]["end"] - median_duration, segment["start"]) | |
| ) | |
| else: | |
| segment["start"] = words[0]["start"] | |
| # prefer the segment-level end timestamp if the last word is too long. | |
| if ( | |
| segment["end"] > words[-1]["start"] | |
| and segment["end"] + 0.5 < words[-1]["end"] | |
| ): | |
| words[-1]["end"] = max( | |
| words[-1]["start"] + median_duration, segment["end"] | |
| ) | |
| else: | |
| segment["end"] = words[-1]["end"] | |
| last_speech_timestamp = segment["end"] | |
| segment["words"] = words | |