mazesmazes commited on
Commit
32d7b9c
·
verified ·
1 Parent(s): 6f0e8d4

Training in progress - step 19000

Browse files
Files changed (1) hide show
  1. asr_pipeline.py +0 -398
asr_pipeline.py CHANGED
@@ -1,6 +1,5 @@
1
  from typing import Any
2
 
3
- import numpy as np
4
  import torch
5
  import transformers
6
 
@@ -10,278 +9,6 @@ except ImportError:
10
  from asr_modeling import ASRModel # type: ignore[no-redef]
11
 
12
 
13
- class ForcedAligner:
14
- """Lazy-loaded forced aligner for word-level timestamps using torchaudio wav2vec2."""
15
-
16
- _bundle = None
17
- _model = None
18
- _labels = None
19
- _dictionary = None
20
-
21
- @classmethod
22
- def get_instance(cls, device: str = "cuda"):
23
- if cls._model is None:
24
- import torchaudio
25
-
26
- cls._bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
27
- cls._model = cls._bundle.get_model().to(device)
28
- cls._model.eval()
29
- cls._labels = cls._bundle.get_labels()
30
- cls._dictionary = {c: i for i, c in enumerate(cls._labels)}
31
- return cls._model, cls._labels, cls._dictionary
32
-
33
- @classmethod
34
- def align(
35
- cls,
36
- audio: np.ndarray,
37
- text: str,
38
- sample_rate: int = 16000,
39
- language: str = "eng",
40
- batch_size: int = 16,
41
- ) -> list[dict]:
42
- """Align transcript to audio and return word-level timestamps.
43
-
44
- Args:
45
- audio: Audio waveform as numpy array
46
- text: Transcript text to align
47
- sample_rate: Audio sample rate (default 16000)
48
- language: ISO-639-3 language code (default "eng" for English, unused)
49
- batch_size: Batch size for alignment model (unused)
50
-
51
- Returns:
52
- List of dicts with 'word', 'start', 'end' keys
53
- """
54
- import torchaudio
55
- from torchaudio.functional import forced_align, merge_tokens
56
-
57
- device = "cuda" if torch.cuda.is_available() else "cpu"
58
- model, labels, dictionary = cls.get_instance(device)
59
-
60
- # Convert audio to tensor (copy to ensure array is writable)
61
- if isinstance(audio, np.ndarray):
62
- waveform = torch.from_numpy(audio.copy()).float()
63
- else:
64
- waveform = audio.clone().float()
65
-
66
- # Ensure 2D (channels, time)
67
- if waveform.dim() == 1:
68
- waveform = waveform.unsqueeze(0)
69
-
70
- # Resample if needed (wav2vec2 expects 16kHz)
71
- if sample_rate != cls._bundle.sample_rate:
72
- waveform = torchaudio.functional.resample(
73
- waveform, sample_rate, cls._bundle.sample_rate
74
- )
75
-
76
- waveform = waveform.to(device)
77
-
78
- # Get emissions from model
79
- with torch.inference_mode():
80
- emissions, _ = model(waveform)
81
- emissions = torch.log_softmax(emissions, dim=-1)
82
-
83
- emission = emissions[0].cpu()
84
-
85
- # Normalize text: uppercase, keep only valid characters
86
- transcript = text.upper()
87
- # Build tokens from transcript
88
- tokens = []
89
- for char in transcript:
90
- if char in dictionary:
91
- tokens.append(dictionary[char])
92
- elif char == " ":
93
- tokens.append(dictionary.get("|", dictionary.get(" ", 0)))
94
-
95
- if not tokens:
96
- return []
97
-
98
- targets = torch.tensor([tokens], dtype=torch.int32)
99
-
100
- # Run forced alignment
101
- # Note: forced_align is deprecated in torchaudio 2.6+ and will be removed in 2.9 (late 2025)
102
- # No official replacement announced yet. See https://github.com/pytorch/audio/issues/3902
103
- aligned_tokens, scores = forced_align(emission.unsqueeze(0), targets, blank=0)
104
-
105
- # Use torchaudio's merge_tokens to get token spans (removes blanks and merges repeats)
106
- token_spans = merge_tokens(aligned_tokens[0], scores[0])
107
-
108
- # Convert frame indices to time (model stride is 320 samples at 16kHz = 20ms)
109
- frame_duration = 320 / cls._bundle.sample_rate
110
-
111
- # Group token spans into words based on pipe separator
112
- words = text.split()
113
- word_timestamps = []
114
- current_word_start = None
115
- current_word_end = None
116
- word_idx = 0
117
-
118
- for span in token_spans:
119
- token_char = labels[span.token]
120
- if token_char == "|": # Word separator
121
- if current_word_start is not None and word_idx < len(words):
122
- word_timestamps.append({
123
- "word": words[word_idx],
124
- "start": current_word_start * frame_duration,
125
- "end": current_word_end * frame_duration,
126
- })
127
- word_idx += 1
128
- current_word_start = None
129
- current_word_end = None
130
- else:
131
- if current_word_start is None:
132
- current_word_start = span.start
133
- current_word_end = span.end
134
-
135
- # Don't forget the last word
136
- if current_word_start is not None and word_idx < len(words):
137
- word_timestamps.append({
138
- "word": words[word_idx],
139
- "start": current_word_start * frame_duration,
140
- "end": current_word_end * frame_duration,
141
- })
142
-
143
- return word_timestamps
144
-
145
-
146
- class SpeakerDiarizer:
147
- """Lazy-loaded speaker diarization using pyannote-audio."""
148
-
149
- _pipeline = None
150
-
151
- @classmethod
152
- def get_instance(cls, hf_token: str | None = None):
153
- """Get or create the diarization pipeline.
154
-
155
- Args:
156
- hf_token: HuggingFace token with access to pyannote models.
157
- Can also be set via HF_TOKEN environment variable.
158
- """
159
- if cls._pipeline is None:
160
- from pyannote.audio import Pipeline
161
-
162
- cls._pipeline = Pipeline.from_pretrained(
163
- "pyannote/speaker-diarization-3.1",
164
- )
165
-
166
- # Move to GPU if available
167
- if torch.cuda.is_available():
168
- cls._pipeline.to(torch.device("cuda"))
169
- elif torch.backends.mps.is_available():
170
- cls._pipeline.to(torch.device("mps"))
171
-
172
- return cls._pipeline
173
-
174
- @classmethod
175
- def diarize(
176
- cls,
177
- audio: np.ndarray | str,
178
- sample_rate: int = 16000,
179
- num_speakers: int | None = None,
180
- min_speakers: int | None = None,
181
- max_speakers: int | None = None,
182
- hf_token: str | None = None,
183
- ) -> list[dict]:
184
- """Run speaker diarization on audio.
185
-
186
- Args:
187
- audio: Audio waveform as numpy array or path to audio file
188
- sample_rate: Audio sample rate (default 16000)
189
- num_speakers: Exact number of speakers (if known)
190
- min_speakers: Minimum number of speakers
191
- max_speakers: Maximum number of speakers
192
- hf_token: HuggingFace token for pyannote models
193
-
194
- Returns:
195
- List of dicts with 'speaker', 'start', 'end' keys
196
- """
197
- pipeline = cls.get_instance(hf_token)
198
-
199
- # Prepare audio input
200
- if isinstance(audio, np.ndarray):
201
- # pyannote expects {"waveform": tensor, "sample_rate": int}
202
- waveform = torch.from_numpy(audio).unsqueeze(0) # Add channel dim
203
- if waveform.dim() == 1:
204
- waveform = waveform.unsqueeze(0)
205
- audio_input = {"waveform": waveform, "sample_rate": sample_rate}
206
- else:
207
- # File path
208
- audio_input = audio
209
-
210
- # Run diarization
211
- diarization_args = {}
212
- if num_speakers is not None:
213
- diarization_args["num_speakers"] = num_speakers
214
- if min_speakers is not None:
215
- diarization_args["min_speakers"] = min_speakers
216
- if max_speakers is not None:
217
- diarization_args["max_speakers"] = max_speakers
218
-
219
- diarization = pipeline(audio_input, **diarization_args)
220
-
221
- # Handle different pyannote return types
222
- # pyannote 3.x returns DiarizeOutput dataclass, older versions return Annotation
223
- if hasattr(diarization, "itertracks"):
224
- annotation = diarization
225
- elif hasattr(diarization, "speaker_diarization"):
226
- # pyannote 3.x DiarizeOutput dataclass
227
- annotation = diarization.speaker_diarization
228
- elif isinstance(diarization, tuple):
229
- # Some versions return (annotation, embeddings) tuple
230
- annotation = diarization[0]
231
- else:
232
- raise TypeError(f"Unexpected diarization output type: {type(diarization)}")
233
-
234
- # Convert to simple format
235
- segments = []
236
- for turn, _, speaker in annotation.itertracks(yield_label=True):
237
- segments.append({
238
- "speaker": speaker,
239
- "start": turn.start,
240
- "end": turn.end,
241
- })
242
-
243
- return segments
244
-
245
- @classmethod
246
- def assign_speakers_to_words(
247
- cls,
248
- words: list[dict],
249
- speaker_segments: list[dict],
250
- ) -> list[dict]:
251
- """Assign speaker labels to words based on timestamp overlap.
252
-
253
- Args:
254
- words: List of word dicts with 'word', 'start', 'end' keys
255
- speaker_segments: List of speaker dicts with 'speaker', 'start', 'end' keys
256
-
257
- Returns:
258
- Words list with 'speaker' key added to each word
259
- """
260
- for word in words:
261
- word_mid = (word["start"] + word["end"]) / 2
262
-
263
- # Find the speaker segment that contains this word's midpoint
264
- best_speaker = None
265
- for seg in speaker_segments:
266
- if seg["start"] <= word_mid <= seg["end"]:
267
- best_speaker = seg["speaker"]
268
- break
269
-
270
- # If no exact match, find closest segment
271
- if best_speaker is None and speaker_segments:
272
- min_dist = float("inf")
273
- for seg in speaker_segments:
274
- seg_mid = (seg["start"] + seg["end"]) / 2
275
- dist = abs(word_mid - seg_mid)
276
- if dist < min_dist:
277
- min_dist = dist
278
- best_speaker = seg["speaker"]
279
-
280
- word["speaker"] = best_speaker
281
-
282
- return words
283
-
284
-
285
  class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
