mazesmazes commited on
Commit
f6305ff
·
verified ·
1 Parent(s): 0a45964

Update custom model files, README, and requirements

Browse files
Files changed (2) hide show
  1. asr_pipeline.py +337 -0
  2. requirements.txt +5 -13
asr_pipeline.py CHANGED
@@ -1,5 +1,6 @@
1
  from typing import Any
2
 
 
3
  import torch
4
  import transformers
5
 
@@ -9,6 +10,211 @@ except ImportError:
9
  from asr_modeling import ASRModel # type: ignore[no-redef]
10
 
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
13
  """ASR Pipeline for audio-to-text transcription."""
14
 
@@ -24,6 +230,137 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
24
  super().__init__(
25
  model=model, feature_extractor=feature_extractor, tokenizer=tokenizer, **kwargs
26
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  def preprocess(self, inputs, **preprocess_params):
29
  # Handle dict with "array" key (from datasets)
 
1
  from typing import Any
2
 
3
+ import numpy as np
4
  import torch
5
  import transformers
6
 
 
10
  from asr_modeling import ASRModel # type: ignore[no-redef]
11
 
12
 
13
+ class ForcedAligner:
14
+ """Lazy-loaded forced aligner for word-level timestamps."""
15
+
16
+ _instance = None
17
+ _model = None
18
+ _tokenizer = None
19
+
20
+ @classmethod
21
+ def get_instance(cls, device: str = "cuda"):
22
+ if cls._model is None:
23
+ from ctc_forced_aligner import load_alignment_model
24
+
25
+ dtype = torch.float16 if device == "cuda" else torch.float32
26
+ cls._model, cls._tokenizer = load_alignment_model(device, dtype=dtype)
27
+ return cls._model, cls._tokenizer
28
+
29
+ @classmethod
30
+ def align(
31
+ cls,
32
+ audio: np.ndarray,
33
+ text: str,
34
+ sample_rate: int = 16000,
35
+ language: str = "eng",
36
+ batch_size: int = 16,
37
+ ) -> list[dict]:
38
+ """Align transcript to audio and return word-level timestamps.
39
+
40
+ Args:
41
+ audio: Audio waveform as numpy array
42
+ text: Transcript text to align
43
+ sample_rate: Audio sample rate (default 16000)
44
+ language: ISO-639-3 language code (default "eng" for English)
45
+ batch_size: Batch size for alignment model
46
+
47
+ Returns:
48
+ List of dicts with 'word', 'start', 'end' keys
49
+ """
50
+ from ctc_forced_aligner import (
51
+ generate_emissions,
52
+ get_alignments,
53
+ get_spans,
54
+ postprocess_results,
55
+ preprocess_text,
56
+ )
57
+
58
+ device = "cuda" if torch.cuda.is_available() else "cpu"
59
+ model, tokenizer = cls.get_instance(device)
60
+
61
+ # Convert audio to tensor
62
+ if isinstance(audio, np.ndarray):
63
+ audio_tensor = torch.from_numpy(audio).to(model.dtype).to(model.device)
64
+ else:
65
+ audio_tensor = audio.to(model.dtype).to(model.device)
66
+
67
+ # Ensure 1D
68
+ if audio_tensor.dim() > 1:
69
+ audio_tensor = audio_tensor.squeeze()
70
+
71
+ # Generate emissions
72
+ emissions, stride = generate_emissions(model, audio_tensor, batch_size=batch_size)
73
+
74
+ # Preprocess text
75
+ tokens_starred, text_starred = preprocess_text(text, romanize=True, language=language)
76
+
77
+ # Get alignments
78
+ segments, scores, blank_token = get_alignments(emissions, tokens_starred, tokenizer)
79
+
80
+ # Get spans
81
+ spans = get_spans(tokens_starred, segments, blank_token)
82
+
83
+ # Get word timestamps
84
+ word_timestamps = postprocess_results(text_starred, spans, stride, scores)
85
+
86
+ # Convert to simple format
87
+ return [{"word": w["word"], "start": w["start"], "end": w["end"]} for w in word_timestamps]
88
+
89
+
90
+ class SpeakerDiarizer:
91
+ """Lazy-loaded speaker diarization using pyannote-audio."""
92
+
93
+ _pipeline = None
94
+
95
+ @classmethod
96
+ def get_instance(cls, hf_token: str | None = None):
97
+ """Get or create the diarization pipeline.
98
+
99
+ Args:
100
+ hf_token: HuggingFace token with access to pyannote models.
101
+ Can also be set via HF_TOKEN environment variable.
102
+ """
103
+ if cls._pipeline is None:
104
+ import os
105
+
106
+ from pyannote.audio import Pipeline
107
+
108
+ token = hf_token or os.environ.get("HF_TOKEN")
109
+ cls._pipeline = Pipeline.from_pretrained(
110
+ "pyannote/speaker-diarization-3.1",
111
+ use_auth_token=token,
112
+ )
113
+
114
+ # Move to GPU if available
115
+ if torch.cuda.is_available():
116
+ cls._pipeline.to(torch.device("cuda"))
117
+
118
+ return cls._pipeline
119
+
120
+ @classmethod
121
+ def diarize(
122
+ cls,
123
+ audio: np.ndarray | str,
124
+ sample_rate: int = 16000,
125
+ num_speakers: int | None = None,
126
+ min_speakers: int | None = None,
127
+ max_speakers: int | None = None,
128
+ hf_token: str | None = None,
129
+ ) -> list[dict]:
130
+ """Run speaker diarization on audio.
131
+
132
+ Args:
133
+ audio: Audio waveform as numpy array or path to audio file
134
+ sample_rate: Audio sample rate (default 16000)
135
+ num_speakers: Exact number of speakers (if known)
136
+ min_speakers: Minimum number of speakers
137
+ max_speakers: Maximum number of speakers
138
+ hf_token: HuggingFace token for pyannote models
139
+
140
+ Returns:
141
+ List of dicts with 'speaker', 'start', 'end' keys
142
+ """
143
+ pipeline = cls.get_instance(hf_token)
144
+
145
+ # Prepare audio input
146
+ if isinstance(audio, np.ndarray):
147
+ # pyannote expects {"waveform": tensor, "sample_rate": int}
148
+ waveform = torch.from_numpy(audio).unsqueeze(0) # Add channel dim
149
+ if waveform.dim() == 1:
150
+ waveform = waveform.unsqueeze(0)
151
+ audio_input = {"waveform": waveform, "sample_rate": sample_rate}
152
+ else:
153
+ # File path
154
+ audio_input = audio
155
+
156
+ # Run diarization
157
+ diarization_args = {}
158
+ if num_speakers is not None:
159
+ diarization_args["num_speakers"] = num_speakers
160
+ if min_speakers is not None:
161
+ diarization_args["min_speakers"] = min_speakers
162
+ if max_speakers is not None:
163
+ diarization_args["max_speakers"] = max_speakers
164
+
165
+ diarization = pipeline(audio_input, **diarization_args)
166
+
167
+ # Convert to simple format
168
+ segments = []
169
+ for turn, _, speaker in diarization.itertracks(yield_label=True):
170
+ segments.append({
171
+ "speaker": speaker,
172
+ "start": turn.start,
173
+ "end": turn.end,
174
+ })
175
+
176
+ return segments
177
+
178
+ @classmethod
179
+ def assign_speakers_to_words(
180
+ cls,
181
+ words: list[dict],
182
+ speaker_segments: list[dict],
183
+ ) -> list[dict]:
184
+ """Assign speaker labels to words based on timestamp overlap.
185
+
186
+ Args:
187
+ words: List of word dicts with 'word', 'start', 'end' keys
188
+ speaker_segments: List of speaker dicts with 'speaker', 'start', 'end' keys
189
+
190
+ Returns:
191
+ Words list with 'speaker' key added to each word
192
+ """
193
+ for word in words:
194
+ word_mid = (word["start"] + word["end"]) / 2
195
+
196
+ # Find the speaker segment that contains this word's midpoint
197
+ best_speaker = None
198
+ for seg in speaker_segments:
199
+ if seg["start"] <= word_mid <= seg["end"]:
200
+ best_speaker = seg["speaker"]
201
+ break
202
+
203
+ # If no exact match, find closest segment
204
+ if best_speaker is None and speaker_segments:
205
+ min_dist = float("inf")
206
+ for seg in speaker_segments:
207
+ seg_mid = (seg["start"] + seg["end"]) / 2
208
+ dist = abs(word_mid - seg_mid)
209
+ if dist < min_dist:
210
+ min_dist = dist
211
+ best_speaker = seg["speaker"]
212
+
213
+ word["speaker"] = best_speaker
214
+
215
+ return words
216
+
217
+
218
  class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
219
  """ASR Pipeline for audio-to-text transcription."""
220
 
 
230
  super().__init__(
231
  model=model, feature_extractor=feature_extractor, tokenizer=tokenizer, **kwargs
232
  )
233
+ self._current_audio = None
234
+ self._return_timestamps = False
235
+ self._return_speakers = False
236
+ self._diarization_params = {}
237
+
238
+ def _sanitize_parameters(self, **kwargs):
239
+ """Intercept our custom parameters before parent class validates them."""
240
+ # Extract our custom parameters before parent sees them
241
+ self._return_timestamps = kwargs.pop("return_timestamps", False)
242
+ self._return_speakers = kwargs.pop("return_speakers", False)
243
+ self._diarization_params = {
244
+ "num_speakers": kwargs.pop("num_speakers", None),
245
+ "min_speakers": kwargs.pop("min_speakers", None),
246
+ "max_speakers": kwargs.pop("max_speakers", None),
247
+ "hf_token": kwargs.pop("hf_token", None),
248
+ }
249
+
250
+ # return_speakers requires return_timestamps
251
+ if self._return_speakers:
252
+ self._return_timestamps = True
253
+
254
+ # Now let parent sanitize remaining params
255
+ return super()._sanitize_parameters(**kwargs)
256
+
257
+ def __call__(
258
+ self,
259
+ inputs,
260
+ **kwargs,
261
+ ):
262
+ """Transcribe audio with optional word-level timestamps and speaker diarization.
263
+
264
+ Args:
265
+ inputs: Audio input (file path, dict with array/sampling_rate, etc.)
266
+ return_timestamps: If True, return word-level timestamps using forced alignment
267
+ return_speakers: If True, return speaker labels for each word
268
+ num_speakers: Exact number of speakers (if known, for diarization)
269
+ min_speakers: Minimum number of speakers (for diarization)
270
+ max_speakers: Maximum number of speakers (for diarization)
271
+ hf_token: HuggingFace token for pyannote models (or set HF_TOKEN env var)
272
+ **kwargs: Additional arguments passed to the pipeline
273
+
274
+ Returns:
275
+ Dict with 'text' key, 'words' key if return_timestamps=True,
276
+ and speaker labels on words if return_speakers=True
277
+ """
278
+ # Extract our params before super().__call__ (which will also call _sanitize_parameters)
279
+ return_timestamps = kwargs.pop("return_timestamps", False)
280
+ return_speakers = kwargs.pop("return_speakers", False)
281
+ diarization_params = {
282
+ "num_speakers": kwargs.pop("num_speakers", None),
283
+ "min_speakers": kwargs.pop("min_speakers", None),
284
+ "max_speakers": kwargs.pop("max_speakers", None),
285
+ "hf_token": kwargs.pop("hf_token", None),
286
+ }
287
+
288
+ if return_speakers:
289
+ return_timestamps = True
290
+
291
+ # Store audio for timestamp alignment and diarization
292
+ if return_timestamps or return_speakers:
293
+ self._current_audio = self._extract_audio(inputs)
294
+
295
+ # Run standard transcription
296
+ result = super().__call__(inputs, **kwargs)
297
+
298
+ # Add timestamps if requested
299
+ if return_timestamps and self._current_audio is not None:
300
+ text = result.get("text", "")
301
+ if text:
302
+ try:
303
+ words = ForcedAligner.align(
304
+ self._current_audio["array"],
305
+ text,
306
+ sample_rate=self._current_audio.get("sampling_rate", 16000),
307
+ )
308
+ result["words"] = words
309
+ except Exception as e:
310
+ result["words"] = []
311
+ result["timestamp_error"] = str(e)
312
+ else:
313
+ result["words"] = []
314
+
315
+ # Add speaker diarization if requested
316
+ if return_speakers and self._current_audio is not None:
317
+ try:
318
+ # Run diarization
319
+ speaker_segments = SpeakerDiarizer.diarize(
320
+ self._current_audio["array"],
321
+ sample_rate=self._current_audio.get("sampling_rate", 16000),
322
+ **{k: v for k, v in diarization_params.items() if v is not None},
323
+ )
324
+ result["speaker_segments"] = speaker_segments
325
+
326
+ # Assign speakers to words
327
+ if result.get("words"):
328
+ result["words"] = SpeakerDiarizer.assign_speakers_to_words(
329
+ result["words"],
330
+ speaker_segments,
331
+ )
332
+ except Exception as e:
333
+ result["speaker_segments"] = []
334
+ result["diarization_error"] = str(e)
335
+
336
+ # Clean up
337
+ self._current_audio = None
338
+
339
+ return result
340
+
341
+ def _extract_audio(self, inputs) -> dict | None:
342
+ """Extract audio array from various input formats."""
343
+ import librosa
344
+
345
+ if isinstance(inputs, dict):
346
+ if "array" in inputs:
347
+ return {
348
+ "array": inputs["array"],
349
+ "sampling_rate": inputs.get("sampling_rate", 16000),
350
+ }
351
+ if "raw" in inputs:
352
+ return {
353
+ "array": inputs["raw"],
354
+ "sampling_rate": inputs.get("sampling_rate", 16000),
355
+ }
356
+ elif isinstance(inputs, str):
357
+ # File path - load audio
358
+ audio, sr = librosa.load(inputs, sr=16000)
359
+ return {"array": audio, "sampling_rate": sr}
360
+ elif isinstance(inputs, np.ndarray):
361
+ return {"array": inputs, "sampling_rate": 16000}
362
+
363
+ return None
364
 
365
  def preprocess(self, inputs, **preprocess_params):
366
  # Handle dict with "array" key (from datasets)
requirements.txt CHANGED
@@ -1,14 +1,6 @@
1
- # Use latest compatible versions
2
- gradio
3
- transformers>=4.57.1
4
- torch
5
- soundfile
6
- librosa
7
- peft
8
- truecase
9
 
10
- # Forced alignment for word-level timestamps
11
- ctc-forced-aligner @ git+https://github.com/MahmoudAshraf97/ctc-forced-aligner.git
12
-
13
- # Speaker diarization
14
- pyannote-audio>=3.1.0
 
1
+ # Core dependencies for tiny-audio model inference
2
+ # This file is pushed to HuggingFace for model repository
 
 
 
 
 
 
3
 
4
+ # Transformers - main library for model loading and inference
5
+ transformers>=4.57.0
6
+ truecase