datbkpro commited on
Commit
b96139a
·
verified ·
1 Parent(s): 14e4b97

Create silero_vad

Browse files
Files changed (1) hide show
  1. core/silero_vad +194 -0
core/silero_vad ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from typing import Optional, Callable
4
+ from config.settings import settings
5
+
6
+ class SileroVAD:
7
+ def __init__(self):
8
+ self.model = None
9
+ self.sample_rate = 16000 # Silero VAD yêu cầu 16kHz
10
+ self.is_streaming = False
11
+ self.speech_callback = None
12
+ self.audio_buffer = []
13
+ self._initialize_model()
14
+
15
+ def _initialize_model(self):
16
+ """Khởi tạo Silero VAD model"""
17
+ try:
18
+ print("🔄 Đang tải Silero VAD model...")
19
+ torch.hub.download_url_to_file(
20
+ 'https://raw.githubusercontent.com/snakers4/silero-vad/master/files/model.jit',
21
+ 'silero_vad.jit'
22
+ )
23
+ self.model = torch.jit.load('silero_vad.jit')
24
+ self.model.eval()
25
+ print("✅ Đã tải Silero VAD model thành công")
26
+ except Exception as e:
27
+ print(f"❌ Lỗi tải Silero VAD model: {e}")
28
+ self.model = None
29
+
30
+ def start_stream(self, speech_callback: Callable):
31
+ """Bắt đầu stream với VAD"""
32
+ if self.model is None:
33
+ print("❌ Silero VAD model chưa được khởi tạo")
34
+ return False
35
+
36
+ self.is_streaming = True
37
+ self.speech_callback = speech_callback
38
+ self.audio_buffer = []
39
+ print("🎙️ Bắt đầu Silero VAD streaming...")
40
+ return True
41
+
42
+ def stop_stream(self):
43
+ """Dừng stream"""
44
+ self.is_streaming = False
45
+ self.speech_callback = None
46
+ self.audio_buffer = []
47
+ print("🛑 Đã dừng Silero VAD streaming")
48
+
49
+ def process_stream(self, audio_chunk: np.ndarray, sample_rate: int):
50
+ """Xử lý audio chunk với Silero VAD"""
51
+ if not self.is_streaming or self.model is None:
52
+ return
53
+
54
+ try:
55
+ # Resample nếu cần (Silero yêu cầu 16kHz)
56
+ if sample_rate != self.sample_rate:
57
+ audio_chunk = self._resample_audio(audio_chunk, sample_rate, self.sample_rate)
58
+
59
+ # Thêm vào buffer
60
+ self.audio_buffer.extend(audio_chunk)
61
+
62
+ # Xử lý khi buffer đủ lớn (1 giây - Silero làm việc tốt với chunk nhỏ)
63
+ buffer_duration = len(self.audio_buffer) / self.sample_rate
64
+ if buffer_duration >= 1.0: # Giảm từ 2.0 xuống 1.0 giây
65
+ self._process_buffer()
66
+
67
+ except Exception as e:
68
+ print(f"❌ Lỗi xử lý Silero VAD: {e}")
69
+
70
+ def _process_buffer(self):
71
+ """Xử lý buffer audio với Silero VAD"""
72
+ try:
73
+ # Silero VAD làm việc tốt với chunk 1 giây
74
+ chunk_size = self.sample_rate # 1 giây
75
+ if len(self.audio_buffer) < chunk_size:
76
+ return
77
+
78
+ # Lấy chunk 1 giây
79
+ audio_chunk = np.array(self.audio_buffer[:chunk_size])
80
+
81
+ # Chuẩn hóa audio cho Silero
82
+ if audio_chunk.dtype != np.float32:
83
+ audio_chunk = audio_chunk.astype(np.float32) / 32768.0 # Normalize to [-1, 1]
84
+
85
+ # Chuyển thành tensor
86
+ audio_tensor = torch.from_numpy(audio_chunk).unsqueeze(0)
87
+
88
+ # Phát hiện speech với Silero VAD
89
+ with torch.no_grad():
90
+ speech_prob = self.model(audio_tensor, self.sample_rate).item()
91
+
92
+ print(f"🎯 Silero VAD speech probability: {speech_prob:.3f}")
93
+
94
+ # Ngưỡng phát hiện speech (có thể điều chỉnh)
95
+ if speech_prob > settings.VAD_THRESHOLD:
96
+ print(f"🎯 Silero VAD phát hiện speech: {speech_prob:.3f}")
97
+
98
+ # Gọi callback với speech segment
99
+ if self.speech_callback:
100
+ self.speech_callback(audio_chunk, self.sample_rate)
101
+
102
+ # Giữ lại 0.3 giây cuối để overlap (Silero nhạy hơn)
103
+ keep_samples = int(self.sample_rate * 0.3)
104
+ if len(self.audio_buffer) > keep_samples:
105
+ self.audio_buffer = self.audio_buffer[-keep_samples:]
106
+ else:
107
+ self.audio_buffer = []
108
+
109
+ except Exception as e:
110
+ print(f"❌ Lỗi xử lý Silero VAD buffer: {e}")
111
+ self.audio_buffer = []
112
+
113
+ def _resample_audio(self, audio: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray:
114
+ """Resample audio nếu cần"""
115
+ if orig_sr == target_sr:
116
+ return audio
117
+
118
+ try:
119
+ # Simple resampling bằng interpolation
120
+ orig_length = len(audio)
121
+ new_length = int(orig_length * target_sr / orig_sr)
122
+
123
+ # Linear interpolation
124
+ x_old = np.linspace(0, 1, orig_length)
125
+ x_new = np.linspace(0, 1, new_length)
126
+ resampled_audio = np.interp(x_new, x_old, audio)
127
+
128
+ return resampled_audio
129
+ except Exception as e:
130
+ print(f"⚠️ Lỗi resample: {e}")
131
+ return audio
132
+
133
+ def is_speech(self, audio_chunk: np.ndarray, sample_rate: int) -> bool:
134
+ """Kiểm tra xem audio chunk có phải là speech không"""
135
+ if self.model is None:
136
+ return True # Fallback: luôn coi là speech
137
+
138
+ try:
139
+ # Resample nếu cần
140
+ if sample_rate != self.sample_rate:
141
+ audio_chunk = self._resample_audio(audio_chunk, sample_rate, self.sample_rate)
142
+
143
+ # Chuẩn hóa audio
144
+ if audio_chunk.dtype != np.float32:
145
+ audio_chunk = audio_chunk.astype(np.float32) / 32768.0
146
+
147
+ # Đảm bảo độ dài phù hợp
148
+ if len(audio_chunk) < 512: # Silero cần ít nhất 512 samples
149
+ padding = np.zeros(512 - len(audio_chunk))
150
+ audio_chunk = np.concatenate([audio_chunk, padding])
151
+
152
+ # Chuyển thành tensor
153
+ audio_tensor = torch.from_numpy(audio_chunk).unsqueeze(0)
154
+
155
+ # Phát hiện speech
156
+ with torch.no_grad():
157
+ speech_prob = self.model(audio_tensor, self.sample_rate).item()
158
+
159
+ # Kiểm tra ngưỡng
160
+ return speech_prob > settings.VAD_THRESHOLD
161
+
162
+ except Exception as e:
163
+ print(f"❌ Lỗi kiểm tra speech với Silero: {e}")
164
+ return True
165
+
166
+ def get_speech_probability(self, audio_chunk: np.ndarray, sample_rate: int) -> float:
167
+ """Lấy xác suất speech (dùng cho debugging)"""
168
+ if self.model is None:
169
+ return 0.0
170
+
171
+ try:
172
+ # Resample nếu cần
173
+ if sample_rate != self.sample_rate:
174
+ audio_chunk = self._resample_audio(audio_chunk, sample_rate, self.sample_rate)
175
+
176
+ # Chuẩn hóa audio
177
+ if audio_chunk.dtype != np.float32:
178
+ audio_chunk = audio_chunk.astype(np.float32) / 32768.0
179
+
180
+ # Đảm bảo độ dài phù hợp
181
+ if len(audio_chunk) < 512:
182
+ padding = np.zeros(512 - len(audio_chunk))
183
+ audio_chunk = np.concatenate([audio_chunk, padding])
184
+
185
+ # Chuyển thành tensor
186
+ audio_tensor = torch.from_numpy(audio_chunk).unsqueeze(0)
187
+
188
+ # Phát hiện speech
189
+ with torch.no_grad():
190
+ return self.model(audio_tensor, self.sample_rate).item()
191
+
192
+ except Exception as e:
193
+ print(f"❌ Lỗi lấy speech probability: {e}")
194
+ return 0.0