mazesmazes commited on
Commit
0549714
·
verified ·
1 Parent(s): 3878fe5

Update custom model files, README, and requirements

Browse files
Files changed (1) hide show
  1. asr_pipeline.py +255 -7
asr_pipeline.py CHANGED
@@ -282,6 +282,160 @@ class SpeakerDiarizer:
282
  return words
283
 
284
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
  class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
286
  """ASR Pipeline for audio-to-text transcription."""
287
 
@@ -308,6 +462,10 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
308
  kwargs.pop("min_speakers", None)
309
  kwargs.pop("max_speakers", None)
310
  kwargs.pop("hf_token", None)
 
 
 
 
311
 
312
  return super()._sanitize_parameters(**kwargs)
313
 
@@ -316,10 +474,14 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
316
  inputs,
317
  **kwargs,
318
  ):
319
- """Transcribe audio with optional word-level timestamps and speaker diarization.
320
 
321
  Args:
322
  inputs: Audio input (file path, dict with array/sampling_rate, etc.)
 
 
 
 
323
  return_timestamps: If True, return word-level timestamps using forced alignment
324
  return_speakers: If True, return speaker labels for each word
325
  num_speakers: Exact number of speakers (if known, for diarization)
@@ -330,9 +492,13 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
330
 
331
  Returns:
332
  Dict with 'text' key, 'words' key if return_timestamps=True,
333
- and speaker labels on words if return_speakers=True
334
  """
335
  # Extract our params before super().__call__ (which will also call _sanitize_parameters)
 
 
 
 
336
  return_timestamps = kwargs.pop("return_timestamps", False)
337
  return_speakers = kwargs.pop("return_speakers", False)
