mazesmazes commited on
Commit
4ef930f
·
verified ·
1 Parent(s): 88033ca

Training in progress - step 15000

Browse files
Files changed (1) hide show
  1. asr_pipeline.py +0 -325
asr_pipeline.py CHANGED
@@ -1,6 +1,5 @@
1
  from typing import Any
2
 
3
- import numpy as np
4
  import torch
5
  import transformers
6
 
@@ -10,211 +9,6 @@ except ImportError:
10
  from asr_modeling import ASRModel # type: ignore[no-redef]
11
 
12
 
13
- class ForcedAligner:
14
- """Lazy-loaded forced aligner for word-level timestamps."""
15
-
16
- _instance = None
17
- _model = None
18
- _tokenizer = None
19
-
20
- @classmethod
21
- def get_instance(cls, device: str = "cuda"):
22
- if cls._model is None:
23
- from ctc_forced_aligner import load_alignment_model
24
-
25
- dtype = torch.float16 if device == "cuda" else torch.float32
26
- cls._model, cls._tokenizer = load_alignment_model(device, dtype=dtype)
27
- return cls._model, cls._tokenizer
28
-
29
- @classmethod
30
- def align(
31
- cls,
32
- audio: np.ndarray,
33
- text: str,
34
- sample_rate: int = 16000,
35
- language: str = "eng",
36
- batch_size: int = 16,
37
- ) -> list[dict]:
38
- """Align transcript to audio and return word-level timestamps.
39
-
40
- Args:
41
- audio: Audio waveform as numpy array
42
- text: Transcript text to align
43
- sample_rate: Audio sample rate (default 16000)
44
- language: ISO-639-3 language code (default "eng" for English)
45
- batch_size: Batch size for alignment model
46
-
47
- Returns:
48
- List of dicts with 'word', 'start', 'end' keys
49
- """
50
- from ctc_forced_aligner import (
51
- generate_emissions,
52
- get_alignments,
53
- get_spans,
54
- postprocess_results,
55
- preprocess_text,
56
- )
57
-
58
- device = "cuda" if torch.cuda.is_available() else "cpu"
59
- model, tokenizer = cls.get_instance(device)
60
-
61
- # Convert audio to tensor
62
- if isinstance(audio, np.ndarray):
63
- audio_tensor = torch.from_numpy(audio).to(model.dtype).to(model.device)
64
- else:
65
- audio_tensor = audio.to(model.dtype).to(model.device)
66
-
67
- # Ensure 1D
68
- if audio_tensor.dim() > 1:
69
- audio_tensor = audio_tensor.squeeze()
70
-
71
- # Generate emissions
72
- emissions, stride = generate_emissions(model, audio_tensor, batch_size=batch_size)
73
-
74
- # Preprocess text
75
- tokens_starred, text_starred = preprocess_text(text, romanize=True, language=language)
76
-
77
- # Get alignments
78
- segments, scores, blank_token = get_alignments(emissions, tokens_starred, tokenizer)
79
-
80
- # Get spans
81
- spans = get_spans(tokens_starred, segments, blank_token)
82
-
83
- # Get word timestamps
84
- word_timestamps = postprocess_results(text_starred, spans, stride, scores)
85
-
86
- # Convert to simple format
87
- return [{"word": w["word"], "start": w["start"], "end": w["end"]} for w in word_timestamps]
88
-
89
-
90
- class SpeakerDiarizer:
91
- """Lazy-loaded speaker diarization using pyannote-audio."""
92
-
93
- _pipeline = None
94
-
95
- @classmethod
96
- def get_instance(cls, hf_token: str | None = None):
97
- """Get or create the diarization pipeline.
98
-
99
- Args:
100
- hf_token: HuggingFace token with access to pyannote models.
101
- Can also be set via HF_TOKEN environment variable.
102
- """
103
- if cls._pipeline is None:
104
- import os
105
-
106
- from pyannote.audio import Pipeline
107
-
108
- token = hf_token or os.environ.get("HF_TOKEN")
109
- cls._pipeline = Pipeline.from_pretrained(
110
- "pyannote/speaker-diarization-3.1",
111
- use_auth_token=token,
112
- )
113
-
114
- # Move to GPU if available
115
- if torch.cuda.is_available():
116
- cls._pipeline.to(torch.device("cuda"))
117
-
118
- return cls._pipeline
119
-
120
- @classmethod
121
- def diarize(
122
- cls,
123
- audio: np.ndarray | str,
124
- sample_rate: int = 16000,
125
- num_speakers: int | None = None,
126
- min_speakers: int | None = None,
127
- max_speakers: int | None = None,
128
- hf_token: str | None = None,
129
- ) -> list[dict]:
130
- """Run speaker diarization on audio.
131
-
132
- Args:
133
- audio: Audio waveform as numpy array or path to audio file
134
- sample_rate: Audio sample rate (default 16000)
135
- num_speakers: Exact number of speakers (if known)
136
- min_speakers: Minimum number of speakers
137
- max_speakers: Maximum number of speakers
138
- hf_token: HuggingFace token for pyannote models
139
-
140
- Returns:
141
- List of dicts with 'speaker', 'start', 'end' keys
142
- """
143
- pipeline = cls.get_instance(hf_token)
144
-
145
- # Prepare audio input
146
- if isinstance(audio, np.ndarray):
147
- # pyannote expects {"waveform": tensor, "sample_rate": int}
148
- waveform = torch.from_numpy(audio).unsqueeze(0) # Add channel dim
149
- if waveform.dim() == 1:
150
- waveform = waveform.unsqueeze(0)
151
- audio_input = {"waveform": waveform, "sample_rate": sample_rate}
152
- else:
153
- # File path
154
- audio_input = audio
155
-
156
- # Run diarization
157
- diarization_args = {}
158
- if num_speakers is not None:
159
- diarization_args["num_speakers"] = num_speakers
160
- if min_speakers is not None:
161
- diarization_args["min_speakers"] = min_speakers
162
- if max_speakers is not None:
163
- diarization_args["max_speakers"] = max_speakers
164
-
165
- diarization = pipeline(audio_input, **diarization_args)
166
-
167
- # Convert to simple format
168
- segments = []
169
- for turn, _, speaker in diarization.itertracks(yield_label=True):
170
- segments.append({
171
- "speaker": speaker,
172
- "start": turn.start,
173
- "end": turn.end,
174
- })
175
-
176
- return segments
177
-
178
- @classmethod
179
- def assign_speakers_to_words(
180
- cls,
181
- words: list[dict],
182
- speaker_segments: list[dict],
183
- ) -> list[dict]:
184
- """Assign speaker labels to words based on timestamp overlap.
185
-
186
- Args:
187
- words: List of word dicts with 'word', 'start', 'end' keys
188
- speaker_segments: List of speaker dicts with 'speaker', 'start', 'end' keys
189
-
190
- Returns:
191
- Words list with 'speaker' key added to each word
192
- """
193
- for word in words:
194
- word_mid = (word["start"] + word["end"]) / 2
195
-
196
- # Find the speaker segment that contains this word's midpoint
197
- best_speaker = None
198
- for seg in speaker_segments:
199
- if seg["start"] <= word_mid <= seg["end"]:
200
- best_speaker = seg["speaker"]
201
- break
202
-
203
- # If no exact match, find closest segment
204
- if best_speaker is None and speaker_segments:
205
- min_dist = float("inf")
206
- for seg in speaker_segments:
207
- seg_mid = (seg["start"] + seg["end"]) / 2
208
- dist = abs(word_mid - seg_mid)
209
- if dist < min_dist:
210
- min_dist = dist
211
- best_speaker = seg["speaker"]
212
-
213
- word["speaker"] = best_speaker
214
-
215
- return words
216
-
217
-
218
  class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
219
  """ASR Pipeline for audio-to-text transcription."""
220
 
@@ -230,125 +24,6 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
230
  super().__init__(
231
  model=model, feature_extractor=feature_extractor, tokenizer=tokenizer, **kwargs
232
  )
233
- self._current_audio = None
234
- self._return_timestamps = False
235
- self._return_speakers = False
236
- self._diarization_params = {}
237
-
238
- def _sanitize_parameters(self, **kwargs):
239
- """Intercept our custom parameters before parent class validates them."""
240
- # Extract our custom parameters before parent sees them
241
- self._return_timestamps = kwargs.pop("return_timestamps", False)
242
- self._return_speakers = kwargs.pop("return_speakers", False)
243
- self._diarization_params = {
244
- "num_speakers": kwargs.pop("num_speakers", None),
245
- "min_speakers": kwargs.pop("min_speakers", None),
246
- "max_speakers": kwargs.pop("max_speakers", None),
247
- "hf_token": kwargs.pop("hf_token", None),
248
- }
249
-
250
- # return_speakers requires return_timestamps
251
- if self._return_speakers:
252
- self._return_timestamps = True
253
-
254
- # Now let parent sanitize remaining params
255
- return super()._sanitize_parameters(**kwargs)
256
-
257
- def __call__(
258
- self,
259
- inputs,
260
- **kwargs,
261
- ):
262
- """Transcribe audio with optional word-level timestamps and speaker diarization.
263
-
264
- Args:
265
- inputs: Audio input (file path, dict with array/sampling_rate, etc.)
266
- return_timestamps: If True, return word-level timestamps using forced alignment
267
- return_speakers: If True, return speaker labels for each word
268
- num_speakers: Exact number of speakers (if known, for diarization)
269
- min_speakers: Minimum number of speakers (for diarization)
270
- max_speakers: Maximum number of speakers (for diarization)
271
- hf_token: HuggingFace token for pyannote models (or set HF_TOKEN env var)
272
- **kwargs: Additional arguments passed to the pipeline
273
-
274
- Returns:
275
- Dict with 'text' key, 'words' key if return_timestamps=True,
276
- and speaker labels on words if return_speakers=True
277
- """
278
- # Store audio for timestamp alignment and diarization
279
- if self._return_timestamps or self._return_speakers:
280
- self._current_audio = self._extract_audio(inputs)
281
-
282
- # Run standard transcription
283
- result = super().__call__(inputs, **kwargs)
284
-
285
- # Add timestamps if requested
286
- if self._return_timestamps and self._current_audio is not None:
287
- text = result.get("text", "")
288
- if text:
289
- try:
290
- words = ForcedAligner.align(
291
- self._current_audio["array"],
292
- text,
293
- sample_rate=self._current_audio.get("sampling_rate", 16000),
294
- )
295
- result["words"] = words
296
- except Exception as e:
297
- result["words"] = []
298
- result["timestamp_error"] = str(e)
299
- else:
300
- result["words"] = []
301
-
302
- # Add speaker diarization if requested
303
- if self._return_speakers and self._current_audio is not None:
304
- try:
305
- # Run diarization
306
- speaker_segments = SpeakerDiarizer.diarize(
307
- self._current_audio["array"],
308
- sample_rate=self._current_audio.get("sampling_rate", 16000),
309
- **{k: v for k, v in self._diarization_params.items() if v is not None},
310
- )
311
- result["speaker_segments"] = speaker_segments
312
-
313
- # Assign speakers to words
314
- if result.get("words"):
315
- result["words"] = SpeakerDiarizer.assign_speakers_to_words(
316
- result["words"],
317
- speaker_segments,
318
- )
319
- except Exception as e:
320
- result["speaker_segments"] = []
321
- result["diarization_error"] = str(e)
322
-
323
- # Clean up
324
- if self._return_timestamps or self._return_speakers:
325
- self._current_audio = None
326
-
327
- return result
328
-
329
- def _extract_audio(self, inputs) -> dict | None:
330
- """Extract audio array from various input formats."""
331
- import librosa
332
-
333
- if isinstance(inputs, dict):
334
- if "array" in inputs:
335
- return {
336
- "array": inputs["array"],
337
- "sampling_rate": inputs.get("sampling_rate", 16000),
338
- }
339
- if "raw" in inputs:
340
- return {
341
- "array": inputs["raw"],
342
- "sampling_rate": inputs.get("sampling_rate", 16000),
343
- }
344
- elif isinstance(inputs, str):
345
- # File path - load audio
346
- audio, sr = librosa.load(inputs, sr=16000)
347
- return {"array": audio, "sampling_rate": sr}
348
- elif isinstance(inputs, np.ndarray):
349
- return {"array": inputs, "sampling_rate": 16000}
350
-
351
- return None
352
 
353
  def preprocess(self, inputs, **preprocess_params):
354
  # Handle dict with "array" key (from datasets)
 
1
  from typing import Any
2
 
 
3
  import torch
4
  import transformers
5
 
 
9
  from asr_modeling import ASRModel # type: ignore[no-redef]
10
 
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
13
  """ASR Pipeline for audio-to-text transcription."""
14
 
 
24
  super().__init__(
25
  model=model, feature_extractor=feature_extractor, tokenizer=tokenizer, **kwargs
26
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  def preprocess(self, inputs, **preprocess_params):
29
  # Handle dict with "array" key (from datasets)