Update custom model files, README, and requirements
Browse files- asr_pipeline.py +255 -7
asr_pipeline.py
CHANGED
|
@@ -282,6 +282,160 @@ class SpeakerDiarizer:
|
|
| 282 |
return words
|
| 283 |
|
| 284 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 285 |
class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
| 286 |
"""ASR Pipeline for audio-to-text transcription."""
|
| 287 |
|
|
@@ -308,6 +462,10 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
|
| 308 |
kwargs.pop("min_speakers", None)
|
| 309 |
kwargs.pop("max_speakers", None)
|
| 310 |
kwargs.pop("hf_token", None)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
|
| 312 |
return super()._sanitize_parameters(**kwargs)
|
| 313 |
|
|
@@ -316,10 +474,14 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
|
| 316 |
inputs,
|
| 317 |
**kwargs,
|
| 318 |
):
|
| 319 |
-
"""Transcribe audio with optional
|
| 320 |
|
| 321 |
Args:
|
| 322 |
inputs: Audio input (file path, dict with array/sampling_rate, etc.)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 323 |
return_timestamps: If True, return word-level timestamps using forced alignment
|
| 324 |
return_speakers: If True, return speaker labels for each word
|
| 325 |
num_speakers: Exact number of speakers (if known, for diarization)
|
|
@@ -330,9 +492,13 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
|
| 330 |
|
| 331 |
Returns:
|
| 332 |
Dict with 'text' key, 'words' key if return_timestamps=True,
|
| 333 |
-
and speaker labels on words if return_speakers=True
|
| 334 |
"""
|
| 335 |
# Extract our params before super().__call__ (which will also call _sanitize_parameters)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
return_timestamps = kwargs.pop("return_timestamps", False)
|
| 337 |
return_speakers = kwargs.pop("return_speakers", False)
|
| 338 |
diarization_params = {
|
|
@@ -345,12 +511,25 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
|
| 345 |
if return_speakers:
|
| 346 |
return_timestamps = True
|
| 347 |
|
| 348 |
-
#
|
| 349 |
-
|
| 350 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 351 |
|
| 352 |
-
|
| 353 |
-
|
| 354 |
|
| 355 |
# Add timestamps if requested
|
| 356 |
if return_timestamps and self._current_audio is not None:
|
|
@@ -423,6 +602,75 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
|
| 423 |
|
| 424 |
return None
|
| 425 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 426 |
def preprocess(self, inputs, **preprocess_params):
|
| 427 |
# Handle dict with "array" key (from datasets)
|
| 428 |
if isinstance(inputs, dict) and "array" in inputs:
|
|
|
|
| 282 |
return words
|
| 283 |
|
| 284 |
|
| 285 |
+
class VoiceActivityDetector:
|
| 286 |
+
"""Voice Activity Detection using pyannote for improved transcription quality.
|
| 287 |
+
|
| 288 |
+
Based on WhisperX implementation. Detects speech regions in audio and chunks
|
| 289 |
+
them for more accurate transcription of long audio files.
|
| 290 |
+
"""
|
| 291 |
+
|
| 292 |
+
_model = None
|
| 293 |
+
_pipeline = None
|
| 294 |
+
|
| 295 |
+
@classmethod
|
| 296 |
+
def get_instance(cls, vad_onset: float = 0.5, vad_offset: float = 0.363):
|
| 297 |
+
"""Get or create the VAD pipeline instance.
|
| 298 |
+
|
| 299 |
+
Args:
|
| 300 |
+
vad_onset: Threshold for speech start detection (default 0.5)
|
| 301 |
+
vad_offset: Threshold for speech end detection (default 0.363)
|
| 302 |
+
"""
|
| 303 |
+
if cls._pipeline is None:
|
| 304 |
+
from pyannote.audio import Model
|
| 305 |
+
from pyannote.audio.pipelines import VoiceActivityDetection
|
| 306 |
+
|
| 307 |
+
# Load the segmentation model
|
| 308 |
+
cls._model = Model.from_pretrained(
|
| 309 |
+
"pyannote/segmentation-3.0",
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
# Create VAD pipeline with hyperparameters
|
| 313 |
+
cls._pipeline = VoiceActivityDetection(segmentation=cls._model)
|
| 314 |
+
cls._pipeline.instantiate({
|
| 315 |
+
"onset": vad_onset,
|
| 316 |
+
"offset": vad_offset,
|
| 317 |
+
"min_duration_on": 0.1, # Min speech duration (100ms)
|
| 318 |
+
"min_duration_off": 0.1, # Min silence duration (100ms)
|
| 319 |
+
})
|
| 320 |
+
|
| 321 |
+
# Move to GPU if available
|
| 322 |
+
if torch.cuda.is_available():
|
| 323 |
+
cls._pipeline.to(torch.device("cuda"))
|
| 324 |
+
elif torch.backends.mps.is_available():
|
| 325 |
+
cls._pipeline.to(torch.device("mps"))
|
| 326 |
+
|
| 327 |
+
return cls._pipeline
|
| 328 |
+
|
| 329 |
+
@classmethod
|
| 330 |
+
def detect(
|
| 331 |
+
cls,
|
| 332 |
+
audio: np.ndarray,
|
| 333 |
+
sample_rate: int = 16000,
|
| 334 |
+
vad_onset: float = 0.5,
|
| 335 |
+
vad_offset: float = 0.363,
|
| 336 |
+
) -> list[dict]:
|
| 337 |
+
"""Detect speech regions in audio.
|
| 338 |
+
|
| 339 |
+
Args:
|
| 340 |
+
audio: Audio waveform as numpy array
|
| 341 |
+
sample_rate: Audio sample rate (default 16000)
|
| 342 |
+
vad_onset: Threshold for speech start detection
|
| 343 |
+
vad_offset: Threshold for speech end detection
|
| 344 |
+
|
| 345 |
+
Returns:
|
| 346 |
+
List of dicts with 'start', 'end' keys (in seconds)
|
| 347 |
+
"""
|
| 348 |
+
pipeline = cls.get_instance(vad_onset, vad_offset)
|
| 349 |
+
|
| 350 |
+
# Prepare audio input
|
| 351 |
+
waveform = torch.from_numpy(audio).float()
|
| 352 |
+
if waveform.dim() == 1:
|
| 353 |
+
waveform = waveform.unsqueeze(0)
|
| 354 |
+
|
| 355 |
+
audio_input = {"waveform": waveform, "sample_rate": sample_rate}
|
| 356 |
+
|
| 357 |
+
# Run VAD
|
| 358 |
+
vad_result = pipeline(audio_input)
|
| 359 |
+
|
| 360 |
+
# Convert to list of segments
|
| 361 |
+
segments = []
|
| 362 |
+
for speech_turn in vad_result.get_timeline():
|
| 363 |
+
segments.append({
|
| 364 |
+
"start": speech_turn.start,
|
| 365 |
+
"end": speech_turn.end,
|
| 366 |
+
})
|
| 367 |
+
|
| 368 |
+
return segments
|
| 369 |
+
|
| 370 |
+
@classmethod
|
| 371 |
+
def merge_chunks(
|
| 372 |
+
cls,
|
| 373 |
+
segments: list[dict],
|
| 374 |
+
chunk_size: float = 30.0,
|
| 375 |
+
) -> list[dict]:
|
| 376 |
+
"""Merge VAD segments into larger chunks for batched processing.
|
| 377 |
+
|
| 378 |
+
Args:
|
| 379 |
+
segments: List of VAD segments with 'start', 'end' keys
|
| 380 |
+
chunk_size: Maximum chunk duration in seconds (default 30)
|
| 381 |
+
|
| 382 |
+
Returns:
|
| 383 |
+
List of chunks with 'start', 'end', 'segments' keys
|
| 384 |
+
"""
|
| 385 |
+
if not segments:
|
| 386 |
+
return []
|
| 387 |
+
|
| 388 |
+
merged = []
|
| 389 |
+
curr_start = segments[0]["start"]
|
| 390 |
+
curr_end = segments[0]["end"]
|
| 391 |
+
curr_segments = []
|
| 392 |
+
|
| 393 |
+
for seg in segments:
|
| 394 |
+
# If adding this segment exceeds chunk_size, finalize current chunk
|
| 395 |
+
if seg["end"] - curr_start > chunk_size and curr_segments:
|
| 396 |
+
merged.append({
|
| 397 |
+
"start": curr_start,
|
| 398 |
+
"end": curr_end,
|
| 399 |
+
"segments": curr_segments,
|
| 400 |
+
})
|
| 401 |
+
curr_start = seg["start"]
|
| 402 |
+
curr_segments = []
|
| 403 |
+
|
| 404 |
+
curr_end = seg["end"]
|
| 405 |
+
curr_segments.append((seg["start"], seg["end"]))
|
| 406 |
+
|
| 407 |
+
# Add final chunk
|
| 408 |
+
if curr_segments:
|
| 409 |
+
merged.append({
|
| 410 |
+
"start": curr_start,
|
| 411 |
+
"end": curr_end,
|
| 412 |
+
"segments": curr_segments,
|
| 413 |
+
})
|
| 414 |
+
|
| 415 |
+
return merged
|
| 416 |
+
|
| 417 |
+
@classmethod
|
| 418 |
+
def extract_chunk_audio(
|
| 419 |
+
cls,
|
| 420 |
+
audio: np.ndarray,
|
| 421 |
+
chunk: dict,
|
| 422 |
+
sample_rate: int = 16000,
|
| 423 |
+
) -> np.ndarray:
|
| 424 |
+
"""Extract audio for a specific chunk.
|
| 425 |
+
|
| 426 |
+
Args:
|
| 427 |
+
audio: Full audio waveform
|
| 428 |
+
chunk: Chunk dict with 'start', 'end' keys
|
| 429 |
+
sample_rate: Audio sample rate
|
| 430 |
+
|
| 431 |
+
Returns:
|
| 432 |
+
Audio chunk as numpy array
|
| 433 |
+
"""
|
| 434 |
+
start_sample = int(chunk["start"] * sample_rate)
|
| 435 |
+
end_sample = int(chunk["end"] * sample_rate)
|
| 436 |
+
return audio[start_sample:end_sample]
|
| 437 |
+
|
| 438 |
+
|
| 439 |
class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
| 440 |
"""ASR Pipeline for audio-to-text transcription."""
|
| 441 |
|
|
|
|
| 462 |
kwargs.pop("min_speakers", None)
|
| 463 |
kwargs.pop("max_speakers", None)
|
| 464 |
kwargs.pop("hf_token", None)
|
| 465 |
+
kwargs.pop("use_vad", None)
|
| 466 |
+
kwargs.pop("vad_onset", None)
|
| 467 |
+
kwargs.pop("vad_offset", None)
|
| 468 |
+
kwargs.pop("chunk_size", None)
|
| 469 |
|
| 470 |
return super()._sanitize_parameters(**kwargs)
|
| 471 |
|
|
|
|
| 474 |
inputs,
|
| 475 |
**kwargs,
|
| 476 |
):
|
| 477 |
+
"""Transcribe audio with optional VAD, timestamps, and speaker diarization.
|
| 478 |
|
| 479 |
Args:
|
| 480 |
inputs: Audio input (file path, dict with array/sampling_rate, etc.)
|
| 481 |
+
use_vad: If True, use Voice Activity Detection to chunk audio (recommended for long audio)
|
| 482 |
+
vad_onset: VAD speech start threshold (default 0.5)
|
| 483 |
+
vad_offset: VAD speech end threshold (default 0.363)
|
| 484 |
+
chunk_size: Maximum chunk duration in seconds for VAD (default 30)
|
| 485 |
return_timestamps: If True, return word-level timestamps using forced alignment
|
| 486 |
return_speakers: If True, return speaker labels for each word
|
| 487 |
num_speakers: Exact number of speakers (if known, for diarization)
|
|
|
|
| 492 |
|
| 493 |
Returns:
|
| 494 |
Dict with 'text' key, 'words' key if return_timestamps=True,
|
| 495 |
+
'vad_segments' if use_vad=True, and speaker labels on words if return_speakers=True
|
| 496 |
"""
|
| 497 |
# Extract our params before super().__call__ (which will also call _sanitize_parameters)
|
| 498 |
+
use_vad = kwargs.pop("use_vad", False)
|
| 499 |
+
vad_onset = kwargs.pop("vad_onset", 0.5)
|
| 500 |
+
vad_offset = kwargs.pop("vad_offset", 0.363)
|
| 501 |
+
chunk_size = kwargs.pop("chunk_size", 30.0)
|
| 502 |
return_timestamps = kwargs.pop("return_timestamps", False)
|
| 503 |
return_speakers = kwargs.pop("return_speakers", False)
|
| 504 |
diarization_params = {
|
|
|
|
| 511 |
if return_speakers:
|
| 512 |
return_timestamps = True
|
| 513 |
|
| 514 |
+
# Extract audio for VAD, timestamps, and diarization
|
| 515 |
+
audio_data = self._extract_audio(inputs)
|
| 516 |
+
|
| 517 |
+
# Use VAD to chunk and transcribe long audio
|
| 518 |
+
if use_vad and audio_data is not None:
|
| 519 |
+
result = self._transcribe_with_vad(
|
| 520 |
+
audio_data,
|
| 521 |
+
vad_onset=vad_onset,
|
| 522 |
+
vad_offset=vad_offset,
|
| 523 |
+
chunk_size=chunk_size,
|
| 524 |
+
**kwargs,
|
| 525 |
+
)
|
| 526 |
+
else:
|
| 527 |
+
# Store audio for timestamp alignment and diarization
|
| 528 |
+
if return_timestamps or return_speakers:
|
| 529 |
+
self._current_audio = audio_data
|
| 530 |
|
| 531 |
+
# Run standard transcription
|
| 532 |
+
result = super().__call__(inputs, **kwargs)
|
| 533 |
|
| 534 |
# Add timestamps if requested
|
| 535 |
if return_timestamps and self._current_audio is not None:
|
|
|
|
| 602 |
|
| 603 |
return None
|
| 604 |
|
| 605 |
+
def _transcribe_with_vad(
|
| 606 |
+
self,
|
| 607 |
+
audio_data: dict,
|
| 608 |
+
vad_onset: float = 0.5,
|
| 609 |
+
vad_offset: float = 0.363,
|
| 610 |
+
chunk_size: float = 30.0,
|
| 611 |
+
**kwargs,
|
| 612 |
+
) -> dict:
|
| 613 |
+
"""Transcribe audio using VAD to chunk long audio.
|
| 614 |
+
|
| 615 |
+
Args:
|
| 616 |
+
audio_data: Dict with 'array' and 'sampling_rate' keys
|
| 617 |
+
vad_onset: VAD speech start threshold
|
| 618 |
+
vad_offset: VAD speech end threshold
|
| 619 |
+
chunk_size: Maximum chunk duration in seconds
|
| 620 |
+
**kwargs: Additional arguments passed to transcription
|
| 621 |
+
|
| 622 |
+
Returns:
|
| 623 |
+
Dict with 'text', 'vad_segments', and 'chunks' keys
|
| 624 |
+
"""
|
| 625 |
+
audio = audio_data["array"]
|
| 626 |
+
sample_rate = audio_data.get("sampling_rate", 16000)
|
| 627 |
+
|
| 628 |
+
# Run VAD to detect speech regions
|
| 629 |
+
vad_segments = VoiceActivityDetector.detect(
|
| 630 |
+
audio,
|
| 631 |
+
sample_rate=sample_rate,
|
| 632 |
+
vad_onset=vad_onset,
|
| 633 |
+
vad_offset=vad_offset,
|
| 634 |
+
)
|
| 635 |
+
|
| 636 |
+
if not vad_segments:
|
| 637 |
+
return {"text": "", "vad_segments": [], "chunks": []}
|
| 638 |
+
|
| 639 |
+
# Merge segments into chunks
|
| 640 |
+
chunks = VoiceActivityDetector.merge_chunks(vad_segments, chunk_size)
|
| 641 |
+
|
| 642 |
+
# Transcribe each chunk
|
| 643 |
+
all_text = []
|
| 644 |
+
chunk_results = []
|
| 645 |
+
|
| 646 |
+
for chunk in chunks:
|
| 647 |
+
# Extract chunk audio
|
| 648 |
+
chunk_audio = VoiceActivityDetector.extract_chunk_audio(
|
| 649 |
+
audio, chunk, sample_rate
|
| 650 |
+
)
|
| 651 |
+
|
| 652 |
+
# Transcribe chunk
|
| 653 |
+
chunk_input = {"raw": chunk_audio, "sampling_rate": sample_rate}
|
| 654 |
+
chunk_result = super().__call__(chunk_input, **kwargs)
|
| 655 |
+
|
| 656 |
+
chunk_text = chunk_result.get("text", "").strip()
|
| 657 |
+
all_text.append(chunk_text)
|
| 658 |
+
|
| 659 |
+
chunk_results.append({
|
| 660 |
+
"start": chunk["start"],
|
| 661 |
+
"end": chunk["end"],
|
| 662 |
+
"text": chunk_text,
|
| 663 |
+
})
|
| 664 |
+
|
| 665 |
+
# Store audio for potential timestamp/diarization
|
| 666 |
+
self._current_audio = audio_data
|
| 667 |
+
|
| 668 |
+
return {
|
| 669 |
+
"text": " ".join(all_text),
|
| 670 |
+
"vad_segments": vad_segments,
|
| 671 |
+
"chunks": chunk_results,
|
| 672 |
+
}
|
| 673 |
+
|
| 674 |
def preprocess(self, inputs, **preprocess_params):
|
| 675 |
# Handle dict with "array" key (from datasets)
|
| 676 |
if isinstance(inputs, dict) and "array" in inputs:
|