286
  """ASR Pipeline for audio-to-text transcription."""
287
 
@@ -297,131 +24,6 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
297
  super().__init__(
298
  model=model, feature_extractor=feature_extractor, tokenizer=tokenizer, **kwargs
299
  )
300
- self._current_audio = None
301
-
302
- def _sanitize_parameters(self, **kwargs):
303
- """Intercept our custom parameters before parent class validates them."""
304
- # Remove our custom parameters so parent doesn't see them
305
- kwargs.pop("return_timestamps", None)
306
- kwargs.pop("return_speakers", None)
307
- kwargs.pop("num_speakers", None)
308
- kwargs.pop("min_speakers", None)
309
- kwargs.pop("max_speakers", None)
310
- kwargs.pop("hf_token", None)
311
-
312
- return super()._sanitize_parameters(**kwargs)
313
-
314
- def __call__(
315
- self,
316
- inputs,
317
- **kwargs,
318
- ):
319
- """Transcribe audio with optional word-level timestamps and speaker diarization.
320
-
321
- Args:
322
- inputs: Audio input (file path, dict with array/sampling_rate, etc.)
323
- return_timestamps: If True, return word-level timestamps using forced alignment
324
- return_speakers: If True, return speaker labels for each word
325
- num_speakers: Exact number of speakers (if known, for diarization)
326
- min_speakers: Minimum number of speakers (for diarization)
327
- max_speakers: Maximum number of speakers (for diarization)
328
- hf_token: HuggingFace token for pyannote models (or set HF_TOKEN env var)
329
- **kwargs: Additional arguments passed to the pipeline
330
-
331
- Returns:
332
- Dict with 'text' key, 'words' key if return_timestamps=True,
333
- and speaker labels on words if return_speakers=True
334
- """
335
- # Extract our params before super().__call__ (which will also call _sanitize_parameters)
336
- return_timestamps = kwargs.pop("return_timestamps", False)
337
- return_speakers = kwargs.pop("return_speakers", False)
338
- diarization_params = {
339
- "num_speakers": kwargs.pop("num_speakers", None),
340
- "min_speakers": kwargs.pop("min_speakers", None),
341
- "max_speakers": kwargs.pop("max_speakers", None),
342
- "hf_token": kwargs.pop("hf_token", None),
343
- }
344
-
345
- if return_speakers:
346
- return_timestamps = True
347
-
348
- # Store audio for timestamp alignment and diarization
349
- if return_timestamps or return_speakers:
350
- self._current_audio = self._extract_audio(inputs)
351
-
352
- # Run standard transcription
353
- result = super().__call__(inputs, **kwargs)
354
-
355
- # Add timestamps if requested
356
- if return_timestamps and self._current_audio is not None:
357
- text = result.get("text", "")
358
- if text:
359
- try:
360
- words = ForcedAligner.align(
361
- self._current_audio["array"],
362
- text,
363
- sample_rate=self._current_audio.get("sampling_rate", 16000),
364
- )
365
- result["words"] = words
366
- except Exception as e:
367
- result["words"] = []
368
- result["timestamp_error"] = str(e)
369
- else:
370
- result["words"] = []
371
-
372
- # Add speaker diarization if requested
373
- if return_speakers and self._current_audio is not None:
374
- try:
375
- # Run diarization
376
- speaker_segments = SpeakerDiarizer.diarize(
377
- self._current_audio["array"],
378
- sample_rate=self._current_audio.get("sampling_rate", 16000),
379
- **{k: v for k, v in diarization_params.items() if v is not None},
380
- )
381
- result["speaker_segments"] = speaker_segments
382
-
383
- # Assign speakers to words
384
- if result.get("words"):
385
- result["words"] = SpeakerDiarizer.assign_speakers_to_words(
386
- result["words"],
387
- speaker_segments,
388
- )
389
- except Exception as e:
390
- result["speaker_segments"] = []
391
- result["diarization_error"] = str(e)
392
-
393
- # Clean up
394
- self._current_audio = None
395
-
396
- return result
397
-
398
- def _extract_audio(self, inputs) -> dict | None:
399
- """Extract audio array from various input formats using HF utilities."""
400
- from transformers.pipelines.audio_utils import ffmpeg_read
401
-
402
- if isinstance(inputs, dict):
403
- if "array" in inputs:
404
- return {
405
- "array": inputs["array"],
406
- "sampling_rate": inputs.get("sampling_rate", 16000),
407
- }
408
- if "raw" in inputs:
409
- return {
410
- "array": inputs["raw"],
411
- "sampling_rate": inputs.get("sampling_rate", 16000),
412
- }
413
- elif isinstance(inputs, str):
414
- # File path - load audio using ffmpeg (same as HF pipeline)
415
- with open(inputs, "rb") as f:
416
- audio = ffmpeg_read(f.read(), sampling_rate=16000)
417
- return {"array": audio, "sampling_rate": 16000}
418
- elif isinstance(inputs, bytes):
419
- audio = ffmpeg_read(inputs, sampling_rate=16000)
420
- return {"array": audio, "sampling_rate": 16000}
421
- elif isinstance(inputs, np.ndarray):
422
- return {"array": inputs, "sampling_rate": 16000}
423
-
424
- return None
425
 
426
  def preprocess(self, inputs, **preprocess_params):
427
  # Handle dict with "array" key (from datasets)
 
1
  from typing import Any
2
 
 
3
  import torch
4
  import transformers
5
 
 
9
  from asr_modeling import ASRModel # type: ignore[no-redef]
10
 
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
13
  """ASR Pipeline for audio-to-text transcription."""
14
 
 
24
  super().__init__(
25
  model=model, feature_extractor=feature_extractor, tokenizer=tokenizer, **kwargs
26
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  def preprocess(self, inputs, **preprocess_params):
29
  # Handle dict with "array" key (from datasets)