338
  diarization_params = {
@@ -345,12 +511,25 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
345
  if return_speakers:
346
  return_timestamps = True
347
 
348
- # Store audio for timestamp alignment and diarization
349
- if return_timestamps or return_speakers:
350
- self._current_audio = self._extract_audio(inputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
351
 
352
- # Run standard transcription
353
- result = super().__call__(inputs, **kwargs)
354
 
355
  # Add timestamps if requested
356
  if return_timestamps and self._current_audio is not None:
@@ -423,6 +602,75 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
423
 
424
  return None
425
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
426
  def preprocess(self, inputs, **preprocess_params):
427
  # Handle dict with "array" key (from datasets)
428
  if isinstance(inputs, dict) and "array" in inputs:
 
282
  return words
283
 
284
 
285
+ class VoiceActivityDetector:
286
+ """Voice Activity Detection using pyannote for improved transcription quality.
287
+
288
+ Based on WhisperX implementation. Detects speech regions in audio and chunks
289
+ them for more accurate transcription of long audio files.
290
+ """
291
+
292
+ _model = None
293
+ _pipeline = None
294
+
295
+ @classmethod
296
+ def get_instance(cls, vad_onset: float = 0.5, vad_offset: float = 0.363):
297
+ """Get or create the VAD pipeline instance.
298
+
299
+ Args:
300
+ vad_onset: Threshold for speech start detection (default 0.5)
301
+ vad_offset: Threshold for speech end detection (default 0.363)
302
+ """
303
+ if cls._pipeline is None:
304
+ from pyannote.audio import Model
305
+ from pyannote.audio.pipelines import VoiceActivityDetection
306
+
307
+ # Load the segmentation model
308
+ cls._model = Model.from_pretrained(
309
+ "pyannote/segmentation-3.0",
310
+ )
311
+
312
+ # Create VAD pipeline with hyperparameters
313
+ cls._pipeline = VoiceActivityDetection(segmentation=cls._model)
314
+ cls._pipeline.instantiate({
315
+ "onset": vad_onset,
316
+ "offset": vad_offset,
317
+ "min_duration_on": 0.1, # Min speech duration (100ms)
318
+ "min_duration_off": 0.1, # Min silence duration (100ms)
319
+ })
320
+
321
+ # Move to GPU if available
322
+ if torch.cuda.is_available():
323
+ cls._pipeline.to(torch.device("cuda"))
324
+ elif torch.backends.mps.is_available():
325
+ cls._pipeline.to(torch.device("mps"))
326
+
327
+ return cls._pipeline
328
+
329
+ @classmethod
330
+ def detect(
331
+ cls,
332
+ audio: np.ndarray,
333
+ sample_rate: int = 16000,
334
+ vad_onset: float = 0.5,
335
+ vad_offset: float = 0.363,
336
+ ) -> list[dict]:
337
+ """Detect speech regions in audio.
338
+
339
+ Args:
340
+ audio: Audio waveform as numpy array
341
+ sample_rate: Audio sample rate (default 16000)
342
+ vad_onset: Threshold for speech start detection
343
+ vad_offset: Threshold for speech end detection
344
+
345
+ Returns:
346
+ List of dicts with 'start', 'end' keys (in seconds)
347
+ """
348
+ pipeline = cls.get_instance(vad_onset, vad_offset)
349
+
350
+ # Prepare audio input
351
+ waveform = torch.from_numpy(audio).float()
352
+ if waveform.dim() == 1:
353
+ waveform = waveform.unsqueeze(0)
354
+
355
+ audio_input = {"waveform": waveform, "sample_rate": sample_rate}
356
+
357
+ # Run VAD
358
+ vad_result = pipeline(audio_input)
359
+
360
+ # Convert to list of segments
361
+ segments = []
362
+ for speech_turn in vad_result.get_timeline():
363
+ segments.append({
364
+ "start": speech_turn.start,
365
+ "end": speech_turn.end,
366
+ })
367
+
368
+ return segments
369
+
370
+ @classmethod
371
+ def merge_chunks(
372
+ cls,
373
+ segments: list[dict],
374
+ chunk_size: float = 30.0,
375
+ ) -> list[dict]:
376
+ """Merge VAD segments into larger chunks for batched processing.
377
+
378
+ Args:
379
+ segments: List of VAD segments with 'start', 'end' keys
380
+ chunk_size: Maximum chunk duration in seconds (default 30)
381
+
382
+ Returns:
383
+ List of chunks with 'start', 'end', 'segments' keys
384
+ """
385
+ if not segments:
386
+ return []
387
+
388
+ merged = []
389
+ curr_start = segments[0]["start"]
390
+ curr_end = segments[0]["end"]
391
+ curr_segments = []
392
+
393
+ for seg in segments:
394
+ # If adding this segment exceeds chunk_size, finalize current chunk
395
+ if seg["end"] - curr_start > chunk_size and curr_segments:
396
+ merged.append({
397
+ "start": curr_start,
398
+ "end": curr_end,
399
+ "segments": curr_segments,
400
+ })
401
+ curr_start = seg["start"]
402
+ curr_segments = []
403
+
404
+ curr_end = seg["end"]
405
+ curr_segments.append((seg["start"], seg["end"]))
406
+
407
+ # Add final chunk
408
+ if curr_segments:
409
+ merged.append({
410
+ "start": curr_start,
411
+ "end": curr_end,
412
+ "segments": curr_segments,
413
+ })
414
+
415
+ return merged
416
+
417
+ @classmethod
418
+ def extract_chunk_audio(
419
+ cls,
420
+ audio: np.ndarray,
421
+ chunk: dict,
422
+ sample_rate: int = 16000,
423
+ ) -> np.ndarray:
424
+ """Extract audio for a specific chunk.
425
+
426
+ Args:
427
+ audio: Full audio waveform
428
+ chunk: Chunk dict with 'start', 'end' keys
429
+ sample_rate: Audio sample rate
430
+
431
+ Returns:
432
+ Audio chunk as numpy array
433
+ """
434
+ start_sample = int(chunk["start"] * sample_rate)
435
+ end_sample = int(chunk["end"] * sample_rate)
436
+ return audio[start_sample:end_sample]
437
+
438
+
439
  class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
440
  """ASR Pipeline for audio-to-text transcription."""
441
 
 
462
  kwargs.pop("min_speakers", None)
463
  kwargs.pop("max_speakers", None)
464
  kwargs.pop("hf_token", None)
465
+ kwargs.pop("use_vad", None)
466
+ kwargs.pop("vad_onset", None)
467
+ kwargs.pop("vad_offset", None)
468
+ kwargs.pop("chunk_size", None)
469
 
470
  return super()._sanitize_parameters(**kwargs)
471
 
 
474
  inputs,
475
  **kwargs,
476
  ):
477
+ """Transcribe audio with optional VAD, timestamps, and speaker diarization.
478
 
479
  Args:
480
  inputs: Audio input (file path, dict with array/sampling_rate, etc.)
481
+ use_vad: If True, use Voice Activity Detection to chunk audio (recommended for long audio)
482
+ vad_onset: VAD speech start threshold (default 0.5)
483
+ vad_offset: VAD speech end threshold (default 0.363)
484
+ chunk_size: Maximum chunk duration in seconds for VAD (default 30)
485
  return_timestamps: If True, return word-level timestamps using forced alignment
486
  return_speakers: If True, return speaker labels for each word
487
  num_speakers: Exact number of speakers (if known, for diarization)
 
492
 
493
  Returns:
494
  Dict with 'text' key, 'words' key if return_timestamps=True,
495
+ 'vad_segments' if use_vad=True, and speaker labels on words if return_speakers=True
496
  """
497
  # Extract our params before super().__call__ (which will also call _sanitize_parameters)
498
+ use_vad = kwargs.pop("use_vad", False)
499
+ vad_onset = kwargs.pop("vad_onset", 0.5)
500
+ vad_offset = kwargs.pop("vad_offset", 0.363)
501
+ chunk_size = kwargs.pop("chunk_size", 30.0)
502
  return_timestamps = kwargs.pop("return_timestamps", False)
503
  return_speakers = kwargs.pop("return_speakers", False)
504
  diarization_params = {
 
511
  if return_speakers:
512
  return_timestamps = True
513
 
514
+ # Extract audio for VAD, timestamps, and diarization
515
+ audio_data = self._extract_audio(inputs)
516
+
517
+ # Use VAD to chunk and transcribe long audio
518
+ if use_vad and audio_data is not None:
519
+ result = self._transcribe_with_vad(
520
+ audio_data,
521
+ vad_onset=vad_onset,
522
+ vad_offset=vad_offset,
523
+ chunk_size=chunk_size,
524
+ **kwargs,
525
+ )
526
+ else:
527
+ # Store audio for timestamp alignment and diarization
528
+ if return_timestamps or return_speakers:
529
+ self._current_audio = audio_data
530
 
531
+ # Run standard transcription
532
+ result = super().__call__(inputs, **kwargs)
533
 
534
  # Add timestamps if requested
535
  if return_timestamps and self._current_audio is not None:
 
602
 
603
  return None
604
 
605
+ def _transcribe_with_vad(
606
+ self,
607
+ audio_data: dict,
608
+ vad_onset: float = 0.5,
609
+ vad_offset: float = 0.363,
610
+ chunk_size: float = 30.0,
611
+ **kwargs,
612
+ ) -> dict:
613
+ """Transcribe audio using VAD to chunk long audio.
614
+
615
+ Args:
616
+ audio_data: Dict with 'array' and 'sampling_rate' keys
617
+ vad_onset: VAD speech start threshold
618
+ vad_offset: VAD speech end threshold
619
+ chunk_size: Maximum chunk duration in seconds
620
+ **kwargs: Additional arguments passed to transcription
621
+
622
+ Returns:
623
+ Dict with 'text', 'vad_segments', and 'chunks' keys
624
+ """
625
+ audio = audio_data["array"]
626
+ sample_rate = audio_data.get("sampling_rate", 16000)
627
+
628
+ # Run VAD to detect speech regions
629
+ vad_segments = VoiceActivityDetector.detect(
630
+ audio,
631
+ sample_rate=sample_rate,
632
+ vad_onset=vad_onset,
633
+ vad_offset=vad_offset,
634
+ )
635
+
636
+ if not vad_segments:
637
+ return {"text": "", "vad_segments": [], "chunks": []}
638
+
639
+ # Merge segments into chunks
640
+ chunks = VoiceActivityDetector.merge_chunks(vad_segments, chunk_size)
641
+
642
+ # Transcribe each chunk
643
+ all_text = []
644
+ chunk_results = []
645
+
646
+ for chunk in chunks:
647
+ # Extract chunk audio
648
+ chunk_audio = VoiceActivityDetector.extract_chunk_audio(
649
+ audio, chunk, sample_rate
650
+ )
651
+
652
+ # Transcribe chunk
653
+ chunk_input = {"raw": chunk_audio, "sampling_rate": sample_rate}
654
+ chunk_result = super().__call__(chunk_input, **kwargs)
655
+
656
+ chunk_text = chunk_result.get("text", "").strip()
657
+ all_text.append(chunk_text)
658
+
659
+ chunk_results.append({
660
+ "start": chunk["start"],
661
+ "end": chunk["end"],
662
+ "text": chunk_text,
663
+ })
664
+
665
+ # Store audio for potential timestamp/diarization
666
+ self._current_audio = audio_data
667
+
668
+ return {
669
+ "text": " ".join(all_text),
670
+ "vad_segments": vad_segments,
671
+ "chunks": chunk_results,
672
+ }
673
+
674
  def preprocess(self, inputs, **preprocess_params):
675
  # Handle dict with "array" key (from datasets)
676
  if isinstance(inputs, dict) and "array" in inputs: