Spaces:
Runtime error
Runtime error
| # Copyright 2021 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from collections import defaultdict | |
| from typing import TYPE_CHECKING, Dict, Optional, Union | |
| import numpy as np | |
| import requests | |
| from ..modelcard import ModelCard | |
| from ..tokenization_utils import PreTrainedTokenizer | |
| from ..utils import is_torch_available, is_torchaudio_available, logging | |
| from .audio_utils import ffmpeg_read | |
| from .base import ArgumentHandler, ChunkPipeline, infer_framework_load_model | |
| if TYPE_CHECKING: | |
| from pyctcdecode import BeamSearchDecoderCTC | |
| from ..feature_extraction_sequence_utils import SequenceFeatureExtractor | |
| from ..modeling_utils import PreTrainedModel | |
| logger = logging.get_logger(__name__) | |
| if is_torch_available(): | |
| import torch | |
| from ..models.auto.modeling_auto import MODEL_FOR_CTC_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES | |
| def rescale_stride(stride, ratio): | |
| """ | |
| Rescales the stride values from audio space to tokens/logits space. | |
| (160_000, 16_000, 16_000) -> (2000, 200, 200) for instance. | |
| """ | |
| # Shape is [B, SEQ] for tokens | |
| # [B, SEQ, V] for logits | |
| new_strides = [] | |
| for input_n, left, right in stride: | |
| token_n = int(round(input_n * ratio)) | |
| left = int(round(left / input_n * token_n)) | |
| right = int(round(right / input_n * token_n)) | |
| new_stride = (token_n, left, right) | |
| new_strides.append(new_stride) | |
| return new_strides | |
| def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right, rescale=True, dtype=None): | |
| inputs_len = inputs.shape[0] | |
| step = chunk_len - stride_left - stride_right | |
| for chunk_start_idx in range(0, inputs_len, step): | |
| chunk_end_idx = chunk_start_idx + chunk_len | |
| chunk = inputs[chunk_start_idx:chunk_end_idx] | |
| processed = feature_extractor(chunk, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt") | |
| if dtype is not None: | |
| processed = processed.to(dtype=dtype) | |
| _stride_left = 0 if chunk_start_idx == 0 else stride_left | |
| # all right strides must be full, otherwise it is the last item | |
| is_last = chunk_end_idx > inputs_len if stride_right > 0 else chunk_end_idx >= inputs_len | |
| _stride_right = 0 if is_last else stride_right | |
| chunk_len = chunk.shape[0] | |
| stride = (chunk_len, _stride_left, _stride_right) | |
| if "input_features" in processed: | |
| processed_len = processed["input_features"].shape[-1] | |
| elif "input_values" in processed: | |
| processed_len = processed["input_values"].shape[-1] | |
| if processed_len != chunk.shape[-1] and rescale: | |
| ratio = processed_len / chunk_len | |
| stride = rescale_stride([stride], ratio)[0] | |
| if chunk.shape[0] > _stride_left: | |
| yield {"is_last": is_last, "stride": stride, **processed} | |
| if is_last: | |
| break | |
| def _fast_find_longest_common_sequence(sequence_left, sequence_right): | |
| seq_len_left = len(sequence_left) | |
| seq_len_right = len(sequence_right) | |
| counter = [[0] * (seq_len_right + 1) for _ in range(seq_len_left + 1)] | |
| longest = 0 | |
| for i in range(seq_len_left): | |
| for j in range(seq_len_right): | |
| if sequence_left[i] == sequence_right[j]: | |
| previous_counter = counter[i][j] + 1 | |
| counter[i + 1][j + 1] = previous_counter | |
| if previous_counter > longest: | |
| longest = previous_counter | |
| counter = np.array(counter) | |
| # we return the idx of the first element of the longest common sequence in the left sequence | |
| index_left = np.argwhere(counter == longest)[-1][0] - longest if longest != 0 else -1 | |
| index_right = np.argwhere(counter == longest)[-1][1] - longest if longest != 0 else -1 | |
| return index_left, index_right, longest | |
| def _find_longest_common_sequence(sequences, tokenizer): | |
| # TODO Use a faster algorithm this can probably be done in O(n) | |
| # using suffix array. | |
| # It might be tedious to do because of fault tolerance. | |
| # We actually have a really good property which is that the total sequence | |
| # MUST be those subsequences in order. | |
| # Also the algorithm should be more tolerant to errors. | |
| sequence = [tok_id for tok_id in sequences[0][0].tolist() if tok_id not in tokenizer.all_special_ids] | |
| for new_seq in sequences[1:]: | |
| new_sequence = [tok_id for tok_id in new_seq[0].tolist() if tok_id not in tokenizer.all_special_ids] | |
| index = 0 | |
| max_ = 0.0 | |
| for i in range(1, len(new_sequence) + 1): | |
| # epsilon to favor long perfect matches | |
| eps = i / 10000.0 | |
| matches = np.sum(np.array(sequence[-i:]) == np.array(new_sequence[:i])) | |
| matching = matches / i + eps | |
| if matches > 1 and matching > max_: | |
| index = i | |
| max_ = matching | |
| sequence.extend(new_sequence[index:]) | |
| return np.array(sequence) | |
| class AutomaticSpeechRecognitionPipeline(ChunkPipeline): | |
| """ | |
| Pipeline that aims at extracting spoken text contained within some audio. | |
| The input can be either a raw waveform or a audio file. In case of the audio file, ffmpeg should be installed for | |
| to support multiple audio formats | |
| Example: | |
| ```python | |
| >>> from transformers import pipeline | |
| >>> transcriber = pipeline(model="openai/whisper-base") | |
| >>> transcriber("https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/1.flac") | |
| {'text': ' He hoped there would be stew for dinner, turnips and carrots and bruised potatoes and fat mutton pieces to be ladled out in thick, peppered flour-fatten sauce.'} | |
| ``` | |
| Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial) | |
| Arguments: | |
| model ([`PreTrainedModel`] or [`TFPreTrainedModel`]): | |
| The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from | |
| [`PreTrainedModel`] for PyTorch and [`TFPreTrainedModel`] for TensorFlow. | |
| tokenizer ([`PreTrainedTokenizer`]): | |
| The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from | |
| [`PreTrainedTokenizer`]. | |
| feature_extractor ([`SequenceFeatureExtractor`]): | |
| The feature extractor that will be used by the pipeline to encode waveform for the model. | |
| chunk_length_s (`float`, *optional*, defaults to 0): | |
| The input length for in each chunk. If `chunk_length_s = 0` then chunking is disabled (default). | |
| <Tip> | |
| For more information on how to effectively use `chunk_length_s`, please have a look at the [ASR chunking | |
| blog post](https://huggingface.co/blog/asr-chunking). | |
| </Tip> | |
| stride_length_s (`float`, *optional*, defaults to `chunk_length_s / 6`): | |
| The length of stride on the left and right of each chunk. Used only with `chunk_length_s > 0`. This enables | |
| the model to *see* more context and infer letters better than without this context but the pipeline | |
| discards the stride bits at the end to make the final reconstitution as perfect as possible. | |
| <Tip> | |
| For more information on how to effectively use `stride_length_s`, please have a look at the [ASR chunking | |
| blog post](https://huggingface.co/blog/asr-chunking). | |
| </Tip> | |
| framework (`str`, *optional*): | |
| The framework to use, either `"pt"` for PyTorch or `"tf"` for TensorFlow. The specified framework must be | |
| installed. If no framework is specified, will default to the one currently installed. If no framework is | |
| specified and both frameworks are installed, will default to the framework of the `model`, or to PyTorch if | |
| no model is provided. | |
| device (Union[`int`, `torch.device`], *optional*): | |
| Device ordinal for CPU/GPU supports. Setting this to `None` will leverage CPU, a positive will run the | |
| model on the associated CUDA device id. | |
| decoder (`pyctcdecode.BeamSearchDecoderCTC`, *optional*): | |
| [PyCTCDecode's | |
| BeamSearchDecoderCTC](https://github.com/kensho-technologies/pyctcdecode/blob/2fd33dc37c4111417e08d89ccd23d28e9b308d19/pyctcdecode/decoder.py#L180) | |
| can be passed for language model boosted decoding. See [`Wav2Vec2ProcessorWithLM`] for more information. | |
| """ | |
| def __init__( | |
| self, | |
| model: "PreTrainedModel", | |
| feature_extractor: Union["SequenceFeatureExtractor", str] = None, | |
| tokenizer: Optional[PreTrainedTokenizer] = None, | |
| decoder: Optional[Union["BeamSearchDecoderCTC", str]] = None, | |
| modelcard: Optional[ModelCard] = None, | |
| framework: Optional[str] = None, | |
| task: str = "", | |
| args_parser: ArgumentHandler = None, | |
| device: Union[int, "torch.device"] = None, | |
| torch_dtype: Optional[Union[str, "torch.dtype"]] = None, | |
| binary_output: bool = False, | |
| **kwargs, | |
| ): | |
| if framework is None: | |
| framework, model = infer_framework_load_model(model, config=model.config) | |
| self.task = task | |
| self.model = model | |
| self.tokenizer = tokenizer | |
| self.feature_extractor = feature_extractor | |
| self.modelcard = modelcard | |
| self.framework = framework | |
| # `accelerate` device map | |
| hf_device_map = getattr(self.model, "hf_device_map", None) | |
| if hf_device_map is not None and device is not None: | |
| raise ValueError( | |
| "The model has been loaded with `accelerate` and therefore cannot be moved to a specific device. Please " | |
| "discard the `device` argument when creating your pipeline object." | |
| ) | |
| if self.framework == "tf": | |
| raise ValueError("The AutomaticSpeechRecognitionPipeline is only available in PyTorch.") | |
| # We shouldn't call `model.to()` for models loaded with accelerate | |
| if device is not None and not (isinstance(device, int) and device < 0): | |
| self.model.to(device) | |
| if device is None: | |
| if hf_device_map is not None: | |
| # Take the first device used by `accelerate`. | |
| device = next(iter(hf_device_map.values())) | |
| else: | |
| device = -1 | |
| if is_torch_available() and self.framework == "pt": | |
| if isinstance(device, torch.device): | |
| self.device = device | |
| elif isinstance(device, str): | |
| self.device = torch.device(device) | |
| elif device < 0: | |
| self.device = torch.device("cpu") | |
| else: | |
| self.device = torch.device(f"cuda:{device}") | |
| else: | |
| self.device = device if device is not None else -1 | |
| self.torch_dtype = torch_dtype | |
| self.binary_output = binary_output | |
| # Update config and generation_config with task specific parameters | |
| task_specific_params = self.model.config.task_specific_params | |
| if task_specific_params is not None and task in task_specific_params: | |
| self.model.config.update(task_specific_params.get(task)) | |
| if self.model.can_generate(): | |
| self.model.generation_config.update(**task_specific_params.get(task)) | |
| self.call_count = 0 | |
| self._batch_size = kwargs.pop("batch_size", None) | |
| self._num_workers = kwargs.pop("num_workers", None) | |
| # set the model type so we can check we have the right pre- and post-processing parameters | |
| if self.model.config.model_type == "whisper": | |
| self.type = "seq2seq_whisper" | |
| elif self.model.__class__.__name__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES.values(): | |
| self.type = "seq2seq" | |
| elif ( | |
| feature_extractor._processor_class | |
| and feature_extractor._processor_class.endswith("WithLM") | |
| and decoder is not None | |
| ): | |
| self.decoder = decoder | |
| self.type = "ctc_with_lm" | |
| else: | |
| self.type = "ctc" | |
| self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters(**kwargs) | |
| mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES.copy() | |
| mapping.update(MODEL_FOR_CTC_MAPPING_NAMES) | |
| self.check_model_type(mapping) | |
| def __call__( | |
| self, | |
| inputs: Union[np.ndarray, bytes, str], | |
| **kwargs, | |
| ): | |
| """ | |
| Transcribe the audio sequence(s) given as inputs to text. See the [`AutomaticSpeechRecognitionPipeline`] | |
| documentation for more information. | |
| Args: | |
| inputs (`np.ndarray` or `bytes` or `str` or `dict`): | |
| The inputs is either : | |
| - `str` that is either the filename of a local audio file, or a public URL address to download the | |
| audio file. The file will be read at the correct sampling rate to get the waveform using | |
| *ffmpeg*. This requires *ffmpeg* to be installed on the system. | |
| - `bytes` it is supposed to be the content of an audio file and is interpreted by *ffmpeg* in the | |
| same way. | |
| - (`np.ndarray` of shape (n, ) of type `np.float32` or `np.float64`) | |
| Raw audio at the correct sampling rate (no further check will be done) | |
| - `dict` form can be used to pass raw audio sampled at arbitrary `sampling_rate` and let this | |
| pipeline do the resampling. The dict must be in the format `{"sampling_rate": int, "raw": | |
| np.array}` with optionally a `"stride": (left: int, right: int)` than can ask the pipeline to | |
| treat the first `left` samples and last `right` samples to be ignored in decoding (but used at | |
| inference to provide more context to the model). Only use `stride` with CTC models. | |
| return_timestamps (*optional*, `str` or `bool`): | |
| Only available for pure CTC models (Wav2Vec2, HuBERT, etc) and the Whisper model. Not available for | |
| other sequence-to-sequence models. | |
| For CTC models, timestamps can take one of two formats: | |
| - `"char"`: the pipeline will return timestamps along the text for every character in the text. For | |
| instance, if you get `[{"text": "h", "timestamp": (0.5, 0.6)}, {"text": "i", "timestamp": (0.7, | |
| 0.9)}]`, then it means the model predicts that the letter "h" was spoken after `0.5` and before | |
| `0.6` seconds. | |
| - `"word"`: the pipeline will return timestamps along the text for every word in the text. For | |
| instance, if you get `[{"text": "hi ", "timestamp": (0.5, 0.9)}, {"text": "there", "timestamp": | |
| (1.0, 1.5)}]`, then it means the model predicts that the word "hi" was spoken after `0.5` and | |
| before `0.9` seconds. | |
| For the Whisper model, timestamps can take one of two formats: | |
| - `"word"`: same as above for word-level CTC timestamps. Word-level timestamps are predicted | |
| through the *dynamic-time warping (DTW)* algorithm, an approximation to word-level timestamps | |
| by inspecting the cross-attention weights. | |
| - `True`: the pipeline will return timestamps along the text for *segments* of words in the text. | |
| For instance, if you get `[{"text": " Hi there!", "timestamp": (0.5, 1.5)}]`, then it means the | |
| model predicts that the segment "Hi there!" was spoken after `0.5` and before `1.5` seconds. | |
| Note that a segment of text refers to a sequence of one or more words, rather than individual | |
| words as with word-level timestamps. | |
| generate_kwargs (`dict`, *optional*): | |
| The dictionary of ad-hoc parametrization of `generate_config` to be used for the generation call. For a | |
| complete overview of generate, check the [following | |
| guide](https://huggingface.co/docs/transformers/en/main_classes/text_generation). | |
| max_new_tokens (`int`, *optional*): | |
| The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt. | |
| Return: | |
| `Dict`: A dictionary with the following keys: | |
| - **text** (`str`): The recognized text. | |
| - **chunks** (*optional(, `List[Dict]`) | |
| When using `return_timestamps`, the `chunks` will become a list containing all the various text | |
| chunks identified by the model, *e.g.* `[{"text": "hi ", "timestamp": (0.5, 0.9)}, {"text": | |
| "there", "timestamp": (1.0, 1.5)}]`. The original full text can roughly be recovered by doing | |
| `"".join(chunk["text"] for chunk in output["chunks"])`. | |
| """ | |
| return super().__call__(inputs, **kwargs) | |
| def _sanitize_parameters( | |
| self, | |
| chunk_length_s=None, | |
| stride_length_s=None, | |
| ignore_warning=None, | |
| decoder_kwargs=None, | |
| return_timestamps=None, | |
| return_language=None, | |
| generate_kwargs=None, | |
| max_new_tokens=None, | |
| ): | |
| # No parameters on this pipeline right now | |
| preprocess_params = {} | |
| if chunk_length_s is not None: | |
| if self.type == "seq2seq" and not ignore_warning: | |
| logger.warning( | |
| "Using `chunk_length_s` is very experimental with seq2seq models. The results will not necessarily" | |
| " be entirely accurate and will have caveats. More information:" | |
| " https://github.com/huggingface/transformers/pull/20104. Ignore this warning with pipeline(...," | |
| " ignore_warning=True)" | |
| ) | |
| preprocess_params["chunk_length_s"] = chunk_length_s | |
| if stride_length_s is not None: | |
| preprocess_params["stride_length_s"] = stride_length_s | |
| forward_params = defaultdict(dict) | |
| if max_new_tokens is not None: | |
| forward_params["generate_kwargs"]["max_new_tokens"] = max_new_tokens | |
| if generate_kwargs is not None: | |
| if max_new_tokens is not None and "max_new_tokens" in generate_kwargs: | |
| raise ValueError( | |
| "`max_new_tokens` is defined both as an argument and inside `generate_kwargs` argument, please use" | |
| " only 1 version" | |
| ) | |
| forward_params["generate_kwargs"].update(generate_kwargs) | |
| postprocess_params = {} | |
| if decoder_kwargs is not None: | |
| postprocess_params["decoder_kwargs"] = decoder_kwargs | |
| if return_timestamps is not None: | |
| # Check whether we have a valid setting for return_timestamps and throw an error before we perform a forward pass | |
| if self.type == "seq2seq" and return_timestamps: | |
| raise ValueError("We cannot return_timestamps yet on non-CTC models apart from Whisper!") | |
| if self.type == "ctc_with_lm" and return_timestamps != "word": | |
| raise ValueError("CTC with LM can only predict word level timestamps, set `return_timestamps='word'`") | |
| if self.type == "ctc" and return_timestamps not in ["char", "word"]: | |
| raise ValueError( | |
| "CTC can either predict character level timestamps, or word level timestamps." | |
| "Set `return_timestamps='char'` or `return_timestamps='word'` as required." | |
| ) | |
| if self.type == "seq2seq_whisper" and return_timestamps == "char": | |
| raise ValueError( | |
| "Whisper cannot return `char` timestamps, only word level or segment level timestamps. " | |
| "Use `return_timestamps='word'` or `return_timestamps=True` respectively." | |
| ) | |
| forward_params["return_timestamps"] = return_timestamps | |
| postprocess_params["return_timestamps"] = return_timestamps | |
| if return_language is not None: | |
| if self.type != "seq2seq_whisper": | |
| raise ValueError("Only Whisper can return language for now.") | |
| postprocess_params["return_language"] = return_language | |
| return preprocess_params, forward_params, postprocess_params | |
| def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None): | |
| if isinstance(inputs, str): | |
| if inputs.startswith("http://") or inputs.startswith("https://"): | |
| # We need to actually check for a real protocol, otherwise it's impossible to use a local file | |
| # like http_huggingface_co.png | |
| inputs = requests.get(inputs).content | |
| else: | |
| with open(inputs, "rb") as f: | |
| inputs = f.read() | |
| if isinstance(inputs, bytes): | |
| inputs = ffmpeg_read(inputs, self.feature_extractor.sampling_rate) | |
| stride = None | |
| extra = {} | |
| if isinstance(inputs, dict): | |
| stride = inputs.pop("stride", None) | |
| # Accepting `"array"` which is the key defined in `datasets` for | |
| # better integration | |
| if not ("sampling_rate" in inputs and ("raw" in inputs or "array" in inputs)): | |
| raise ValueError( | |
| "When passing a dictionary to AutomaticSpeechRecognitionPipeline, the dict needs to contain a " | |
| '"raw" key containing the numpy array representing the audio and a "sampling_rate" key, ' | |
| "containing the sampling_rate associated with that array" | |
| ) | |
| _inputs = inputs.pop("raw", None) | |
| if _inputs is None: | |
| # Remove path which will not be used from `datasets`. | |
| inputs.pop("path", None) | |
| _inputs = inputs.pop("array", None) | |
| in_sampling_rate = inputs.pop("sampling_rate") | |
| extra = inputs | |
| inputs = _inputs | |
| if in_sampling_rate != self.feature_extractor.sampling_rate: | |
| if is_torchaudio_available(): | |
| from torchaudio import functional as F | |
| else: | |
| raise ImportError( | |
| "torchaudio is required to resample audio samples in AutomaticSpeechRecognitionPipeline. " | |
| "The torchaudio package can be installed through: `pip install torchaudio`." | |
| ) | |
| inputs = F.resample( | |
| torch.from_numpy(inputs), in_sampling_rate, self.feature_extractor.sampling_rate | |
| ).numpy() | |
| ratio = self.feature_extractor.sampling_rate / in_sampling_rate | |
| else: | |
| ratio = 1 | |
| if stride is not None: | |
| if stride[0] + stride[1] > inputs.shape[0]: | |
| raise ValueError("Stride is too large for input") | |
| # Stride needs to get the chunk length here, it's going to get | |
| # swallowed by the `feature_extractor` later, and then batching | |
| # can add extra data in the inputs, so we need to keep track | |
| # of the original length in the stride so we can cut properly. | |
| stride = (inputs.shape[0], int(round(stride[0] * ratio)), int(round(stride[1] * ratio))) | |
| if not isinstance(inputs, np.ndarray): | |
| raise ValueError(f"We expect a numpy ndarray as input, got `{type(inputs)}`") | |
| if len(inputs.shape) != 1: | |
| raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline") | |
| if chunk_length_s: | |
| if stride_length_s is None: | |
| stride_length_s = chunk_length_s / 6 | |
| if isinstance(stride_length_s, (int, float)): | |
| stride_length_s = [stride_length_s, stride_length_s] | |
| # XXX: Carefuly, this variable will not exist in `seq2seq` setting. | |
| # Currently chunking is not possible at this level for `seq2seq` so | |
| # it's ok. | |
| align_to = getattr(self.model.config, "inputs_to_logits_ratio", 1) | |
| chunk_len = int(round(chunk_length_s * self.feature_extractor.sampling_rate / align_to) * align_to) | |
| stride_left = int(round(stride_length_s[0] * self.feature_extractor.sampling_rate / align_to) * align_to) | |
| stride_right = int(round(stride_length_s[1] * self.feature_extractor.sampling_rate / align_to) * align_to) | |
| if chunk_len < stride_left + stride_right: | |
| raise ValueError("Chunk length must be superior to stride length") | |
| rescale = self.type != "seq2seq_whisper" | |
| # make sure that | |
| for item in chunk_iter( | |
| inputs, self.feature_extractor, chunk_len, stride_left, stride_right, rescale, self.torch_dtype | |
| ): | |
| yield item | |
| else: | |
| processed = self.feature_extractor( | |
| inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt" | |
| ) | |
| if self.torch_dtype is not None: | |
| processed = processed.to(dtype=self.torch_dtype) | |
| if stride is not None: | |
| if self.type == "seq2seq": | |
| raise ValueError("Stride is only usable with CTC models, try removing it !") | |
| processed["stride"] = stride | |
| yield {"is_last": True, **processed, **extra} | |
| def _forward(self, model_inputs, return_timestamps=False, generate_kwargs=None): | |
| if generate_kwargs is None: | |
| generate_kwargs = {} | |
| attention_mask = model_inputs.pop("attention_mask", None) | |
| stride = model_inputs.pop("stride", None) | |
| is_last = model_inputs.pop("is_last") | |
| if self.type in {"seq2seq", "seq2seq_whisper"}: | |
| encoder = self.model.get_encoder() | |
| # Consume values so we can let extra information flow freely through | |
| # the pipeline (important for `partial` in microphone) | |
| if "input_features" in model_inputs: | |
| inputs = model_inputs.pop("input_features") | |
| elif "input_values" in model_inputs: | |
| inputs = model_inputs.pop("input_values") | |
| else: | |
| raise ValueError( | |
| "Seq2Seq speech recognition model requires either a " | |
| f"`input_features` or `input_values` key, but only has {model_inputs.keys()}" | |
| ) | |
| # custom processing for Whisper timestamps and word-level timestamps | |
| if return_timestamps and self.type == "seq2seq_whisper": | |
| generate_kwargs["return_timestamps"] = return_timestamps | |
| if return_timestamps == "word": | |
| generate_kwargs["return_token_timestamps"] = True | |
| if stride is not None: | |
| generate_kwargs["num_frames"] = stride[0] // self.feature_extractor.hop_length | |
| tokens = self.model.generate( | |
| encoder_outputs=encoder(inputs, attention_mask=attention_mask), | |
| attention_mask=attention_mask, | |
| **generate_kwargs, | |
| ) | |
| if return_timestamps == "word" and self.type == "seq2seq_whisper": | |
| out = {"tokens": tokens["sequences"], "token_timestamps": tokens["token_timestamps"]} | |
| else: | |
| out = {"tokens": tokens} | |
| if self.type == "seq2seq_whisper": | |
| if stride is not None: | |
| out["stride"] = stride | |
| else: | |
| input_values = model_inputs.pop("input_values") | |
| outputs = self.model(input_values=input_values, attention_mask=attention_mask) | |
| logits = outputs.logits | |
| if self.type == "ctc_with_lm": | |
| out = {"logits": logits} | |
| else: | |
| out = {"tokens": logits.argmax(dim=-1)} | |
| if stride is not None: | |
| # Send stride to `postprocess`. | |
| # it needs to be handled there where | |
| # the pieces are to be concatenated. | |
| ratio = 1 / self.model.config.inputs_to_logits_ratio | |
| if isinstance(stride, tuple): | |
| out["stride"] = rescale_stride([stride], ratio)[0] | |
| else: | |
| out["stride"] = rescale_stride(stride, ratio) | |
| # Leftover | |
| extra = model_inputs | |
| return {"is_last": is_last, **out, **extra} | |
| def postprocess( | |
| self, model_outputs, decoder_kwargs: Optional[Dict] = None, return_timestamps=None, return_language=None | |
| ): | |
| # Optional return types | |
| optional = {} | |
| final_items = [] | |
| key = "logits" if self.type == "ctc_with_lm" else "tokens" | |
| stride = None | |
| for outputs in model_outputs: | |
| items = outputs[key].numpy() | |
| stride = outputs.get("stride", None) | |
| if stride is not None and self.type in {"ctc", "ctc_with_lm"}: | |
| total_n, left, right = stride | |
| # Total_n might be < logits.shape[1] | |
| # because of padding, that's why | |
| # we need to reconstruct this information | |
| # This won't work with left padding (which doesn't exist right now) | |
| right_n = total_n - right | |
| items = items[:, left:right_n] | |
| final_items.append(items) | |
| if stride and self.type == "seq2seq": | |
| items = _find_longest_common_sequence(final_items, self.tokenizer) | |
| elif self.type == "seq2seq_whisper": | |
| time_precision = self.feature_extractor.chunk_length / self.model.config.max_source_positions | |
| # Send the chunking back to seconds, it's easier to handle in whisper | |
| sampling_rate = self.feature_extractor.sampling_rate | |
| for output in model_outputs: | |
| if "stride" in output: | |
| chunk_len, stride_left, stride_right = output["stride"] | |
| # Go back in seconds | |
| chunk_len /= sampling_rate | |
| stride_left /= sampling_rate | |
| stride_right /= sampling_rate | |
| output["stride"] = chunk_len, stride_left, stride_right | |
| text, optional = self.tokenizer._decode_asr( | |
| model_outputs, | |
| return_timestamps=return_timestamps, | |
| return_language=return_language, | |
| time_precision=time_precision, | |
| ) | |
| else: | |
| items = np.concatenate(final_items, axis=1) | |
| items = items.squeeze(0) | |
| if self.type == "ctc_with_lm": | |
| if decoder_kwargs is None: | |
| decoder_kwargs = {} | |
| beams = self.decoder.decode_beams(items, **decoder_kwargs) | |
| text = beams[0][0] | |
| if return_timestamps: | |
| # Simply cast from pyctcdecode format to wav2vec2 format to leverage | |
| # pre-existing code later | |
| chunk_offset = beams[0][2] | |
| offsets = [] | |
| for word, (start_offset, end_offset) in chunk_offset: | |
| offsets.append({"word": word, "start_offset": start_offset, "end_offset": end_offset}) | |
| elif self.type != "seq2seq_whisper": | |
| skip_special_tokens = self.type != "ctc" | |
| text = self.tokenizer.decode(items, skip_special_tokens=skip_special_tokens) | |
| if return_timestamps: | |
| offsets = self.tokenizer.decode( | |
| items, skip_special_tokens=skip_special_tokens, output_char_offsets=True | |
| )["char_offsets"] | |
| if return_timestamps == "word": | |
| offsets = self.tokenizer._get_word_offsets(offsets, self.tokenizer.replace_word_delimiter_char) | |
| if return_timestamps and self.type not in {"seq2seq", "seq2seq_whisper"}: | |
| chunks = [] | |
| for item in offsets: | |
| start = item["start_offset"] * self.model.config.inputs_to_logits_ratio | |
| start /= self.feature_extractor.sampling_rate | |
| stop = item["end_offset"] * self.model.config.inputs_to_logits_ratio | |
| stop /= self.feature_extractor.sampling_rate | |
| chunks.append({"text": item[return_timestamps], "timestamp": (start, stop)}) | |
| optional["chunks"] = chunks | |
| extra = defaultdict(list) | |
| for output in model_outputs: | |
| output.pop("tokens", None) | |
| output.pop("logits", None) | |
| output.pop("is_last", None) | |
| output.pop("stride", None) | |
| output.pop("token_timestamps", None) | |
| for k, v in output.items(): | |
| extra[k].append(v) | |
| return {"text": text, **optional, **extra} | |
| def _find_timestamp_sequence(sequences, tokenizer, feature_extractor, max_source_positions): | |
| """ | |
| Computes the final sequences by merging the end of the nth sequence with the beginning of the n+1th sequence. Since | |
| `WhisperForConditionalGeneration` produces the timestamps pairwise, we filter the consecutive timestamps and only | |
| iterate over them. We keep track of the `time` which indicates the actual starting time of the chunk that is | |
| processed. We need to make sure to offset the timestamps tokens by the `time` in order for the tokenizer to | |
| properly compute the final `offset`. | |
| """ | |
| # index of the first timestamp token | |
| timestamp_begin = tokenizer.convert_tokens_to_ids("<|notimestamps|>") + 1 | |
| items = [] | |
| # approximation of the token to time ratio : ~0.2seconds | |
| time_precision = feature_extractor.chunk_length / max_source_positions | |
| time = 0 | |
| for seq_idx, item in enumerate(sequences): | |
| sequence, stride = item | |
| if isinstance(sequence, list): | |
| sequence = np.array(sequence) | |
| chunk_len, stride_left, stride_right = stride | |
| sequence = sequence.squeeze(0) | |
| # get rid of the `forced_decoder_idx` that are use to parametrize the generation | |
| begin_idx = np.where(sequence == timestamp_begin)[0][0] if timestamp_begin in sequence else 0 | |
| sequence = sequence[begin_idx:] | |
| timestamp_tokens = sequence >= timestamp_begin | |
| if seq_idx != 0 and sum(timestamp_tokens) > 0: | |
| consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1 | |
| last_timestamp = np.where(timestamp_tokens)[0][-1] | |
| consecutive = np.append(consecutive, last_timestamp) if last_timestamp not in consecutive else consecutive | |
| time -= stride_left + stride_right | |
| offset = int((time / feature_extractor.sampling_rate) / time_precision) | |
| overlap_time = int((stride_left / feature_extractor.sampling_rate) / time_precision) | |
| # relevant timestamps are in the overlapping part | |
| relevant_timestamp = np.where(sequence[consecutive] >= timestamp_begin + overlap_time)[0] | |
| if relevant_timestamp.shape[0] > 0: | |
| relevant_timestamp = ( | |
| consecutive[relevant_timestamp[0] - 1] if relevant_timestamp[0] > 0 else consecutive[0] | |
| ) | |
| # if a big stride is used, we need to check some of the previous items for the best overlap | |
| best_match = 0 | |
| sliced_sequence = [] | |
| for idx, previous_sequence in enumerate(reversed(items)): | |
| previous_tokens = previous_sequence[1:-1] | |
| if previous_sequence[0] < (timestamp_begin + offset - overlap_time) and idx != 0: | |
| break # the previous sequence is too far in the past | |
| if len(previous_tokens) > 0: | |
| # find the longest common sequence between the overlapping parts | |
| index_left, index_right, match_length = _fast_find_longest_common_sequence( | |
| sequence[1:relevant_timestamp], previous_tokens | |
| ) | |
| # don't do anything if only 1 token was matched | |
| if match_length > 1 and match_length > best_match: | |
| best_match = match_length | |
| best_idx = idx | |
| end_of_curr_sequence_idx = ( | |
| np.where(sequence[index_left + 1 :] >= timestamp_begin)[0][0] + 1 | |
| ) | |
| end_of_curr_sequence_idx = end_of_curr_sequence_idx + 1 + index_left | |
| # if all the tokens are matched, suffix | |
| if index_left == 0 and match_length == len(previous_tokens): | |
| sliced_sequence = np.insert( | |
| sequence[index_left + 1 : end_of_curr_sequence_idx], 0, previous_sequence[0] | |
| ) | |
| sliced_sequence[-1] = previous_sequence[-1] | |
| # if part of the previous sequence is not taken | |
| elif index_left >= 0: | |
| sliced_sequence = sequence[index_left + 1 : end_of_curr_sequence_idx] | |
| # let's insert the missing part of the previous sequence | |
| previous_slice = ( | |
| previous_sequence[: index_right + 1] if index_right > 0 else [previous_sequence[0]] | |
| ) | |
| sliced_sequence = np.insert(sliced_sequence, 0, previous_slice) | |
| sliced_sequence[-1] += offset | |
| if len(sliced_sequence) > 0: | |
| items[len(items) - best_idx - 1] = sliced_sequence | |
| items = items[: len(items) - best_idx] | |
| sequence = sequence[end_of_curr_sequence_idx:] | |
| # sequence might have changed | |
| timestamp_tokens = sequence >= timestamp_begin | |
| consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1 | |
| if sum(timestamp_tokens) > 0: | |
| last_timestamp = np.where(timestamp_tokens)[0][-1] | |
| consecutive = ( | |
| np.append(consecutive, last_timestamp + 1) if last_timestamp not in consecutive else consecutive | |
| ) | |
| if len(consecutive) > 0: | |
| last_slice = 0 | |
| for current_slice in consecutive: | |
| actual_offset = items[-1][-1] if seq_idx != 0 or last_slice != 0 else sequence[0] | |
| sliced_tokens = sequence[last_slice:current_slice] | |
| duration = sliced_tokens[-1] - sliced_tokens[0] | |
| sliced_tokens[0] = actual_offset | |
| sliced_tokens[-1] = actual_offset + duration | |
| items.append(sliced_tokens) | |
| last_slice = current_slice | |
| time += chunk_len | |
| result = [] | |
| for i in range(len(items)): | |
| result += items[i].tolist() | |
| return result | |