Spaces:
Runtime error
Runtime error
Commit
·
fc4914b
1
Parent(s):
ff9897a
Update asr_diarizer.py
Browse files- asr_diarizer.py +35 -16
asr_diarizer.py
CHANGED
|
@@ -16,14 +16,15 @@ class ASRDiarizationPipeline:
|
|
| 16 |
diarization_pipeline,
|
| 17 |
):
|
| 18 |
self.asr_pipeline = asr_pipeline
|
| 19 |
-
self.
|
| 20 |
|
| 21 |
-
self.
|
| 22 |
|
| 23 |
@classmethod
|
| 24 |
def from_pretrained(
|
| 25 |
cls,
|
| 26 |
-
asr_model: Optional[str] = "openai/whisper-
|
|
|
|
| 27 |
diarizer_model: Optional[str] = "pyannote/speaker-diarization",
|
| 28 |
chunk_length_s: Optional[int] = 30,
|
| 29 |
use_auth_token: Optional[Union[str, bool]] = True,
|
|
@@ -37,7 +38,7 @@ class ASRDiarizationPipeline:
|
|
| 37 |
**kwargs,
|
| 38 |
)
|
| 39 |
diarization_pipeline = Pipeline.from_pretrained(diarizer_model, use_auth_token=use_auth_token)
|
| 40 |
-
cls(asr_pipeline, diarization_pipeline)
|
| 41 |
|
| 42 |
def __call__(
|
| 43 |
self,
|
|
@@ -46,7 +47,13 @@ class ASRDiarizationPipeline:
|
|
| 46 |
**kwargs,
|
| 47 |
):
|
| 48 |
"""
|
| 49 |
-
Transcribe the audio sequence(s) given as inputs to text.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
Args:
|
| 52 |
inputs (`np.ndarray` or `bytes` or `str` or `dict`):
|
|
@@ -62,15 +69,16 @@ class ASRDiarizationPipeline:
|
|
| 62 |
np.array}` with optionally a `"stride": (left: int, right: int)` than can ask the pipeline to
|
| 63 |
treat the first `left` samples and last `right` samples to be ignored in decoding (but used at
|
| 64 |
inference to provide more context to the model). Only use `stride` with CTC models.
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
Return:
|
| 67 |
-
|
|
|
|
| 68 |
- **text** (`str` ) -- The recognized text.
|
| 69 |
-
- **
|
| 70 |
-
|
| 71 |
-
chunks identified by the model, *e.g.* `[{"text": "hi ", "timestamps": (0.5,0.9), {"text":
|
| 72 |
-
"there", "timestamps": (1.0, 1.5)}]`. The original full text can roughly be recovered by doing
|
| 73 |
-
`"".join(chunk["text"] for chunk in output["chunks"])`.
|
| 74 |
"""
|
| 75 |
inputs, diarizer_inputs = self.preprocess(inputs)
|
| 76 |
|
|
@@ -81,13 +89,17 @@ class ASRDiarizationPipeline:
|
|
| 81 |
|
| 82 |
segments = diarization.for_json()["content"]
|
| 83 |
|
|
|
|
|
|
|
| 84 |
new_segments = []
|
| 85 |
prev_segment = cur_segment = segments[0]
|
| 86 |
|
| 87 |
for i in range(1, len(segments)):
|
| 88 |
cur_segment = segments[i]
|
| 89 |
|
|
|
|
| 90 |
if cur_segment["label"] != prev_segment["label"] and i < len(segments):
|
|
|
|
| 91 |
new_segments.append(
|
| 92 |
{
|
| 93 |
"segment": {"start": prev_segment["segment"]["start"], "end": cur_segment["segment"]["start"]},
|
|
@@ -96,6 +108,7 @@ class ASRDiarizationPipeline:
|
|
| 96 |
)
|
| 97 |
prev_segment = segments[i]
|
| 98 |
|
|
|
|
| 99 |
new_segments.append(
|
| 100 |
{
|
| 101 |
"segment": {"start": prev_segment["segment"]["start"], "end": cur_segment["segment"]["end"]},
|
|
@@ -110,11 +123,15 @@ class ASRDiarizationPipeline:
|
|
| 110 |
)
|
| 111 |
transcript = asr_out["chunks"]
|
| 112 |
|
|
|
|
| 113 |
end_timestamps = np.array([chunk["timestamp"][-1] for chunk in transcript])
|
| 114 |
segmented_preds = []
|
| 115 |
|
|
|
|
| 116 |
for segment in new_segments:
|
|
|
|
| 117 |
end_time = segment["segment"]["end"]
|
|
|
|
| 118 |
upto_idx = np.argmin(np.abs(end_timestamps - end_time))
|
| 119 |
|
| 120 |
if group_by_speaker:
|
|
@@ -122,21 +139,21 @@ class ASRDiarizationPipeline:
|
|
| 122 |
{
|
| 123 |
"speaker": segment["speaker"],
|
| 124 |
"text": "".join([chunk["text"] for chunk in transcript[: upto_idx + 1]]),
|
| 125 |
-
"timestamp":
|
| 126 |
-
"start": transcript[0]["timestamp"][0],
|
| 127 |
-
"end": transcript[upto_idx]["timestamp"][1],
|
| 128 |
-
},
|
| 129 |
}
|
| 130 |
)
|
| 131 |
else:
|
| 132 |
for i in range(upto_idx + 1):
|
| 133 |
segmented_preds.append({"speaker": segment["speaker"], **transcript[i]})
|
| 134 |
|
|
|
|
| 135 |
transcript = transcript[upto_idx + 1 :]
|
| 136 |
end_timestamps = end_timestamps[upto_idx + 1 :]
|
| 137 |
|
| 138 |
return segmented_preds
|
| 139 |
|
|
|
|
|
|
|
| 140 |
def preprocess(self, inputs):
|
| 141 |
if isinstance(inputs, str):
|
| 142 |
if inputs.startswith("http://") or inputs.startswith("https://"):
|
|
@@ -174,6 +191,8 @@ class ASRDiarizationPipeline:
|
|
| 174 |
if len(inputs.shape) != 1:
|
| 175 |
raise ValueError("We expect a single channel audio input for ASRDiarizePipeline")
|
| 176 |
|
| 177 |
-
|
|
|
|
|
|
|
| 178 |
|
| 179 |
return inputs, diarizer_inputs
|
|
|
|
| 16 |
diarization_pipeline,
|
| 17 |
):
|
| 18 |
self.asr_pipeline = asr_pipeline
|
| 19 |
+
self.sampling_rate = asr_pipeline.feature_extractor.sampling_rate
|
| 20 |
|
| 21 |
+
self.diarization_pipeline = diarization_pipeline
|
| 22 |
|
| 23 |
@classmethod
|
| 24 |
def from_pretrained(
|
| 25 |
cls,
|
| 26 |
+
asr_model: Optional[str] = "openai/whisper-medium",
|
| 27 |
+
*,
|
| 28 |
diarizer_model: Optional[str] = "pyannote/speaker-diarization",
|
| 29 |
chunk_length_s: Optional[int] = 30,
|
| 30 |
use_auth_token: Optional[Union[str, bool]] = True,
|
|
|
|
| 38 |
**kwargs,
|
| 39 |
)
|
| 40 |
diarization_pipeline = Pipeline.from_pretrained(diarizer_model, use_auth_token=use_auth_token)
|
| 41 |
+
return cls(asr_pipeline, diarization_pipeline)
|
| 42 |
|
| 43 |
def __call__(
|
| 44 |
self,
|
|
|
|
| 47 |
**kwargs,
|
| 48 |
):
|
| 49 |
"""
|
| 50 |
+
Transcribe the audio sequence(s) given as inputs to text and label with speaker information. The input audio
|
| 51 |
+
is first passed to the speaker diarization pipeline, which returns timestamps for 'who spoke when'. The audio
|
| 52 |
+
is then passed to the ASR pipeline, which returns utterance-level transcriptions and their corresponding
|
| 53 |
+
timestamps. The speaker diarizer timestamps are aligned with the ASR transcription timestamps to give
|
| 54 |
+
speaker-labelled transcriptions. We cannot use the speaker diarization timestamps alone to partition the
|
| 55 |
+
transcriptions, as these timestamps may straddle across transcribed utterances from the ASR output. Thus, we
|
| 56 |
+
find the diarizer timestamps that are closest to the ASR timestamps and partition here.
|
| 57 |
|
| 58 |
Args:
|
| 59 |
inputs (`np.ndarray` or `bytes` or `str` or `dict`):
|
|
|
|
| 69 |
np.array}` with optionally a `"stride": (left: int, right: int)` than can ask the pipeline to
|
| 70 |
treat the first `left` samples and last `right` samples to be ignored in decoding (but used at
|
| 71 |
inference to provide more context to the model). Only use `stride` with CTC models.
|
| 72 |
+
group_by_speaker (`bool`):
|
| 73 |
+
Whether to group consecutive utterances by one speaker into a single segment. If False, will return
|
| 74 |
+
transcriptions on a chunk-by-chunk basis.
|
| 75 |
|
| 76 |
Return:
|
| 77 |
+
A list of transcriptions. Each list item corresponds to one chunk / segment of transcription, and is a
|
| 78 |
+
dictionary with the following keys:
|
| 79 |
- **text** (`str` ) -- The recognized text.
|
| 80 |
+
- **speaker** (`str`) -- The associated speaker.
|
| 81 |
+
- **timestamps** (`tuple`) -- The start and end time for the chunk / segment.
|
|
|
|
|
|
|
|
|
|
| 82 |
"""
|
| 83 |
inputs, diarizer_inputs = self.preprocess(inputs)
|
| 84 |
|
|
|
|
| 89 |
|
| 90 |
segments = diarization.for_json()["content"]
|
| 91 |
|
| 92 |
+
# diarizer output may contain consecutive segments from the same speaker (e.g. {(0 -> 1, speaker_1), (1 -> 1.5, speaker_1), ...})
|
| 93 |
+
# we combine these segments to give overall timestamps for each speaker's turn (e.g. {(0 -> 1.5, speaker_1), ...})
|
| 94 |
new_segments = []
|
| 95 |
prev_segment = cur_segment = segments[0]
|
| 96 |
|
| 97 |
for i in range(1, len(segments)):
|
| 98 |
cur_segment = segments[i]
|
| 99 |
|
| 100 |
+
# check if we have changed speaker ("label")
|
| 101 |
if cur_segment["label"] != prev_segment["label"] and i < len(segments):
|
| 102 |
+
# add the start/end times for the super-segment to the new list
|
| 103 |
new_segments.append(
|
| 104 |
{
|
| 105 |
"segment": {"start": prev_segment["segment"]["start"], "end": cur_segment["segment"]["start"]},
|
|
|
|
| 108 |
)
|
| 109 |
prev_segment = segments[i]
|
| 110 |
|
| 111 |
+
# add the last segment(s) if there was no speaker change
|
| 112 |
new_segments.append(
|
| 113 |
{
|
| 114 |
"segment": {"start": prev_segment["segment"]["start"], "end": cur_segment["segment"]["end"]},
|
|
|
|
| 123 |
)
|
| 124 |
transcript = asr_out["chunks"]
|
| 125 |
|
| 126 |
+
# get the end timestamps for each chunk from the ASR output
|
| 127 |
end_timestamps = np.array([chunk["timestamp"][-1] for chunk in transcript])
|
| 128 |
segmented_preds = []
|
| 129 |
|
| 130 |
+
# align the diarizer timestamps and the ASR timestamps
|
| 131 |
for segment in new_segments:
|
| 132 |
+
# get the diarizer end timestamp
|
| 133 |
end_time = segment["segment"]["end"]
|
| 134 |
+
# find the ASR end timestamp that is closest to the diarizer's end timestamp and cut the transcript to here
|
| 135 |
upto_idx = np.argmin(np.abs(end_timestamps - end_time))
|
| 136 |
|
| 137 |
if group_by_speaker:
|
|
|
|
| 139 |
{
|
| 140 |
"speaker": segment["speaker"],
|
| 141 |
"text": "".join([chunk["text"] for chunk in transcript[: upto_idx + 1]]),
|
| 142 |
+
"timestamp": (transcript[0]["timestamp"][0], transcript[upto_idx]["timestamp"][1]),
|
|
|
|
|
|
|
|
|
|
| 143 |
}
|
| 144 |
)
|
| 145 |
else:
|
| 146 |
for i in range(upto_idx + 1):
|
| 147 |
segmented_preds.append({"speaker": segment["speaker"], **transcript[i]})
|
| 148 |
|
| 149 |
+
# crop the transcripts and timestamp lists according to the latest timestamp (for faster argmin)
|
| 150 |
transcript = transcript[upto_idx + 1 :]
|
| 151 |
end_timestamps = end_timestamps[upto_idx + 1 :]
|
| 152 |
|
| 153 |
return segmented_preds
|
| 154 |
|
| 155 |
+
# Adapted from transformers.pipelines.automatic_speech_recognition.AutomaticSpeechRecognitionPipeline.preprocess
|
| 156 |
+
# (see https://github.com/huggingface/transformers/blob/238449414f88d94ded35e80459bb6412d8ab42cf/src/transformers/pipelines/automatic_speech_recognition.py#L417)
|
| 157 |
def preprocess(self, inputs):
|
| 158 |
if isinstance(inputs, str):
|
| 159 |
if inputs.startswith("http://") or inputs.startswith("https://"):
|
|
|
|
| 191 |
if len(inputs.shape) != 1:
|
| 192 |
raise ValueError("We expect a single channel audio input for ASRDiarizePipeline")
|
| 193 |
|
| 194 |
+
# diarization model expects float32 torch tensor of shape `(channels, seq_len)`
|
| 195 |
+
diarizer_inputs = torch.from_numpy(inputs).float()
|
| 196 |
+
diarizer_inputs = diarizer_inputs.unsqueeze(0)
|
| 197 |
|
| 198 |
return inputs, diarizer_inputs
|