mazesmazes commited on
Commit
29f8a60
·
verified ·
1 Parent(s): 0549714

Update custom model files, README, and requirements

Browse files
Files changed (1) hide show
  1. asr_pipeline.py +7 -255
asr_pipeline.py CHANGED
@@ -282,160 +282,6 @@ class SpeakerDiarizer:
282
  return words
283
 
284
 
285
- class VoiceActivityDetector:
286
- """Voice Activity Detection using pyannote for improved transcription quality.
287
-
288
- Based on WhisperX implementation. Detects speech regions in audio and chunks
289
- them for more accurate transcription of long audio files.
290
- """
291
-
292
- _model = None
293
- _pipeline = None
294
-
295
- @classmethod
296
- def get_instance(cls, vad_onset: float = 0.5, vad_offset: float = 0.363):
297
- """Get or create the VAD pipeline instance.
298
-
299
- Args:
300
- vad_onset: Threshold for speech start detection (default 0.5)
301
- vad_offset: Threshold for speech end detection (default 0.363)
302
- """
303
- if cls._pipeline is None:
304
- from pyannote.audio import Model
305
- from pyannote.audio.pipelines import VoiceActivityDetection
306
-
307
- # Load the segmentation model
308
- cls._model = Model.from_pretrained(
309
- "pyannote/segmentation-3.0",
310
- )
311
-
312
- # Create VAD pipeline with hyperparameters
313
- cls._pipeline = VoiceActivityDetection(segmentation=cls._model)
314
- cls._pipeline.instantiate({
315
- "onset": vad_onset,
316
- "offset": vad_offset,
317
- "min_duration_on": 0.1, # Min speech duration (100ms)
318
- "min_duration_off": 0.1, # Min silence duration (100ms)
319
- })
320
-
321
- # Move to GPU if available
322
- if torch.cuda.is_available():
323
- cls._pipeline.to(torch.device("cuda"))
324
- elif torch.backends.mps.is_available():
325
- cls._pipeline.to(torch.device("mps"))
326
-
327
- return cls._pipeline
328
-
329
- @classmethod
330
- def detect(
331
- cls,
332
- audio: np.ndarray,
333
- sample_rate: int = 16000,
334
- vad_onset: float = 0.5,
335
- vad_offset: float = 0.363,
336
- ) -> list[dict]:
337
- """Detect speech regions in audio.
338
-
339
- Args:
340
- audio: Audio waveform as numpy array
341
- sample_rate: Audio sample rate (default 16000)
342
- vad_onset: Threshold for speech start detection
343
- vad_offset: Threshold for speech end detection
344
-
345
- Returns:
346
- List of dicts with 'start', 'end' keys (in seconds)
347
- """
348
- pipeline = cls.get_instance(vad_onset, vad_offset)
349
-
350
- # Prepare audio input
351
- waveform = torch.from_numpy(audio).float()
352
- if waveform.dim() == 1:
353
- waveform = waveform.unsqueeze(0)
354
-
355
- audio_input = {"waveform": waveform, "sample_rate": sample_rate}
356
-
357
- # Run VAD
358
- vad_result = pipeline(audio_input)
359
-
360
- # Convert to list of segments
361
- segments = []
362
- for speech_turn in vad_result.get_timeline():
363
- segments.append({
364
- "start": speech_turn.start,
365
- "end": speech_turn.end,
366
- })
367
-
368
- return segments
369
-
370
- @classmethod
371
- def merge_chunks(
372
- cls,
373
- segments: list[dict],
374
- chunk_size: float = 30.0,
375
- ) -> list[dict]:
376
- """Merge VAD segments into larger chunks for batched processing.
377
-
378
- Args:
379
- segments: List of VAD segments with 'start', 'end' keys
380
- chunk_size: Maximum chunk duration in seconds (default 30)
381
-
382
- Returns:
383
- List of chunks with 'start', 'end', 'segments' keys
384
- """
385
- if not segments:
386
- return []
387
-
388
- merged = []
389
- curr_start = segments[0]["start"]
390
- curr_end = segments[0]["end"]
391
- curr_segments = []
392
-
393
- for seg in segments:
394
- # If adding this segment exceeds chunk_size, finalize current chunk
395
- if seg["end"] - curr_start > chunk_size and curr_segments:
396
- merged.append({
397
- "start": curr_start,
398
- "end": curr_end,
399
- "segments": curr_segments,
400
- })
401
- curr_start = seg["start"]
402
- curr_segments = []
403
-
404
- curr_end = seg["end"]
405
- curr_segments.append((seg["start"], seg["end"]))
406
-
407
- # Add final chunk
408
- if curr_segments:
409
- merged.append({
410
- "start": curr_start,
411
- "end": curr_end,
412
- "segments": curr_segments,
413
- })
414
-
415
- return merged
416
-
417
- @classmethod
418
- def extract_chunk_audio(
419
- cls,
420
- audio: np.ndarray,
421
- chunk: dict,
422
- sample_rate: int = 16000,
423
- ) -> np.ndarray:
424
- """Extract audio for a specific chunk.
425
-
426
- Args:
427
- audio: Full audio waveform
428
- chunk: Chunk dict with 'start', 'end' keys
429
- sample_rate: Audio sample rate
430
-
431
- Returns:
432
- Audio chunk as numpy array
433
- """
434
- start_sample = int(chunk["start"] * sample_rate)
435
- end_sample = int(chunk["end"] * sample_rate)
436
- return audio[start_sample:end_sample]
437
-
438
-
439
  class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
