mazesmazes commited on
Commit
9fad9c2
·
verified ·
1 Parent(s): d57208e

Update custom model files, README, and requirements

Browse files
Files changed (1) hide show
  1. asr_pipeline.py +398 -0
asr_pipeline.py CHANGED
@@ -1,5 +1,6 @@
1
  from typing import Any
2
 
 
3
  import torch
4
  import transformers
5
 
@@ -9,6 +10,278 @@ except ImportError:
9
  from asr_modeling import ASRModel # type: ignore[no-redef]
10
 
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
13
  """ASR Pipeline for audio-to-text transcription."""
14
 
@@ -24,6 +297,131 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
24
  super().__init__(
25
  model=model, feature_extractor=feature_extractor, tokenizer=tokenizer, **kwargs
26
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  def preprocess(self, inputs, **preprocess_params):
29
  # Handle dict with "array" key (from datasets)
 
1
  from typing import Any
2
 
3
+ import numpy as np
4
  import torch
5
  import transformers
6
 
 
10
  from asr_modeling import ASRModel # type: ignore[no-redef]
11
 
12
 
13
+ class ForcedAligner:
14
+ """Lazy-loaded forced aligner for word-level timestamps using torchaudio wav2vec2."""
15
+
16
+ _bundle = None
17
+ _model = None
18
+ _labels = None
19
+ _dictionary = None
20
+
21
+ @classmethod
22
+ def get_instance(cls, device: str = "cuda"):
23
+ if cls._model is None:
24
+ import torchaudio
25
+
26
+ cls._bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
27
+ cls._model = cls._bundle.get_model().to(device)
28
+ cls._model.eval()
29
+ cls._labels = cls._bundle.get_labels()
30
+ cls._dictionary = {c: i for i, c in enumerate(cls._labels)}
31
+ return cls._model, cls._labels, cls._dictionary
32
+
33
+ @classmethod
34
+ def align(
35
+ cls,
36
+ audio: np.ndarray,
37
+ text: str,
38
+ sample_rate: int = 16000,
39
+ language: str = "eng",
40
+ batch_size: int = 16,
41
+ ) -> list[dict]:
42
+ """Align transcript to audio and return word-level timestamps.
43
+
44
+ Args:
45
+ audio: Audio waveform as numpy array
46
+ text: Transcript text to align
47
+ sample_rate: Audio sample rate (default 16000)
48
+ language: ISO-639-3 language code (default "eng" for English, unused)
49
+ batch_size: Batch size for alignment model (unused)
50
+
51
+ Returns:
52
+ List of dicts with 'word', 'start', 'end' keys
53
+ """
54
+ import torchaudio
55
+ from torchaudio.functional import forced_align, merge_tokens
56
+
57
+ device = "cuda" if torch.cuda.is_available() else "cpu"
58
+ model, labels, dictionary = cls.get_instance(device)
59
+
60
+ # Convert audio to tensor (copy to ensure array is writable)
61
+ if isinstance(audio, np.ndarray):
62
+ waveform = torch.from_numpy(audio.copy()).float()
63
+ else:
64
+ waveform = audio.clone().float()
65
+
66
+ # Ensure 2D (channels, time)
67
+ if waveform.dim() == 1:
68
+ waveform = waveform.unsqueeze(0)
69
+
70
+ # Resample if needed (wav2vec2 expects 16kHz)
71
+ if sample_rate != cls._bundle.sample_rate:
72
+ waveform = torchaudio.functional.resample(
73
+ waveform, sample_rate, cls._bundle.sample_rate
74
+ )
75
+
76
+ waveform = waveform.to(device)
77
+
78
+ # Get emissions from model
79
+ with torch.inference_mode():
80
+ emissions, _ = model(waveform)
81
+ emissions = torch.log_softmax(emissions, dim=-1)
82
+
83
+ emission = emissions[0].cpu()
84
+
85
+ # Normalize text: uppercase, keep only valid characters
86
+ transcript = text.upper()
87
+ # Build tokens from transcript
88
+ tokens = []
89
+ for char in transcript:
90
+ if char in dictionary:
91
+ tokens.append(dictionary[char])
92
+ elif char == " ":
93
+ tokens.append(dictionary.get("|", dictionary.get(" ", 0)))
94
+
95
+ if not tokens:
96
+ return []
97
+
98
+ targets = torch.tensor([tokens], dtype=torch.int32)
99
+
100
+ # Run forced alignment
101
+ # Note: forced_align is deprecated in torchaudio 2.6+ and will be removed in 2.9 (late 2025)
102
+ # No official replacement announced yet. See https://github.com/pytorch/audio/issues/3902
103
+ aligned_tokens, scores = forced_align(emission.unsqueeze(0), targets, blank=0)
104
+
105
+ # Use torchaudio's merge_tokens to get token spans (removes blanks and merges repeats)
106
+ token_spans = merge_tokens(aligned_tokens[0], scores[0])
107
+
108
+ # Convert frame indices to time (model stride is 320 samples at 16kHz = 20ms)
109
+ frame_duration = 320 / cls._bundle.sample_rate
110
+
111
+ # Group token spans into words based on pipe separator
112
+ words = text.split()
113
+ word_timestamps = []
114
+ current_word_start = None
115
+ current_word_end = None
116
+ word_idx = 0
117
+
118
+ for span in token_spans:
119
+ token_char = labels[span.token]
120
+ if token_char == "|": # Word separator
121
+ if current_word_start is not None and word_idx < len(words):
122
+ word_timestamps.append({
123
+ "word": words[word_idx],
124
+ "start": current_word_start * frame_duration,
125
+ "end": current_word_end * frame_duration,
126
+ })
127
+ word_idx += 1
128
+ current_word_start = None
129
+ current_word_end = None
130
+ else:
131
+ if current_word_start is None:
132
+ current_word_start = span.start
133
+ current_word_end = span.end
134
+
135
+ # Don't forget the last word
136
+ if current_word_start is not None and word_idx < len(words):
137
+ word_timestamps.append({
138
+ "word": words[word_idx],
139
+ "start": current_word_start * frame_duration,
140
+ "end": current_word_end * frame_duration,
141
+ })
142
+
143
+ return word_timestamps
144
+
145
+
146
+ class SpeakerDiarizer:
147
+ """Lazy-loaded speaker diarization using pyannote-audio."""
148
+
149
+ _pipeline = None
150
+
151
+ @classmethod
152
+ def get_instance(cls, hf_token: str | None = None):
153
+ """Get or create the diarization pipeline.
154
+
155
+ Args:
156
+ hf_token: HuggingFace token with access to pyannote models.
157
+ Can also be set via HF_TOKEN environment variable.
158
+ """
159
+ if cls._pipeline is None:
160
+ from pyannote.audio import Pipeline
161
+
162
+ cls._pipeline = Pipeline.from_pretrained(
163
+ "pyannote/speaker-diarization-3.1",
164
+ )
165
+
166
+ # Move to GPU if available
167
+ if torch.cuda.is_available():
168
+ cls._pipeline.to(torch.device("cuda"))
169
+ elif torch.backends.mps.is_available():
170
+ cls._pipeline.to(torch.device("mps"))
171
+
172
+ return cls._pipeline
173
+
174
+ @classmethod
175
+ def diarize(
176
+ cls,
177
+ audio: np.ndarray | str,
178
+ sample_rate: int = 16000,
179
+ num_speakers: int | None = None,
180
+ min_speakers: int | None = None,
181
+ max_speakers: int | None = None,
182
+ hf_token: str | None = None,
183
+ ) -> list[dict]:
184
+ """Run speaker diarization on audio.
185
+
186
+ Args:
187
+ audio: Audio waveform as numpy array or path to audio file
188
+ sample_rate: Audio sample rate (default 16000)
189
+ num_speakers: Exact number of speakers (if known)
190
+ min_speakers: Minimum number of speakers
191
+ max_speakers: Maximum number of speakers
192
+ hf_token: HuggingFace token for pyannote models
193
+
194
+ Returns:
195
+ List of dicts with 'speaker', 'start', 'end' keys
196
+ """
197
+ pipeline = cls.get_instance(hf_token)
198
+
199
+ # Prepare audio input
200
+ if isinstance(audio, np.ndarray):
201
+ # pyannote expects {"waveform": tensor, "sample_rate": int}
202
+ waveform = torch.from_numpy(audio).unsqueeze(0) # Add channel dim
203
+ if waveform.dim() == 1:
204
+ waveform = waveform.unsqueeze(0)
205
+ audio_input = {"waveform": waveform, "sample_rate": sample_rate}
206
+ else:
207
+ # File path
208
+ audio_input = audio
209
+
210
+ # Run diarization
211
+ diarization_args = {}
212
+ if num_speakers is not None:
213
+ diarization_args["num_speakers"] = num_speakers
214
+ if min_speakers is not None:
215
+ diarization_args["min_speakers"] = min_speakers
216
+ if max_speakers is not None:
217
+ diarization_args["max_speakers"] = max_speakers
218
+
219
+ diarization = pipeline(audio_input, **diarization_args)
220
+
221
+ # Handle different pyannote return types
222
+ # pyannote 3.x returns DiarizeOutput dataclass, older versions return Annotation
223
+ if hasattr(diarization, "itertracks"):
224
+ annotation = diarization
225
+ elif hasattr(diarization, "speaker_diarization"):
226
+ # pyannote 3.x DiarizeOutput dataclass
227
+ annotation = diarization.speaker_diarization
228
+ elif isinstance(diarization, tuple):
229
+ # Some versions return (annotation, embeddings) tuple
230
+ annotation = diarization[0]
231
+ else:
232
+ raise TypeError(f"Unexpected diarization output type: {type(diarization)}")
233
+
234
+ # Convert to simple format
235
+ segments = []
236
+ for turn, _, speaker in annotation.itertracks(yield_label=True):
237
+ segments.append({
238
+ "speaker": speaker,
239
+ "start": turn.start,
240
+ "end": turn.end,
241
+ })
242
+
243
+ return segments
244
+
245
+ @classmethod
246
+ def assign_speakers_to_words(
247
+ cls,
248
+ words: list[dict],
249
+ speaker_segments: list[dict],
250
+ ) -> list[dict]:
251
+ """Assign speaker labels to words based on timestamp overlap.
252
+
253
+ Args:
254
+ words: List of word dicts with 'word', 'start', 'end' keys
255
+ speaker_segments: List of speaker dicts with 'speaker', 'start', 'end' keys
256
+
257
+ Returns:
258
+ Words list with 'speaker' key added to each word
259
+ """
260
+ for word in words:
261
+ word_mid = (word["start"] + word["end"]) / 2
262
+
263
+ # Find the speaker segment that contains this word's midpoint
264
+ best_speaker = None
265
+ for seg in speaker_segments:
266
+ if seg["start"] <= word_mid <= seg["end"]:
267
+ best_speaker = seg["speaker"]
268
+ break
269
+
270
+ # If no exact match, find closest segment
271
+ if best_speaker is None and speaker_segments:
272
+ min_dist = float("inf")
273
+ for seg in speaker_segments:
274
+ seg_mid = (seg["start"] + seg["end"]) / 2
275
+ dist = abs(word_mid - seg_mid)
276
+ if dist < min_dist:
277
+ min_dist = dist
278
+ best_speaker = seg["speaker"]
279
+
280
+ word["speaker"] = best_speaker
281
+
282
+ return words
283
+
284
+
285
  class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
286
  """ASR Pipeline for audio-to-text transcription."""
287
 
 
297
  super().__init__(
298
  model=model, feature_extractor=feature_extractor, tokenizer=tokenizer, **kwargs
299
  )
300
+ self._current_audio = None
301
+
302
+ def _sanitize_parameters(self, **kwargs):
303
+ """Intercept our custom parameters before parent class validates them."""
304
+ # Remove our custom parameters so parent doesn't see them
305
+ kwargs.pop("return_timestamps", None)
306
+ kwargs.pop("return_speakers", None)
307
+ kwargs.pop("num_speakers", None)
308
+ kwargs.pop("min_speakers", None)
309
+ kwargs.pop("max_speakers", None)
310
+ kwargs.pop("hf_token", None)
311
+
312
+ return super()._sanitize_parameters(**kwargs)
313
+
314
+ def __call__(
315
+ self,
316
+ inputs,
317
+ **kwargs,
318
+ ):
319
+ """Transcribe audio with optional word-level timestamps and speaker diarization.
320
+
321
+ Args:
322
+ inputs: Audio input (file path, dict with array/sampling_rate, etc.)
323
+ return_timestamps: If True, return word-level timestamps using forced alignment
324
+ return_speakers: If True, return speaker labels for each word
325
+ num_speakers: Exact number of speakers (if known, for diarization)
326
+ min_speakers: Minimum number of speakers (for diarization)
327
+ max_speakers: Maximum number of speakers (for diarization)
328
+ hf_token: HuggingFace token for pyannote models (or set HF_TOKEN env var)
329
+ **kwargs: Additional arguments passed to the pipeline
330
+
331
+ Returns:
332
+ Dict with 'text' key, 'words' key if return_timestamps=True,
333
+ and speaker labels on words if return_speakers=True
334
+ """
335
+ # Extract our params before super().__call__ (which will also call _sanitize_parameters)
336
+ return_timestamps = kwargs.pop("return_timestamps", False)
337
+ return_speakers = kwargs.pop("return_speakers", False)
338
+ diarization_params = {
339
+ "num_speakers": kwargs.pop("num_speakers", None),
340
+ "min_speakers": kwargs.pop("min_speakers", None),
341
+ "max_speakers": kwargs.pop("max_speakers", None),
342
+ "hf_token": kwargs.pop("hf_token", None),
343
+ }
344
+
345
+ if return_speakers:
346
+ return_timestamps = True
347
+
348
+ # Store audio for timestamp alignment and diarization
349
+ if return_timestamps or return_speakers:
350
+ self._current_audio = self._extract_audio(inputs)
351
+
352
+ # Run standard transcription
353
+ result = super().__call__(inputs, **kwargs)
354
+
355
+ # Add timestamps if requested
356
+ if return_timestamps and self._current_audio is not None:
357
+ text = result.get("text", "")
358
+ if text:
359
+ try:
360
+ words = ForcedAligner.align(
361
+ self._current_audio["array"],
362
+ text,
363
+ sample_rate=self._current_audio.get("sampling_rate", 16000),
364
+ )
365
+ result["words"] = words
366
+ except Exception as e:
367
+ result["words"] = []
368
+ result["timestamp_error"] = str(e)
369
+ else:
370
+ result["words"] = []
371
+
372
+ # Add speaker diarization if requested
373
+ if return_speakers and self._current_audio is not None:
374
+ try:
375
+ # Run diarization
376
+ speaker_segments = SpeakerDiarizer.diarize(
377
+ self._current_audio["array"],
378
+ sample_rate=self._current_audio.get("sampling_rate", 16000),
379
+ **{k: v for k, v in diarization_params.items() if v is not None},
380
+ )
381
+ result["speaker_segments"] = speaker_segments
382
+
383
+ # Assign speakers to words
384
+ if result.get("words"):
385
+ result["words"] = SpeakerDiarizer.assign_speakers_to_words(
386
+ result["words"],
387
+ speaker_segments,
388
+ )
389
+ except Exception as e:
390
+ result["speaker_segments"] = []
391
+ result["diarization_error"] = str(e)
392
+
393
+ # Clean up
394
+ self._current_audio = None
395
+
396
+ return result
397
+
398
+ def _extract_audio(self, inputs) -> dict | None:
399
+ """Extract audio array from various input formats using HF utilities."""
400
+ from transformers.pipelines.audio_utils import ffmpeg_read
401
+
402
+ if isinstance(inputs, dict):
403
+ if "array" in inputs:
404
+ return {
405
+ "array": inputs["array"],
406
+ "sampling_rate": inputs.get("sampling_rate", 16000),
407
+ }
408
+ if "raw" in inputs:
409
+ return {
410
+ "array": inputs["raw"],
411
+ "sampling_rate": inputs.get("sampling_rate", 16000),
412
+ }
413
+ elif isinstance(inputs, str):
414
+ # File path - load audio using ffmpeg (same as HF pipeline)
415
+ with open(inputs, "rb") as f:
416
+ audio = ffmpeg_read(f.read(), sampling_rate=16000)
417
+ return {"array": audio, "sampling_rate": 16000}
418
+ elif isinstance(inputs, bytes):
419
+ audio = ffmpeg_read(inputs, sampling_rate=16000)
420
+ return {"array": audio, "sampling_rate": 16000}
421
+ elif isinstance(inputs, np.ndarray):
422
+ return {"array": inputs, "sampling_rate": 16000}
423
+
424
+ return None
425
 
426
  def preprocess(self, inputs, **preprocess_params):
427
  # Handle dict with "array" key (from datasets)