Spaces:
Runtime error
Runtime error
| from typing import List, Optional, Union | |
| import numpy as np | |
| import requests | |
| import torch | |
| from pyannote.audio import Pipeline | |
| from torchaudio import functional as F | |
| from transformers import pipeline | |
| from transformers.pipelines.audio_utils import ffmpeg_read | |
| class ASRDiarizationPipeline: | |
| def __init__( | |
| self, | |
| asr_pipeline, | |
| diarization_pipeline, | |
| ): | |
| self.asr_pipeline = asr_pipeline | |
| self.sampling_rate = asr_pipeline.feature_extractor.sampling_rate | |
| self.diarization_pipeline = diarization_pipeline | |
| def from_pretrained( | |
| cls, | |
| asr_model: Optional[str] = "openai/whisper-medium", | |
| *, | |
| diarizer_model: Optional[str] = "pyannote/speaker-diarization", | |
| chunk_length_s: Optional[int] = 30, | |
| use_auth_token: Optional[Union[str, bool]] = True, | |
| **kwargs, | |
| ): | |
| asr_pipeline = pipeline( | |
| "automatic-speech-recognition", | |
| model=asr_model, | |
| chunk_length_s=chunk_length_s, | |
| use_auth_token=use_auth_token, | |
| **kwargs, | |
| ) | |
| diarization_pipeline = Pipeline.from_pretrained(diarizer_model, use_auth_token=use_auth_token) | |
| return cls(asr_pipeline, diarization_pipeline) | |
| def __call__( | |
| self, | |
| inputs: Union[np.ndarray, List[np.ndarray]], | |
| group_by_speaker: bool = True, | |
| **kwargs, | |
| ): | |
| """ | |
| Transcribe the audio sequence(s) given as inputs to text and label with speaker information. The input audio | |
| is first passed to the speaker diarization pipeline, which returns timestamps for 'who spoke when'. The audio | |
| is then passed to the ASR pipeline, which returns utterance-level transcriptions and their corresponding | |
| timestamps. The speaker diarizer timestamps are aligned with the ASR transcription timestamps to give | |
| speaker-labelled transcriptions. We cannot use the speaker diarization timestamps alone to partition the | |
| transcriptions, as these timestamps may straddle across transcribed utterances from the ASR output. Thus, we | |
| find the diarizer timestamps that are closest to the ASR timestamps and partition here. | |
| Args: | |
| inputs (`np.ndarray` or `bytes` or `str` or `dict`): | |
| The inputs is either : | |
| - `str` that is the filename of the audio file, the file will be read at the correct sampling rate | |
| to get the waveform using *ffmpeg*. This requires *ffmpeg* to be installed on the system. | |
| - `bytes` it is supposed to be the content of an audio file and is interpreted by *ffmpeg* in the | |
| same way. | |
| - (`np.ndarray` of shape (n, ) of type `np.float32` or `np.float64`) | |
| Raw audio at the correct sampling rate (no further check will be done) | |
| - `dict` form can be used to pass raw audio sampled at arbitrary `sampling_rate` and let this | |
| pipeline do the resampling. The dict must be in the format `{"sampling_rate": int, "raw": | |
| np.array}` with optionally a `"stride": (left: int, right: int)` than can ask the pipeline to | |
| treat the first `left` samples and last `right` samples to be ignored in decoding (but used at | |
| inference to provide more context to the model). Only use `stride` with CTC models. | |
| group_by_speaker (`bool`): | |
| Whether to group consecutive utterances by one speaker into a single segment. If False, will return | |
| transcriptions on a chunk-by-chunk basis. | |
| Return: | |
| A list of transcriptions. Each list item corresponds to one chunk / segment of transcription, and is a | |
| dictionary with the following keys: | |
| - **text** (`str` ) -- The recognized text. | |
| - **speaker** (`str`) -- The associated speaker. | |
| - **timestamps** (`tuple`) -- The start and end time for the chunk / segment. | |
| """ | |
| inputs, diarizer_inputs = self.preprocess(inputs) | |
| diarization = self.diarization_pipeline( | |
| {"waveform": diarizer_inputs, "sample_rate": self.sampling_rate}, | |
| **kwargs, | |
| ) | |
| segments = diarization.for_json()["content"] | |
| # diarizer output may contain consecutive segments from the same speaker (e.g. {(0 -> 1, speaker_1), (1 -> 1.5, speaker_1), ...}) | |
| # we combine these segments to give overall timestamps for each speaker's turn (e.g. {(0 -> 1.5, speaker_1), ...}) | |
| new_segments = [] | |
| prev_segment = cur_segment = segments[0] | |
| for i in range(1, len(segments)): | |
| cur_segment = segments[i] | |
| # check if we have changed speaker ("label") | |
| if cur_segment["label"] != prev_segment["label"] and i < len(segments): | |
| # add the start/end times for the super-segment to the new list | |
| new_segments.append( | |
| { | |
| "segment": {"start": prev_segment["segment"]["start"], "end": cur_segment["segment"]["start"]}, | |
| "speaker": prev_segment["label"], | |
| } | |
| ) | |
| prev_segment = segments[i] | |
| # add the last segment(s) if there was no speaker change | |
| new_segments.append( | |
| { | |
| "segment": {"start": prev_segment["segment"]["start"], "end": cur_segment["segment"]["end"]}, | |
| "speaker": prev_segment["label"], | |
| } | |
| ) | |
| asr_out = self.asr_pipeline( | |
| {"array": inputs, "sampling_rate": self.sampling_rate}, | |
| return_timestamps=True, | |
| **kwargs, | |
| ) | |
| transcript = asr_out["chunks"] | |
| # get the end timestamps for each chunk from the ASR output | |
| end_timestamps = np.array([chunk["timestamp"][-1] for chunk in transcript]) | |
| segmented_preds = [] | |
| # align the diarizer timestamps and the ASR timestamps | |
| for segment in new_segments: | |
| # get the diarizer end timestamp | |
| end_time = segment["segment"]["end"] | |
| # find the ASR end timestamp that is closest to the diarizer's end timestamp and cut the transcript to here | |
| upto_idx = np.argmin(np.abs(end_timestamps - end_time)) | |
| if group_by_speaker: | |
| segmented_preds.append( | |
| { | |
| "speaker": segment["speaker"], | |
| "text": "".join([chunk["text"] for chunk in transcript[: upto_idx + 1]]), | |
| "timestamp": (transcript[0]["timestamp"][0], transcript[upto_idx]["timestamp"][1]), | |
| } | |
| ) | |
| else: | |
| for i in range(upto_idx + 1): | |
| segmented_preds.append({"speaker": segment["speaker"], **transcript[i]}) | |
| # crop the transcripts and timestamp lists according to the latest timestamp (for faster argmin) | |
| transcript = transcript[upto_idx + 1 :] | |
| end_timestamps = end_timestamps[upto_idx + 1 :] | |
| return segmented_preds | |
| # Adapted from transformers.pipelines.automatic_speech_recognition.AutomaticSpeechRecognitionPipeline.preprocess | |
| # (see https://github.com/huggingface/transformers/blob/238449414f88d94ded35e80459bb6412d8ab42cf/src/transformers/pipelines/automatic_speech_recognition.py#L417) | |
| def preprocess(self, inputs): | |
| if isinstance(inputs, str): | |
| if inputs.startswith("http://") or inputs.startswith("https://"): | |
| # We need to actually check for a real protocol, otherwise it's impossible to use a local file | |
| # like http_huggingface_co.png | |
| inputs = requests.get(inputs).content | |
| else: | |
| with open(inputs, "rb") as f: | |
| inputs = f.read() | |
| if isinstance(inputs, bytes): | |
| inputs = ffmpeg_read(inputs, self.sampling_rate) | |
| if isinstance(inputs, dict): | |
| # Accepting `"array"` which is the key defined in `datasets` for better integration | |
| if not ("sampling_rate" in inputs and ("raw" in inputs or "array" in inputs)): | |
| raise ValueError( | |
| "When passing a dictionary to ASRDiarizePipeline, the dict needs to contain a " | |
| '"raw" key containing the numpy array representing the audio and a "sampling_rate" key, ' | |
| "containing the sampling_rate associated with that array" | |
| ) | |
| _inputs = inputs.pop("raw", None) | |
| if _inputs is None: | |
| # Remove path which will not be used from `datasets`. | |
| inputs.pop("path", None) | |
| _inputs = inputs.pop("array", None) | |
| in_sampling_rate = inputs.pop("sampling_rate") | |
| inputs = _inputs | |
| if in_sampling_rate != self.sampling_rate: | |
| inputs = F.resample(torch.from_numpy(inputs), in_sampling_rate, self.sampling_rate).numpy() | |
| if not isinstance(inputs, np.ndarray): | |
| raise ValueError(f"We expect a numpy ndarray as input, got `{type(inputs)}`") | |
| if len(inputs.shape) != 1: | |
| raise ValueError("We expect a single channel audio input for ASRDiarizePipeline") | |
| # diarization model expects float32 torch tensor of shape `(channels, seq_len)` | |
| diarizer_inputs = torch.from_numpy(inputs).float() | |
| diarizer_inputs = diarizer_inputs.unsqueeze(0) | |
| return inputs, diarizer_inputs | |