MRI
/
venv
/lib
/python3.13
/site-packages
/transformers
/pipelines
/automatic_speech_recognition.py
| # Copyright 2021 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from collections import defaultdict | |
| from typing import TYPE_CHECKING, Any, Optional, Union | |
| import numpy as np | |
| import requests | |
| from ..generation import GenerationConfig | |
| from ..tokenization_utils import PreTrainedTokenizer | |
| from ..utils import is_torch_available, is_torchaudio_available, is_torchcodec_available, logging | |
| from .audio_utils import ffmpeg_read | |
| from .base import ChunkPipeline | |
| if TYPE_CHECKING: | |
| from pyctcdecode import BeamSearchDecoderCTC | |
| from ..feature_extraction_sequence_utils import SequenceFeatureExtractor | |
| from ..modeling_utils import PreTrainedModel | |
| logger = logging.get_logger(__name__) | |
| if is_torch_available(): | |
| import torch | |
| from ..models.auto.modeling_auto import MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES | |
| def rescale_stride(stride, ratio): | |
| """ | |
| Rescales the stride values from audio space to tokens/logits space. | |
| (160_000, 16_000, 16_000) -> (2000, 200, 200) for instance. | |
| """ | |
| # Shape is [B, SEQ] for tokens | |
| # [B, SEQ, V] for logits | |
| new_strides = [] | |
| for input_n, left, right in stride: | |
| token_n = int(round(input_n * ratio)) | |
| left = int(round(left / input_n * token_n)) | |
| right = int(round(right / input_n * token_n)) | |
| new_stride = (token_n, left, right) | |
| new_strides.append(new_stride) | |
| return new_strides | |
| def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right, dtype=None): | |
| inputs_len = inputs.shape[0] | |
| step = chunk_len - stride_left - stride_right | |
| for chunk_start_idx in range(0, inputs_len, step): | |
| chunk_end_idx = chunk_start_idx + chunk_len | |
| chunk = inputs[chunk_start_idx:chunk_end_idx] | |
| processed = feature_extractor( | |
| chunk, | |
| sampling_rate=feature_extractor.sampling_rate, | |
| return_tensors="pt", | |
| return_attention_mask=True, | |
| ) | |
| if dtype is not None: | |
| processed = processed.to(dtype=dtype) | |
| _stride_left = 0 if chunk_start_idx == 0 else stride_left | |
| is_last = chunk_end_idx >= inputs_len | |
| _stride_right = 0 if is_last else stride_right | |
| chunk_len = chunk.shape[0] | |
| stride = (chunk_len, _stride_left, _stride_right) | |
| if chunk.shape[0] > _stride_left: | |
| yield {"is_last": is_last, "stride": stride, **processed} | |
| if is_last: | |
| break | |
| def _find_longest_common_sequence(sequences, tokenizer): | |
| # TODO Use a faster algorithm this can probably be done in O(n) | |
| # using suffix array. | |
| # It might be tedious to do because of fault tolerance. | |
| # We actually have a really good property which is that the total sequence | |
| # MUST be those subsequences in order. | |
| # Also the algorithm should be more tolerant to errors. | |
| sequence = [tok_id for tok_id in sequences[0][0].tolist() if tok_id not in tokenizer.all_special_ids] | |
| for new_seq in sequences[1:]: | |
| new_sequence = [tok_id for tok_id in new_seq[0].tolist() if tok_id not in tokenizer.all_special_ids] | |
| index = 0 | |
| max_ = 0.0 | |
| for i in range(1, len(new_sequence) + 1): | |
| # epsilon to favor long perfect matches | |
| eps = i / 10000.0 | |
| matches = np.sum(np.array(sequence[-i:]) == np.array(new_sequence[:i])) | |
| matching = matches / i + eps | |
| if matches > 1 and matching > max_: | |
| index = i | |
| max_ = matching | |
| sequence.extend(new_sequence[index:]) | |
| return np.array(sequence) | |
| class AutomaticSpeechRecognitionPipeline(ChunkPipeline): | |
| """ | |
| Pipeline that aims at extracting spoken text contained within some audio. | |
| The input can be either a raw waveform or a audio file. In case of the audio file, ffmpeg should be installed for | |
| to support multiple audio formats | |
| Unless the model you're using explicitly sets these generation parameters in its configuration files | |
| (`generation_config.json`), the following default values will be used: | |
| - max_new_tokens: 256 | |
| - num_beams: 5 | |
| Example: | |
| ```python | |
| >>> from transformers import pipeline | |
| >>> transcriber = pipeline(model="openai/whisper-base") | |
| >>> transcriber("https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/1.flac") | |
| {'text': ' He hoped there would be stew for dinner, turnips and carrots and bruised potatoes and fat mutton pieces to be ladled out in thick, peppered flour-fatten sauce.'} | |
| ``` | |
| Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial) | |
| Arguments: | |
| model ([`PreTrainedModel`] or [`TFPreTrainedModel`]): | |
| The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from | |
| [`PreTrainedModel`] for PyTorch and [`TFPreTrainedModel`] for TensorFlow. | |
| feature_extractor ([`SequenceFeatureExtractor`]): | |
| The feature extractor that will be used by the pipeline to encode waveform for the model. | |
| tokenizer ([`PreTrainedTokenizer`]): | |
| The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from | |
| [`PreTrainedTokenizer`]. | |
| decoder (`pyctcdecode.BeamSearchDecoderCTC`, *optional*): | |
| [PyCTCDecode's | |
| BeamSearchDecoderCTC](https://github.com/kensho-technologies/pyctcdecode/blob/2fd33dc37c4111417e08d89ccd23d28e9b308d19/pyctcdecode/decoder.py#L180) | |
| can be passed for language model boosted decoding. See [`Wav2Vec2ProcessorWithLM`] for more information. | |
| chunk_length_s (`float`, *optional*, defaults to 0): | |
| The input length for in each chunk. If `chunk_length_s = 0` then chunking is disabled (default). | |
| <Tip> | |
| For more information on how to effectively use `chunk_length_s`, please have a look at the [ASR chunking | |
| blog post](https://huggingface.co/blog/asr-chunking). | |
| </Tip> | |
| stride_length_s (`float`, *optional*, defaults to `chunk_length_s / 6`): | |
| The length of stride on the left and right of each chunk. Used only with `chunk_length_s > 0`. This enables | |
| the model to *see* more context and infer letters better than without this context but the pipeline | |
| discards the stride bits at the end to make the final reconstitution as perfect as possible. | |
| <Tip> | |
| For more information on how to effectively use `stride_length_s`, please have a look at the [ASR chunking | |
| blog post](https://huggingface.co/blog/asr-chunking). | |
| </Tip> | |
| framework (`str`, *optional*): | |
| The framework to use, either `"pt"` for PyTorch or `"tf"` for TensorFlow. The specified framework must be | |
| installed. If no framework is specified, will default to the one currently installed. If no framework is | |
| specified and both frameworks are installed, will default to the framework of the `model`, or to PyTorch if | |
| no model is provided. | |
| device (Union[`int`, `torch.device`], *optional*): | |
| Device ordinal for CPU/GPU supports. Setting this to `None` will leverage CPU, a positive will run the | |
| model on the associated CUDA device id. | |
| """ | |
| _pipeline_calls_generate = True | |
| _load_processor = False | |
| _load_image_processor = False | |
| _load_feature_extractor = True | |
| _load_tokenizer = True | |
| # Make sure the docstring is updated when the default generation config is changed | |
| _default_generation_config = GenerationConfig( | |
| max_new_tokens=256, | |
| num_beams=5, # follows openai's whisper implementation | |
| ) | |
| def __init__( | |
| self, | |
| model: "PreTrainedModel", | |
| feature_extractor: Optional[Union["SequenceFeatureExtractor", str]] = None, | |
| tokenizer: Optional[PreTrainedTokenizer] = None, | |
| decoder: Optional[Union["BeamSearchDecoderCTC", str]] = None, | |
| device: Optional[Union[int, "torch.device"]] = None, | |
| **kwargs, | |
| ): | |
| # set the model type so we can check we have the right pre- and post-processing parameters | |
| if model.config.model_type == "whisper": | |
| self.type = "seq2seq_whisper" | |
| elif model.__class__.__name__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES.values(): | |
| self.type = "seq2seq" | |
| elif ( | |
| feature_extractor._processor_class | |
| and feature_extractor._processor_class.endswith("WithLM") | |
| and decoder is not None | |
| ): | |
| self.decoder = decoder | |
| self.type = "ctc_with_lm" | |
| else: | |
| self.type = "ctc" | |
| super().__init__(model, tokenizer, feature_extractor, device=device, **kwargs) | |
| def __call__(self, inputs: Union[np.ndarray, bytes, str, dict], **kwargs: Any) -> list[dict[str, Any]]: | |
| """ | |
| Transcribe the audio sequence(s) given as inputs to text. See the [`AutomaticSpeechRecognitionPipeline`] | |
| documentation for more information. | |
| Args: | |
| inputs (`np.ndarray` or `bytes` or `str` or `dict`): | |
| The inputs is either : | |
| - `str` that is either the filename of a local audio file, or a public URL address to download the | |
| audio file. The file will be read at the correct sampling rate to get the waveform using | |
| *ffmpeg*. This requires *ffmpeg* to be installed on the system. | |
| - `bytes` it is supposed to be the content of an audio file and is interpreted by *ffmpeg* in the | |
| same way. | |
| - (`np.ndarray` of shape (n, ) of type `np.float32` or `np.float64`) | |
| Raw audio at the correct sampling rate (no further check will be done) | |
| - `dict` form can be used to pass raw audio sampled at arbitrary `sampling_rate` and let this | |
| pipeline do the resampling. The dict must be in the format `{"sampling_rate": int, "raw": | |
| np.array}` with optionally a `"stride": (left: int, right: int)` than can ask the pipeline to | |
| treat the first `left` samples and last `right` samples to be ignored in decoding (but used at | |
| inference to provide more context to the model). Only use `stride` with CTC models. | |
| return_timestamps (*optional*, `str` or `bool`): | |
| Only available for pure CTC models (Wav2Vec2, HuBERT, etc) and the Whisper model. Not available for | |
| other sequence-to-sequence models. | |
| For CTC models, timestamps can take one of two formats: | |
| - `"char"`: the pipeline will return timestamps along the text for every character in the text. For | |
| instance, if you get `[{"text": "h", "timestamp": (0.5, 0.6)}, {"text": "i", "timestamp": (0.7, | |
| 0.9)}]`, then it means the model predicts that the letter "h" was spoken after `0.5` and before | |
| `0.6` seconds. | |
| - `"word"`: the pipeline will return timestamps along the text for every word in the text. For | |
| instance, if you get `[{"text": "hi ", "timestamp": (0.5, 0.9)}, {"text": "there", "timestamp": | |
| (1.0, 1.5)}]`, then it means the model predicts that the word "hi" was spoken after `0.5` and | |
| before `0.9` seconds. | |
| For the Whisper model, timestamps can take one of two formats: | |
| - `"word"`: same as above for word-level CTC timestamps. Word-level timestamps are predicted | |
| through the *dynamic-time warping (DTW)* algorithm, an approximation to word-level timestamps | |
| by inspecting the cross-attention weights. | |
| - `True`: the pipeline will return timestamps along the text for *segments* of words in the text. | |
| For instance, if you get `[{"text": " Hi there!", "timestamp": (0.5, 1.5)}]`, then it means the | |
| model predicts that the segment "Hi there!" was spoken after `0.5` and before `1.5` seconds. | |
| Note that a segment of text refers to a sequence of one or more words, rather than individual | |
| words as with word-level timestamps. | |
| generate_kwargs (`dict`, *optional*): | |
| The dictionary of ad-hoc parametrization of `generate_config` to be used for the generation call. For a | |
| complete overview of generate, check the [following | |
| guide](https://huggingface.co/docs/transformers/en/main_classes/text_generation). | |
| Return: | |
| `Dict`: A dictionary with the following keys: | |
| - **text** (`str`): The recognized text. | |
| - **chunks** (*optional(, `list[Dict]`) | |
| When using `return_timestamps`, the `chunks` will become a list containing all the various text | |
| chunks identified by the model, *e.g.* `[{"text": "hi ", "timestamp": (0.5, 0.9)}, {"text": | |
| "there", "timestamp": (1.0, 1.5)}]`. The original full text can roughly be recovered by doing | |
| `"".join(chunk["text"] for chunk in output["chunks"])`. | |
| """ | |
| return super().__call__(inputs, **kwargs) | |
| def _sanitize_parameters( | |
| self, | |
| chunk_length_s=None, | |
| stride_length_s=None, | |
| ignore_warning=None, | |
| decoder_kwargs=None, | |
| return_timestamps=None, | |
| return_language=None, | |
| **generate_kwargs, | |
| ): | |
| preprocess_params = {} | |
| forward_params = {} | |
| postprocess_params = {} | |
| # Preprocess params | |
| if chunk_length_s is not None: | |
| if self.type in ["seq2seq", "seq2seq_whisper"] and not ignore_warning: | |
| type_warning = ( | |
| "Using `chunk_length_s` is very experimental with seq2seq models. The results will not necessarily" | |
| " be entirely accurate and will have caveats. More information:" | |
| " https://github.com/huggingface/transformers/pull/20104. Ignore this warning with pipeline(...," | |
| " ignore_warning=True)." | |
| ) | |
| if self.type == "seq2seq_whisper": | |
| type_warning += ( | |
| " To use Whisper for long-form transcription, use rather the model's `generate` method directly " | |
| "as the model relies on it's own chunking mechanism (cf. Whisper original paper, section 3.8. " | |
| "Long-form Transcription)." | |
| ) | |
| logger.warning(type_warning) | |
| preprocess_params["chunk_length_s"] = chunk_length_s | |
| if stride_length_s is not None: | |
| preprocess_params["stride_length_s"] = stride_length_s | |
| # Forward params | |
| # BC: accept a dictionary of generation kwargs (as opposed to **generate_kwargs) | |
| if "generate_kwargs" in generate_kwargs: | |
| forward_params.update(generate_kwargs.pop("generate_kwargs")) | |
| # Default use for kwargs: they are generation-time kwargs | |
| forward_params.update(generate_kwargs) | |
| if getattr(self, "assistant_model", None) is not None: | |
| forward_params["assistant_model"] = self.assistant_model | |
| if getattr(self, "assistant_tokenizer", None) is not None: | |
| forward_params["tokenizer"] = self.tokenizer | |
| forward_params["assistant_tokenizer"] = self.assistant_tokenizer | |
| # Postprocess params | |
| if decoder_kwargs is not None: | |
| postprocess_params["decoder_kwargs"] = decoder_kwargs | |
| if return_language is not None: | |
| if self.type != "seq2seq_whisper": | |
| raise ValueError("Only Whisper can return language for now.") | |
| postprocess_params["return_language"] = return_language | |
| # Parameter used in more than one place | |
| # in some models like whisper, the generation config has a `return_timestamps` key | |
| if hasattr(self, "generation_config") and hasattr(self.generation_config, "return_timestamps"): | |
| return_timestamps = return_timestamps or self.generation_config.return_timestamps | |
| if return_timestamps is not None: | |
| # Check whether we have a valid setting for return_timestamps and throw an error before we perform a forward pass | |
| if self.type == "seq2seq" and return_timestamps: | |
| raise ValueError("We cannot return_timestamps yet on non-CTC models apart from Whisper!") | |
| if self.type == "ctc_with_lm" and return_timestamps != "word": | |
| raise ValueError("CTC with LM can only predict word level timestamps, set `return_timestamps='word'`") | |
| if self.type == "ctc" and return_timestamps not in ["char", "word"]: | |
| raise ValueError( | |
| "CTC can either predict character level timestamps, or word level timestamps. " | |
| "Set `return_timestamps='char'` or `return_timestamps='word'` as required." | |
| ) | |
| if self.type == "seq2seq_whisper" and return_timestamps == "char": | |
| raise ValueError( | |
| "Whisper cannot return `char` timestamps, only word level or segment level timestamps. " | |
| "Use `return_timestamps='word'` or `return_timestamps=True` respectively." | |
| ) | |
| forward_params["return_timestamps"] = return_timestamps | |
| postprocess_params["return_timestamps"] = return_timestamps | |
| return preprocess_params, forward_params, postprocess_params | |
| def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None): | |
| if isinstance(inputs, str): | |
| if inputs.startswith("http://") or inputs.startswith("https://"): | |
| # We need to actually check for a real protocol, otherwise it's impossible to use a local file | |
| # like http_huggingface_co.png | |
| inputs = requests.get(inputs).content | |
| else: | |
| with open(inputs, "rb") as f: | |
| inputs = f.read() | |
| if isinstance(inputs, bytes): | |
| inputs = ffmpeg_read(inputs, self.feature_extractor.sampling_rate) | |
| stride = None | |
| extra = {} | |
| if is_torch_available(): | |
| import torch | |
| if isinstance(inputs, torch.Tensor): | |
| inputs = inputs.cpu().numpy() | |
| if is_torchcodec_available(): | |
| import torchcodec | |
| if isinstance(inputs, torchcodec.decoders.AudioDecoder): | |
| _audio_samples = inputs.get_all_samples() | |
| # torchcodec always returns (num_channels, num_samples) | |
| # while before (datasets < 4.0) we had (2, num_samples) if stereo, (num_samples,) if mono | |
| _array = _audio_samples.data | |
| _array = _array[0] if _array.ndim == 2 and _array.shape[0] == 1 else _array | |
| inputs = {"array": _array, "sampling_rate": _audio_samples.sample_rate} | |
| if isinstance(inputs, dict): | |
| stride = inputs.pop("stride", None) | |
| # Accepting `"array"` which is the key defined in `datasets` for | |
| # better integration | |
| if not ("sampling_rate" in inputs and ("raw" in inputs or "array" in inputs)): | |
| raise ValueError( | |
| "When passing a dictionary to AutomaticSpeechRecognitionPipeline, the dict needs to contain a " | |
| '"raw" key containing the numpy array or torch tensor representing the audio and a "sampling_rate" key, ' | |
| "containing the sampling_rate associated with that array" | |
| ) | |
| _inputs = inputs.pop("raw", None) | |
| if _inputs is None: | |
| # Remove path which will not be used from `datasets`. | |
| inputs.pop("path", None) | |
| _inputs = inputs.pop("array", None) | |
| in_sampling_rate = inputs.pop("sampling_rate") | |
| extra = inputs | |
| inputs = _inputs | |
| if in_sampling_rate != self.feature_extractor.sampling_rate: | |
| if is_torchaudio_available(): | |
| from torchaudio import functional as F | |
| else: | |
| raise ImportError( | |
| "torchaudio is required to resample audio samples in AutomaticSpeechRecognitionPipeline. " | |
| "The torchaudio package can be installed through: `pip install torchaudio`." | |
| ) | |
| inputs = F.resample( | |
| torch.from_numpy(inputs) if isinstance(inputs, np.ndarray) else inputs, | |
| in_sampling_rate, | |
| in_sampling_rate, | |
| self.feature_extractor.sampling_rate, | |
| ).numpy() | |
| ratio = self.feature_extractor.sampling_rate / in_sampling_rate | |
| else: | |
| ratio = 1 | |
| if stride is not None: | |
| if stride[0] + stride[1] > inputs.shape[0]: | |
| raise ValueError("Stride is too large for input") | |
| # Stride needs to get the chunk length here, it's going to get | |
| # swallowed by the `feature_extractor` later, and then batching | |
| # can add extra data in the inputs, so we need to keep track | |
| # of the original length in the stride so we can cut properly. | |
| stride = (inputs.shape[0], int(round(stride[0] * ratio)), int(round(stride[1] * ratio))) | |
| if not isinstance(inputs, (np.ndarray, torch.Tensor)): | |
| raise TypeError(f"We expect a numpy ndarray or torch tensor as input, got `{type(inputs)}`") | |
| if inputs.ndim != 1: | |
| logger.warning( | |
| f"We expect a single channel audio input for AutomaticSpeechRecognitionPipeline, got {inputs.ndim}. Taking the mean of the channels for mono conversion." | |
| ) | |
| inputs = inputs.mean(axis=0) | |
| if chunk_length_s: | |
| if stride_length_s is None: | |
| stride_length_s = chunk_length_s / 6 | |
| if isinstance(stride_length_s, (int, float)): | |
| stride_length_s = [stride_length_s, stride_length_s] | |
| # XXX: Carefully, this variable will not exist in `seq2seq` setting. | |
| # Currently chunking is not possible at this level for `seq2seq` so | |
| # it's ok. | |
| align_to = getattr(self.model.config, "inputs_to_logits_ratio", 1) | |
| chunk_len = int(round(chunk_length_s * self.feature_extractor.sampling_rate / align_to) * align_to) | |
| stride_left = int(round(stride_length_s[0] * self.feature_extractor.sampling_rate / align_to) * align_to) | |
| stride_right = int(round(stride_length_s[1] * self.feature_extractor.sampling_rate / align_to) * align_to) | |
| if chunk_len < stride_left + stride_right: | |
| raise ValueError("Chunk length must be superior to stride length") | |
| for item in chunk_iter(inputs, self.feature_extractor, chunk_len, stride_left, stride_right, self.dtype): | |
| yield {**item, **extra} | |
| else: | |
| if self.type == "seq2seq_whisper" and inputs.shape[0] > self.feature_extractor.n_samples: | |
| processed = self.feature_extractor( | |
| inputs, | |
| sampling_rate=self.feature_extractor.sampling_rate, | |
| truncation=False, | |
| padding="longest", | |
| return_tensors="pt", | |
| return_attention_mask=True, | |
| ) | |
| else: | |
| if self.type == "seq2seq_whisper" and stride is None: | |
| processed = self.feature_extractor( | |
| inputs, | |
| sampling_rate=self.feature_extractor.sampling_rate, | |
| return_tensors="pt", | |
| return_token_timestamps=True, | |
| return_attention_mask=True, | |
| ) | |
| extra["num_frames"] = processed.pop("num_frames") | |
| else: | |
| processed = self.feature_extractor( | |
| inputs, | |
| sampling_rate=self.feature_extractor.sampling_rate, | |
| return_tensors="pt", | |
| return_attention_mask=True, | |
| ) | |
| if self.dtype is not None: | |
| processed = processed.to(dtype=self.dtype) | |
| if stride is not None: | |
| if self.type == "seq2seq": | |
| raise ValueError("Stride is only usable with CTC models, try removing it !") | |
| processed["stride"] = stride | |
| yield {"is_last": True, **processed, **extra} | |
| def _forward(self, model_inputs, return_timestamps=False, **generate_kwargs): | |
| attention_mask = model_inputs.pop("attention_mask", None) | |
| stride = model_inputs.pop("stride", None) | |
| num_frames = model_inputs.pop("num_frames", None) | |
| is_last = model_inputs.pop("is_last") | |
| if stride is not None and num_frames is not None: | |
| raise ValueError("num_frames must be used only when stride is None") | |
| if self.type in {"seq2seq", "seq2seq_whisper"}: | |
| # Consume values so we can let extra information flow freely through | |
| # the pipeline (important for `partial` in microphone) | |
| if "input_features" in model_inputs: | |
| inputs = model_inputs.pop("input_features") | |
| elif "input_values" in model_inputs: | |
| inputs = model_inputs.pop("input_values") | |
| else: | |
| raise ValueError( | |
| "Seq2Seq speech recognition model requires either a " | |
| f"`input_features` or `input_values` key, but only has {model_inputs.keys()}" | |
| ) | |
| # custom processing for Whisper timestamps and word-level timestamps | |
| return_timestamps = return_timestamps or getattr(self.generation_config, "return_timestamps", False) | |
| if return_timestamps and self.type == "seq2seq_whisper": | |
| generate_kwargs["return_timestamps"] = bool(return_timestamps) | |
| if return_timestamps == "word": | |
| generate_kwargs["return_token_timestamps"] = True | |
| generate_kwargs["return_segments"] = True | |
| # User-defined `generation_config` passed to the pipeline call take precedence | |
| if "generation_config" not in generate_kwargs: | |
| generate_kwargs["generation_config"] = self.generation_config | |
| main_input_name = self.model.main_input_name if hasattr(self.model, "main_input_name") else "inputs" | |
| generate_kwargs = { | |
| main_input_name: inputs, | |
| "attention_mask": attention_mask, | |
| **generate_kwargs, | |
| } | |
| tokens = self.model.generate(**generate_kwargs) | |
| # whisper longform generation stores timestamps in "segments" | |
| if return_timestamps == "word" and self.type == "seq2seq_whisper": | |
| if "segments" not in tokens: | |
| out = {"tokens": tokens["sequences"], "token_timestamps": tokens["token_timestamps"]} | |
| else: | |
| token_timestamps = [ | |
| torch.cat([segment["token_timestamps"] for segment in segment_list]) | |
| for segment_list in tokens["segments"] | |
| ] | |
| out = {"tokens": tokens["sequences"], "token_timestamps": token_timestamps} | |
| else: | |
| out = {"tokens": tokens} | |
| if self.type == "seq2seq_whisper": | |
| if stride is not None: | |
| out["stride"] = stride | |
| else: | |
| inputs = { | |
| self.model.main_input_name: model_inputs.pop(self.model.main_input_name), | |
| "attention_mask": attention_mask, | |
| } | |
| outputs = self.model(**inputs) | |
| logits = outputs.logits | |
| if self.type == "ctc_with_lm": | |
| out = {"logits": logits} | |
| else: | |
| out = {"tokens": logits.argmax(dim=-1)} | |
| if stride is not None: | |
| # Send stride to `postprocess`. | |
| # it needs to be handled there where | |
| # the pieces are to be concatenated. | |
| ratio = 1 / self.model.config.inputs_to_logits_ratio | |
| if isinstance(stride, tuple): | |
| out["stride"] = rescale_stride([stride], ratio)[0] | |
| else: | |
| out["stride"] = rescale_stride(stride, ratio) | |
| # Leftover | |
| extra = model_inputs | |
| return {"is_last": is_last, **out, **extra} | |
| def postprocess( | |
| self, model_outputs, decoder_kwargs: Optional[dict] = None, return_timestamps=None, return_language=None | |
| ): | |
| # Optional return types | |
| optional = {} | |
| final_items = [] | |
| key = "logits" if self.type == "ctc_with_lm" else "tokens" | |
| stride = None | |
| for outputs in model_outputs: | |
| if self.framework == "pt" and outputs[key].dtype in (torch.bfloat16, torch.float16): | |
| items = outputs[key].to(torch.float32).numpy() | |
| else: | |
| items = outputs[key].numpy() | |
| stride = outputs.get("stride", None) | |
| if stride is not None and self.type in {"ctc", "ctc_with_lm"}: | |
| total_n, left, right = stride | |
| # Total_n might be < logits.shape[1] | |
| # because of padding, that's why | |
| # we need to reconstruct this information | |
| # This won't work with left padding (which doesn't exist right now) | |
| right_n = total_n - right | |
| items = items[:, left:right_n] | |
| final_items.append(items) | |
| if stride and self.type == "seq2seq": | |
| items = _find_longest_common_sequence(final_items, self.tokenizer) | |
| elif self.type == "seq2seq_whisper": | |
| time_precision = self.feature_extractor.chunk_length / self.model.config.max_source_positions | |
| # Send the chunking back to seconds, it's easier to handle in whisper | |
| sampling_rate = self.feature_extractor.sampling_rate | |
| for output in model_outputs: | |
| if "stride" in output: | |
| chunk_len, stride_left, stride_right = output["stride"] | |
| # Go back in seconds | |
| chunk_len /= sampling_rate | |
| stride_left /= sampling_rate | |
| stride_right /= sampling_rate | |
| output["stride"] = chunk_len, stride_left, stride_right | |
| text, optional = self.tokenizer._decode_asr( | |
| model_outputs, | |
| return_timestamps=return_timestamps, | |
| return_language=return_language, | |
| time_precision=time_precision, | |
| ) | |
| else: | |
| items = np.concatenate(final_items, axis=1) | |
| items = items.squeeze(0) | |
| if self.type == "ctc_with_lm": | |
| if decoder_kwargs is None: | |
| decoder_kwargs = {} | |
| beams = self.decoder.decode_beams(items, **decoder_kwargs) | |
| text = beams[0][0] | |
| if return_timestamps: | |
| # Simply cast from pyctcdecode format to wav2vec2 format to leverage | |
| # pre-existing code later | |
| chunk_offset = beams[0][2] | |
| offsets = [] | |
| for word, (start_offset, end_offset) in chunk_offset: | |
| offsets.append({"word": word, "start_offset": start_offset, "end_offset": end_offset}) | |
| elif self.type != "seq2seq_whisper": | |
| skip_special_tokens = self.type != "ctc" | |
| text = self.tokenizer.decode(items, skip_special_tokens=skip_special_tokens) | |
| if return_timestamps: | |
| offsets = self.tokenizer.decode( | |
| items, skip_special_tokens=skip_special_tokens, output_char_offsets=True | |
| )["char_offsets"] | |
| if return_timestamps == "word": | |
| offsets = self.tokenizer._get_word_offsets(offsets, self.tokenizer.replace_word_delimiter_char) | |
| if return_timestamps and self.type not in {"seq2seq", "seq2seq_whisper"}: | |
| chunks = [] | |
| for item in offsets: | |
| start = item["start_offset"] * self.model.config.inputs_to_logits_ratio | |
| start /= self.feature_extractor.sampling_rate | |
| stop = item["end_offset"] * self.model.config.inputs_to_logits_ratio | |
| stop /= self.feature_extractor.sampling_rate | |
| chunks.append({"text": item[return_timestamps], "timestamp": (start, stop)}) | |
| optional["chunks"] = chunks | |
| extra = defaultdict(list) | |
| for output in model_outputs: | |
| output.pop("tokens", None) | |
| output.pop("logits", None) | |
| output.pop("is_last", None) | |
| output.pop("stride", None) | |
| output.pop("token_timestamps", None) | |
| for k, v in output.items(): | |
| extra[k].append(v) | |
| return {"text": text, **optional, **extra} | |