440
  """ASR Pipeline for audio-to-text transcription."""
441
 
@@ -462,10 +308,6 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
462
  kwargs.pop("min_speakers", None)
463
  kwargs.pop("max_speakers", None)
464
  kwargs.pop("hf_token", None)
465
- kwargs.pop("use_vad", None)
466
- kwargs.pop("vad_onset", None)
467
- kwargs.pop("vad_offset", None)
468
- kwargs.pop("chunk_size", None)
469
 
470
  return super()._sanitize_parameters(**kwargs)
471
 
@@ -474,14 +316,10 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
474
  inputs,
475
  **kwargs,
476
  ):
477
- """Transcribe audio with optional VAD, timestamps, and speaker diarization.
478
 
479
  Args:
480
  inputs: Audio input (file path, dict with array/sampling_rate, etc.)
481
- use_vad: If True, use Voice Activity Detection to chunk audio (recommended for long audio)
482
- vad_onset: VAD speech start threshold (default 0.5)
483
- vad_offset: VAD speech end threshold (default 0.363)
484
- chunk_size: Maximum chunk duration in seconds for VAD (default 30)
485
  return_timestamps: If True, return word-level timestamps using forced alignment
486
  return_speakers: If True, return speaker labels for each word
487
  num_speakers: Exact number of speakers (if known, for diarization)
@@ -492,13 +330,9 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
492
 
493
  Returns:
494
  Dict with 'text' key, 'words' key if return_timestamps=True,
495
- 'vad_segments' if use_vad=True, and speaker labels on words if return_speakers=True
496
  """
497
  # Extract our params before super().__call__ (which will also call _sanitize_parameters)
498
- use_vad = kwargs.pop("use_vad", False)
499
- vad_onset = kwargs.pop("vad_onset", 0.5)
500
- vad_offset = kwargs.pop("vad_offset", 0.363)
501
- chunk_size = kwargs.pop("chunk_size", 30.0)
502
  return_timestamps = kwargs.pop("return_timestamps", False)
503
  return_speakers = kwargs.pop("return_speakers", False)
