datbkpro commited on
Commit
093eb67
·
verified ·
1 Parent(s): 6323ac8

Update core/silero_vad.py

Browse files
Files changed (1) hide show
  1. core/silero_vad.py +1 -240
core/silero_vad.py CHANGED
@@ -1,243 +1,4 @@
1
- # import torch
2
- # import numpy as np
3
- # from typing import Optional, Callable
4
- # from config.settings import settings
5
- # import os
6
-
7
- # class SileroVAD:
8
- # def __init__(self):
9
- # self.model = None
10
- # self.sample_rate = 16000
11
- # self.is_streaming = False
12
- # self.speech_callback = None
13
- # self.audio_buffer = []
14
- # self._initialize_model()
15
-
16
- # def _initialize_model(self):
17
- # """Khởi tạo Silero VAD model sử dụng torch.hub"""
18
- # try:
19
- # print("🔄 Đang tải Silero VAD model từ torch.hub...")
20
-
21
- # # Sử dụng torch.hub để load model (cách chính thức)
22
- # self.model = torch.hub.load(
23
- # repo_or_dir=settings.VAD_MODEL,
24
- # model='silero_vad',
25
- # force_reload=False, # Sử dụng cache nếu có
26
- # trust_repo=True
27
- # )
28
-
29
- # print("✅ Đã tải Silero VAD model thành công")
30
-
31
- # except Exception as e:
32
- # print(f"❌ Lỗi tải Silero VAD model: {e}")
33
- # print("🔄 Đang thử cách tải thay thế...")
34
- # self._initialize_model_fallback()
35
-
36
- # def _initialize_model_fallback(self):
37
- # """Fallback method nếu cách chính thức không hoạt động"""
38
- # try:
39
- # # Cách 2: Sử dụng direct download
40
- # model_urls = {
41
- # 'silero_vad.jit': 'https://github.com/snakers4/silero-vad/raw/master/files/silero_vad.jit'
42
- # }
43
-
44
- # # Tạo thư mục cache
45
- # os.makedirs('./models', exist_ok=True)
46
- # model_path = './models/silero_vad.jit'
47
-
48
- # if not os.path.exists(model_path):
49
- # print("📥 Đang download Silero VAD model...")
50
- # torch.hub.download_url_to_file(
51
- # model_urls['silero_vad.jit'],
52
- # model_path
53
- # )
54
-
55
- # # Load model
56
- # self.model = torch.jit.load(model_path)
57
- # self.model.eval()
58
- # print("✅ Đã tải Silero VAD model thành công (fallback)")
59
-
60
- # except Exception as e:
61
- # print(f"❌ Lỗi tải Silero VAD model fallback: {e}")
62
- # self.model = None
63
-
64
- # def start_stream(self, speech_callback: Callable):
65
- # """Bắt đầu stream với VAD"""
66
- # if self.model is None:
67
- # print("❌ Silero VAD model chưa được khởi tạo")
68
- # return False
69
-
70
- # self.is_streaming = True
71
- # self.speech_callback = speech_callback
72
- # self.audio_buffer = []
73
- # print("🎙️ Bắt đầu Silero VAD streaming...")
74
- # return True
75
-
76
- # def stop_stream(self):
77
- # """Dừng stream"""
78
- # self.is_streaming = False
79
- # self.speech_callback = None
80
- # self.audio_buffer = []
81
- # print("🛑 Đã dừng Silero VAD streaming")
82
-
83
- # def process_stream(self, audio_chunk: np.ndarray, sample_rate: int):
84
- # """Xử lý audio chunk với Silero VAD"""
85
- # if not self.is_streaming or self.model is None:
86
- # return
87
-
88
- # try:
89
- # # Resample nếu cần
90
- # if sample_rate != self.sample_rate:
91
- # audio_chunk = self._resample_audio(audio_chunk, sample_rate, self.sample_rate)
92
-
93
- # # Thêm vào buffer
94
- # self.audio_buffer.extend(audio_chunk)
95
-
96
- # # Xử lý khi buffer đủ lớn (1 giây)
97
- # buffer_duration = len(self.audio_buffer) / self.sample_rate
98
- # if buffer_duration >= 1.0:
99
- # self._process_buffer()
100
-
101
- # except Exception as e:
102
- # print(f"❌ Lỗi xử lý Silero VAD: {e}")
103
-
104
- # def _process_buffer(self):
105
- # """Xử lý buffer audio với Silero VAD"""
106
- # try:
107
- # chunk_size = self.sample_rate # 1 giây
108
- # if len(self.audio_buffer) < chunk_size:
109
- # return
110
-
111
- # # Lấy chunk 1 giây
112
- # audio_chunk = np.array(self.audio_buffer[:chunk_size])
113
-
114
- # # Chuẩn hóa audio cho Silero
115
- # if audio_chunk.dtype != np.float32:
116
- # audio_chunk = audio_chunk.astype(np.float32)
117
- # if np.max(np.abs(audio_chunk)) > 1.0:
118
- # audio_chunk = audio_chunk / 32768.0 # Normalize từ int16
119
-
120
- # # Đảm bảo audio trong range [-1, 1]
121
- # audio_chunk = np.clip(audio_chunk, -1.0, 1.0)
122
-
123
- # # Chuyển thành tensor
124
- # audio_tensor = torch.from_numpy(audio_chunk).float().unsqueeze(0)
125
-
126
- # # Phát hiện speech với Silero VAD
127
- # with torch.no_grad():
128
- # speech_prob = self.model(audio_tensor, self.sample_rate).item()
129
-
130
- # print(f"🎯 Silero VAD speech probability: {speech_prob:.3f}")
131
-
132
- # # Ngưỡng phát hiện speech
133
- # if speech_prob > settings.VAD_THRESHOLD:
134
- # print(f"🎯 Silero VAD phát hiện speech: {speech_prob:.3f}")
135
-
136
- # # Gọi callback với speech segment
137
- # if self.speech_callback:
138
- # self.speech_callback(audio_chunk, self.sample_rate)
139
-
140
- # # Giữ lại 0.3 giây cuối để overlap
141
- # keep_samples = int(self.sample_rate * 0.3)
142
- # if len(self.audio_buffer) > keep_samples:
143
- # self.audio_buffer = self.audio_buffer[-keep_samples:]
144
- # else:
145
- # self.audio_buffer = []
146
-
147
- # except Exception as e:
148
- # print(f"❌ Lỗi xử lý Silero VAD buffer: {e}")
149
- # self.audio_buffer = []
150
-
151
- # def _resample_audio(self, audio: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray:
152
- # """Resample audio nếu cần"""
153
- # if orig_sr == target_sr:
154
- # return audio
155
-
156
- # try:
157
- # # Simple resampling bằng interpolation
158
- # orig_length = len(audio)
159
- # new_length = int(orig_length * target_sr / orig_sr)
160
-
161
- # # Linear interpolation
162
- # x_old = np.linspace(0, 1, orig_length)
163
- # x_new = np.linspace(0, 1, new_length)
164
- # resampled_audio = np.interp(x_new, x_old, audio)
165
-
166
- # return resampled_audio
167
- # except Exception as e:
168
- # print(f"⚠️ Lỗi resample: {e}")
169
- # return audio
170
-
171
- # def is_speech(self, audio_chunk: np.ndarray, sample_rate: int) -> bool:
172
- # """Kiểm tra xem audio chunk có phải là speech không"""
173
- # if self.model is None:
174
- # return True # Fallback: luôn coi là speech
175
-
176
- # try:
177
- # # Resample nếu cần
178
- # if sample_rate != self.sample_rate:
179
- # audio_chunk = self._resample_audio(audio_chunk, sample_rate, self.sample_rate)
180
-
181
- # # Chuẩn hóa audio
182
- # if audio_chunk.dtype != np.float32:
183
- # audio_chunk = audio_chunk.astype(np.float32)
184
- # if np.max(np.abs(audio_chunk)) > 1.0:
185
- # audio_chunk = audio_chunk / 32768.0
186
-
187
- # audio_chunk = np.clip(audio_chunk, -1.0, 1.0)
188
-
189
- # # Đảm bảo độ dài phù hợp
190
- # if len(audio_chunk) < 512:
191
- # padding = np.zeros(512 - len(audio_chunk), dtype=np.float32)
192
- # audio_chunk = np.concatenate([audio_chunk, padding])
193
-
194
- # # Chuyển thành tensor
195
- # audio_tensor = torch.from_numpy(audio_chunk).float().unsqueeze(0)
196
-
197
- # # Phát hiện speech
198
- # with torch.no_grad():
199
- # speech_prob = self.model(audio_tensor, self.sample_rate).item()
200
-
201
- # # Kiểm tra ngưỡng
202
- # return speech_prob > settings.VAD_THRESHOLD
203
-
204
- # except Exception as e:
205
- # print(f"❌ Lỗi kiểm tra speech với Silero: {e}")
206
- # return True
207
-
208
- # def get_speech_probability(self, audio_chunk: np.ndarray, sample_rate: int) -> float:
209
- # """Lấy xác suất speech"""
210
- # if self.model is None:
211
- # return 0.0
212
-
213
- # try:
214
- # # Resample nếu cần
215
- # if sample_rate != self.sample_rate:
216
- # audio_chunk = self._resample_audio(audio_chunk, sample_rate, self.sample_rate)
217
-
218
- # # Chuẩn hóa audio
219
- # if audio_chunk.dtype != np.float32:
220
- # audio_chunk = audio_chunk.astype(np.float32)
221
- # if np.max(np.abs(audio_chunk)) > 1.0:
222
- # audio_chunk = audio_chunk / 32768.0
223
-
224
- # audio_chunk = np.clip(audio_chunk, -1.0, 1.0)
225
-
226
- # # Đảm bảo độ dài phù hợp
227
- # if len(audio_chunk) < 512:
228
- # padding = np.zeros(512 - len(audio_chunk), dtype=np.float32)
229
- # audio_chunk = np.concatenate([audio_chunk, padding])
230
-
231
- # # Chuyển thành tensor
232
- # audio_tensor = torch.from_numpy(audio_chunk).float().unsqueeze(0)
233
-
234
- # # Phát hiện speech
235
- # with torch.no_grad():
236
- # return self.model(audio_tensor, self.sample_rate).item()
237
-
238
- # except Exception as e:
239
- # print(f"❌ Lỗi lấy speech probability: {e}")
240
- # return 0.0import torch
241
  import torch
242
  import numpy as np
243
  from typing import Callable
 
1
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import torch
3
  import numpy as np
4
  from typing import Callable