| import torch |
| import torchaudio |
| import numpy as np |
| from speechbrain.inference import VAD |
| from typing import List, Tuple, Optional |
| import queue |
| import threading |
| import time |
| from config.settings import settings |
|
|
| class SpeechBrainVAD: |
| def __init__(self): |
| self.vad_model = None |
| self.sample_rate = settings.SAMPLE_RATE |
| self.threshold = settings.VAD_THRESHOLD |
| self.min_silence_duration = settings.VAD_MIN_SILENCE_DURATION |
| self.speech_pad_duration = settings.VAD_SPEECH_PAD_DURATION |
| self.is_running = False |
| self.audio_queue = queue.Queue() |
| self.speech_buffer = [] |
| self.silence_start_time = None |
| self.callback = None |
| |
| self._initialize_model() |
| |
| def _initialize_model(self): |
| """Khởi tạo mô hình VAD từ SpeechBrain""" |
| try: |
| print("🔄 Đang tải mô hình SpeechBrain VAD...") |
| self.vad_model = VAD.from_hparams( |
| source=settings.VAD_MODEL, |
| savedir=f"pretrained_models/{settings.VAD_MODEL}" |
| ) |
| print("✅ Đã tải mô hình VAD thành công") |
| except Exception as e: |
| print(f"❌ Lỗi tải mô hình VAD: {e}") |
| self.vad_model = None |
| |
| def preprocess_audio(self, audio_data: np.ndarray, original_sr: int) -> np.ndarray: |
| """Tiền xử lý audio cho VAD""" |
| if original_sr != self.sample_rate: |
| |
| audio_tensor = torch.from_numpy(audio_data).float() |
| if len(audio_tensor.shape) > 1: |
| audio_tensor = audio_tensor.mean(dim=0) |
| |
| resampler = torchaudio.transforms.Resample( |
| orig_freq=original_sr, |
| new_freq=self.sample_rate |
| ) |
| audio_tensor = resampler(audio_tensor) |
| audio_data = audio_tensor.numpy() |
| |
| |
| if np.max(np.abs(audio_data)) > 0: |
| audio_data = audio_data / np.max(np.abs(audio_data)) |
| |
| return audio_data |
| |
| def detect_voice_activity(self, audio_chunk: np.ndarray) -> bool: |
| """Phát hiện hoạt động giọng nói trong audio chunk""" |
| if self.vad_model is None: |
| |
| return self._energy_based_vad(audio_chunk) |
| |
| try: |
| |
| audio_tensor = torch.from_numpy(audio_chunk).float().unsqueeze(0) |
| |
| |
| with torch.no_grad(): |
| prob = self.vad_model.get_speech_prob_chunk(audio_tensor) |
| |
| return prob.item() > self.threshold |
| |
| except Exception as e: |
| print(f"❌ Lỗi VAD detection: {e}") |
| return self._energy_based_vad(audio_chunk) |
| |
| def _energy_based_vad(self, audio_chunk: np.ndarray) -> bool: |
| """Fallback VAD dựa trên năng lượng âm thanh""" |
| energy = np.mean(audio_chunk ** 2) |
| return energy > 0.01 |
| |
| def process_stream(self, audio_chunk: np.ndarray, original_sr: int): |
| """Xử lý audio stream real-time""" |
| if not self.is_running: |
| return |
| |
| |
| processed_audio = self.preprocess_audio(audio_chunk, original_sr) |
| |
| |
| is_speech = self.detect_voice_activity(processed_audio) |
| |
| if is_speech: |
| self.silence_start_time = None |
| self.speech_buffer.extend(processed_audio) |
| print("🎤 Đang nói...") |
| else: |
| |
| if self.silence_start_time is None: |
| self.silence_start_time = time.time() |
| elif len(self.speech_buffer) > 0: |
| silence_duration = time.time() - self.silence_start_time |
| if silence_duration >= self.min_silence_duration: |
| |
| self._process_speech_segment() |
| |
| return is_speech |
| |
| def _process_speech_segment(self): |
| """Xử lý segment giọng nói khi kết thúc""" |
| if len(self.speech_buffer) == 0: |
| return |
| |
| |
| speech_audio = np.array(self.speech_buffer) |
| |
| |
| if self.callback and callable(self.callback): |
| self.callback(speech_audio, self.sample_rate) |
| |
| |
| self.speech_buffer = [] |
| self.silence_start_time = None |
| |
| print("✅ Đã xử lý segment giọng nói") |
| |
| def start_stream(self, callback: callable): |
| """Bắt đầu xử lý stream""" |
| self.is_running = True |
| self.callback = callback |
| self.speech_buffer = [] |
| self.silence_start_time = None |
| print("🎙️ Bắt đầu stream VAD...") |
| |
| def stop_stream(self): |
| """Dừng xử lý stream""" |
| self.is_running = False |
| |
| if len(self.speech_buffer) > 0: |
| self._process_speech_segment() |
| print("🛑 Đã dừng stream VAD") |
| |
| def get_audio_chunk_from_stream(self, stream, chunk_size: int = 1024): |
| """Lấy audio chunk từ stream (for microphone input)""" |
| try: |
| data = stream.read(chunk_size, exception_on_overflow=False) |
| audio_data = np.frombuffer(data, dtype=np.int16) |
| return audio_data.astype(np.float32) / 32768.0 |
| except Exception as e: |
| print(f"❌ Lỗi đọc audio stream: {e}") |
| return None |