504
  diarization_params = {
@@ -511,25 +345,12 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
511
  if return_speakers:
512
  return_timestamps = True
513
 
514
- # Extract audio for VAD, timestamps, and diarization
515
- audio_data = self._extract_audio(inputs)
516
-
517
- # Use VAD to chunk and transcribe long audio
518
- if use_vad and audio_data is not None:
519
- result = self._transcribe_with_vad(
520
- audio_data,
521
- vad_onset=vad_onset,
522
- vad_offset=vad_offset,
523
- chunk_size=chunk_size,
524
- **kwargs,
525
- )
526
- else:
527
- # Store audio for timestamp alignment and diarization
528
- if return_timestamps or return_speakers:
529
- self._current_audio = audio_data
530
 
531
- # Run standard transcription
532
- result = super().__call__(inputs, **kwargs)
533
 
534
  # Add timestamps if requested
535
  if return_timestamps and self._current_audio is not None:
@@ -602,75 +423,6 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
602
 
603
  return None
604
 
605
- def _transcribe_with_vad(
606
- self,
607
- audio_data: dict,
608
- vad_onset: float = 0.5,
609
- vad_offset: float = 0.363,
610
- chunk_size: float = 30.0,
611
- **kwargs,
612
- ) -> dict:
613
- """Transcribe audio using VAD to chunk long audio.
614
-
615
- Args:
616
- audio_data: Dict with 'array' and 'sampling_rate' keys
617
- vad_onset: VAD speech start threshold
618
- vad_offset: VAD speech end threshold
619
- chunk_size: Maximum chunk duration in seconds
620
- **kwargs: Additional arguments passed to transcription
621
-
622
- Returns:
623
- Dict with 'text', 'vad_segments', and 'chunks' keys
624
- """
625
- audio = audio_data["array"]
626
- sample_rate = audio_data.get("sampling_rate", 16000)
627
-
628
- # Run VAD to detect speech regions
629
- vad_segments = VoiceActivityDetector.detect(
630
- audio,
631
- sample_rate=sample_rate,
632
- vad_onset=vad_onset,
633
- vad_offset=vad_offset,
634
- )
635
-
636
- if not vad_segments:
637
- return {"text": "", "vad_segments": [], "chunks": []}
638
-
639
- # Merge segments into chunks
640
- chunks = VoiceActivityDetector.merge_chunks(vad_segments, chunk_size)
641
-
642
- # Transcribe each chunk
643
- all_text = []
644
- chunk_results = []
645
-
646
- for chunk in chunks:
647
- # Extract chunk audio
648
- chunk_audio = VoiceActivityDetector.extract_chunk_audio(
649
- audio, chunk, sample_rate
650
- )
651
-
652
- # Transcribe chunk
653
- chunk_input = {"raw": chunk_audio, "sampling_rate": sample_rate}
654
- chunk_result = super().__call__(chunk_input, **kwargs)
655
-
656
- chunk_text = chunk_result.get("text", "").strip()
657
- all_text.append(chunk_text)
658
-
659
- chunk_results.append({
660
- "start": chunk["start"],
661
- "end": chunk["end"],
662
- "text": chunk_text,
663
- })
664
-
665
- # Store audio for potential timestamp/diarization
666
- self._current_audio = audio_data
667
-
668
- return {
669
- "text": " ".join(all_text),
670
- "vad_segments": vad_segments,
671
- "chunks": chunk_results,
672
- }
673
-
674
  def preprocess(self, inputs, **preprocess_params):
675
  # Handle dict with "array" key (from datasets)
676
  if isinstance(inputs, dict) and "array" in inputs:
 
282
  return words
283
 
284
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
  class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
286
  """ASR Pipeline for audio-to-text transcription."""
287
 
 
308
  kwargs.pop("min_speakers", None)
309
  kwargs.pop("max_speakers", None)
310
  kwargs.pop("hf_token", None)
 
 
 
 
311
 
312
  return super()._sanitize_parameters(**kwargs)
313
 
 
316
  inputs,
317
  **kwargs,
318
  ):
319
+ """Transcribe audio with optional word-level timestamps and speaker diarization.
320
 
321
  Args:
322
  inputs: Audio input (file path, dict with array/sampling_rate, etc.)
 
 
 
 
323
  return_timestamps: If True, return word-level timestamps using forced alignment
324
  return_speakers: If True, return speaker labels for each word
325
  num_speakers: Exact number of speakers (if known, for diarization)
 
330
 
331
  Returns:
332
  Dict with 'text' key, 'words' key if return_timestamps=True,
333
+ and speaker labels on words if return_speakers=True
334
  """
335
  # Extract our params before super().__call__ (which will also call _sanitize_parameters)
 
 
 
 
336
  return_timestamps = kwargs.pop("return_timestamps", False)
337
  return_speakers = kwargs.pop("return_speakers", False)
338
  diarization_params = {
 
345
  if return_speakers:
346
  return_timestamps = True
347
 
348
+ # Store audio for timestamp alignment and diarization
349
+ if return_timestamps or return_speakers:
350
+ self._current_audio = self._extract_audio(inputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
351
 
352
+ # Run standard transcription
353
+ result = super().__call__(inputs, **kwargs)
354
 
355
  # Add timestamps if requested
356
  if return_timestamps and self._current_audio is not None:
 
423
 
424
  return None
425
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
426
  def preprocess(self, inputs, **preprocess_params):
427
  # Handle dict with "array" key (from datasets)
428
  if isinstance(inputs, dict) and "array" in inputs: