| | import torchaudio |
| | from VAD.vad_iterator import VADIterator |
| | from baseHandler import BaseHandler |
| | import numpy as np |
| | import torch |
| | from rich.console import Console |
| |
|
| | from utils.utils import int2float |
| | from df.enhance import enhance, init_df |
| | import logging |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | console = Console() |
| |
|
| |
|
| | class VADHandler(BaseHandler): |
| | """ |
| | Handles voice activity detection. When voice activity is detected, audio will be accumulated until the end of speech is detected and then passed |
| | to the following part. |
| | """ |
| |
|
| | def setup( |
| | self, |
| | should_listen, |
| | thresh=0.3, |
| | sample_rate=16000, |
| | min_silence_ms=1000, |
| | min_speech_ms=500, |
| | max_speech_ms=float("inf"), |
| | speech_pad_ms=30, |
| | audio_enhancement=False, |
| | ): |
| | self.should_listen = should_listen |
| | self.sample_rate = sample_rate |
| | self.min_silence_ms = min_silence_ms |
| | self.min_speech_ms = min_speech_ms |
| | self.max_speech_ms = max_speech_ms |
| | self.model, _ = torch.hub.load("snakers4/silero-vad", "silero_vad") |
| | self.iterator = VADIterator( |
| | self.model, |
| | threshold=thresh, |
| | sampling_rate=sample_rate, |
| | min_silence_duration_ms=min_silence_ms, |
| | speech_pad_ms=speech_pad_ms, |
| | ) |
| | self.audio_enhancement = audio_enhancement |
| | if audio_enhancement: |
| | self.enhanced_model, self.df_state, _ = init_df() |
| |
|
| | def process(self, audio_chunk): |
| | audio_int16 = np.frombuffer(audio_chunk, dtype=np.int16) |
| | audio_float32 = int2float(audio_int16) |
| | vad_output = self.iterator(torch.from_numpy(audio_float32)) |
| | if vad_output is not None and len(vad_output) != 0: |
| | console.print("VAD: end of speech detected") |
| | logger.debug("VAD: end of speech detected") |
| | array = torch.cat(vad_output).cpu().numpy() |
| | duration_ms = len(array) / self.sample_rate * 1000 |
| | if duration_ms < self.min_speech_ms or duration_ms > self.max_speech_ms: |
| | console.print( |
| | f"audio input of duration: {len(array) / self.sample_rate}s, skipping" |
| | ) |
| | logger.debug( |
| | f"audio input of duration: {len(array) / self.sample_rate}s, skipping" |
| | ) |
| | else: |
| | self.should_listen.clear() |
| | logger.debug("Stop listening") |
| | if self.audio_enhancement: |
| | if self.sample_rate != self.df_state.sr(): |
| | audio_float32 = torchaudio.functional.resample( |
| | torch.from_numpy(array), |
| | orig_freq=self.sample_rate, |
| | new_freq=self.df_state.sr(), |
| | ) |
| | enhanced = enhance( |
| | self.enhanced_model, |
| | self.df_state, |
| | audio_float32.unsqueeze(0), |
| | ) |
| | enhanced = torchaudio.functional.resample( |
| | enhanced, |
| | orig_freq=self.df_state.sr(), |
| | new_freq=self.sample_rate, |
| | ) |
| | else: |
| | enhanced = enhance( |
| | self.enhanced_model, self.df_state, audio_float32 |
| | ) |
| | array = enhanced.numpy().squeeze() |
| | yield array |
| |
|
| | @property |
| | def min_time_to_debug(self): |
| | return 0.00001 |
| |
|