XiaoBai1221 commited on
Commit
ab55b54
·
1 Parent(s): 85a83b0

feat: 語音登入優化 - 錄音時間調整為3秒 & CNN模組整合

Browse files

- 將語音登入錄音時間從4秒調整為3秒,提升用戶體驗
- 在inference.py中添加predict_files函數,整合CNN說話者辨識
- 修改VoiceAuthService使用inference.py而非獨立adapter
- 前端語音綁定錄音時間同步調整為3秒
- 更新測試檔案中的模組引用路徑

app.py CHANGED
@@ -219,7 +219,7 @@ async def lifespan(app: FastAPI):
219
  # 初始化語音登入服務(硬編參數)
220
  try:
221
  app.state.voice_auth = VoiceAuthService(config=VoiceLoginConfig(
222
- window_seconds=4,
223
  required_windows=1,
224
  sample_rate=16000,
225
  prob_threshold=0.40,
 
219
  # 初始化語音登入服務(硬編參數)
220
  try:
221
  app.state.voice_auth = VoiceAuthService(config=VoiceLoginConfig(
222
+ window_seconds=3,
223
  required_windows=1,
224
  sample_rate=16000,
225
  prob_threshold=0.40,
models/speaker_identification/scripts/inference.py CHANGED
@@ -78,6 +78,76 @@ def softmax(x):
78
  return e / e.sum(dim=1, keepdim=True)
79
 
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  # ============== 錄音與前處理(比照 process_audio.py) ==============
82
  REC_SR = 22050
83
  TARGET_RMS = 0.1
 
78
  return e / e.sum(dim=1, keepdim=True)
79
 
80
 
81
+ def predict_files(model_dir, file_list, threshold=0.0):
82
+ """
83
+ 預測多個音訊檔案的說話者。
84
+
85
+ Args:
86
+ model_dir: 模型目錄,包含 speaker_id_model.pth 和 classes.txt
87
+ file_list: 檔案路徑列表
88
+ threshold: 預測門檻(目前未使用)
89
+
90
+ Returns:
91
+ 結果列表,每個元素為字典,包含 'pred', 'score', 'top'
92
+ """
93
+ device = get_device()
94
+ bundle = torchaudio.pipelines.WAV2VEC2_BASE
95
+ target_sr = bundle.sample_rate
96
+
97
+ model_path = os.path.join(model_dir, 'speaker_id_model.pth')
98
+ classes_path = os.path.join(model_dir, 'classes.txt')
99
+ processed_dir = os.path.join(model_dir, 'processed_audio')
100
+
101
+ if os.path.isfile(classes_path):
102
+ classes = load_classes(classes_path)
103
+ elif os.path.isdir(processed_dir):
104
+ classes = load_classes(processed_dir)
105
+ else:
106
+ raise FileNotFoundError(f"找不到類別定義:{classes_path} 或 {processed_dir}")
107
+
108
+ num_classes = len(classes)
109
+
110
+ model = Wav2Vec2SpeakerClassifier(bundle, num_classes)
111
+ state = torch.load(model_path, map_location='cpu')
112
+ model.load_state_dict(state)
113
+ model.to(device).eval()
114
+
115
+ results = []
116
+ for file_path in file_list:
117
+ try:
118
+ # 前處理音訊
119
+ y, sr = process_like_training(file_path)
120
+
121
+ # 轉成模型輸入
122
+ y_t = torch.tensor(y, dtype=torch.float32).unsqueeze(0)
123
+ if sr != target_sr:
124
+ resampler = torchaudio.transforms.Resample(sr, target_sr)
125
+ y_t = resampler(y_t)
126
+ length = torch.tensor([y_t.shape[1]], dtype=torch.long)
127
+ waveforms, lengths = y_t.to(device), length.to(device)
128
+
129
+ with torch.no_grad():
130
+ logits = model(waveforms, lengths)
131
+ probs = softmax(logits).squeeze(0).cpu()
132
+ top_prob, top_idx = torch.max(probs, dim=0)
133
+ pred = classes[top_idx.item()]
134
+
135
+ # 獲取 top 候選
136
+ topk = torch.topk(probs, k=min(3, num_classes))
137
+ top = [(classes[i], float(p)) for p, i in zip(topk.values.tolist(), topk.indices.tolist())]
138
+
139
+ result = {
140
+ 'pred': pred,
141
+ 'score': float(top_prob.item()),
142
+ 'top': top
143
+ }
144
+ results.append(result)
145
+ except Exception as e:
146
+ results.append({'error': str(e)})
147
+
148
+ return results
149
+
150
+
151
  # ============== 錄音與前處理(比照 process_audio.py) ==============
152
  REC_SR = 22050
153
  TARGET_RMS = 0.1
services/voice_login.py CHANGED
@@ -105,7 +105,7 @@ class VoiceAuthService:
105
  if str(self.identity_dir) not in os.sys.path:
106
  os.sys.path.insert(0, str(self.identity_dir))
107
  try:
108
- from models.speaker_identification.cnn_adapter import predict_files as _predict_files # type: ignore
109
  except Exception as e: # pragma: no cover
110
  raise RuntimeError(f"載入 CNN 說話者辨識模組失敗:{e}")
111
  self._predict_files = _predict_files
 
105
  if str(self.identity_dir) not in os.sys.path:
106
  os.sys.path.insert(0, str(self.identity_dir))
107
  try:
108
+ from scripts.inference import predict_files as _predict_files # type: ignore
109
  except Exception as e: # pragma: no cover
110
  raise RuntimeError(f"載入 CNN 說話者辨識模組失敗:{e}")
111
  self._predict_files = _predict_files
static/frontend/js/websocket.js CHANGED
@@ -783,7 +783,7 @@ async function handleVoiceBindingReady() {
783
 
784
  // 更新提示文字
785
  if (typeof transcript !== 'undefined') {
786
- transcript.textContent = '請開始說話(錄音 5 秒)...';
787
  transcript.className = 'voice-transcript provisional';
788
  }
789
 
@@ -814,10 +814,10 @@ async function handleVoiceBindingReady() {
814
  return;
815
  }
816
 
817
- console.log('⏱️ 開始倒數 5 秒錄音...');
818
 
819
  // 倒數計時提示
820
- let countdown = 5;
821
  const countdownInterval = setInterval(() => {
822
  countdown--;
823
  if (countdown > 0 && typeof transcript !== 'undefined') {
@@ -826,10 +826,10 @@ async function handleVoiceBindingReady() {
826
  }
827
  }, 1000);
828
 
829
- // 5 秒後自動停止錄音
830
  setTimeout(() => {
831
  clearInterval(countdownInterval);
832
- console.log('⏹️ 5 秒錄音完成,自動停止');
833
 
834
  // 停止音訊視覺化
835
  if (typeof stopRealAudioAnalysis === 'function') {
@@ -852,7 +852,7 @@ async function handleVoiceBindingReady() {
852
  transcript.className = 'voice-transcript provisional';
853
  }
854
 
855
- }, 5000); // 5 秒錄音時長
856
  } else {
857
  console.error('❌ WebSocket 管理器未初始化');
858
  showErrorNotification('系統錯誤:WebSocket 未連接');
 
783
 
784
  // 更新提示文字
785
  if (typeof transcript !== 'undefined') {
786
+ transcript.textContent = '請開始說話(錄音 3 秒)...';
787
  transcript.className = 'voice-transcript provisional';
788
  }
789
 
 
814
  return;
815
  }
816
 
817
+ console.log('⏱️ 開始倒數 3 秒錄音...');
818
 
819
  // 倒數計時提示
820
+ let countdown = 3;
821
  const countdownInterval = setInterval(() => {
822
  countdown--;
823
  if (countdown > 0 && typeof transcript !== 'undefined') {
 
826
  }
827
  }, 1000);
828
 
829
+ // 3 秒後自動停止錄音
830
  setTimeout(() => {
831
  clearInterval(countdownInterval);
832
+ console.log('⏹️ 3 秒錄音完成,自動停止');
833
 
834
  // 停止音訊視覺化
835
  if (typeof stopRealAudioAnalysis === 'function') {
 
852
  transcript.className = 'voice-transcript provisional';
853
  }
854
 
855
+ }, 3000); // 3 秒錄音時長
856
  } else {
857
  console.error('❌ WebSocket 管理器未初始化');
858
  showErrorNotification('系統錯誤:WebSocket 未連接');
tests/services/test_voice_login_cnn.py CHANGED
@@ -27,7 +27,7 @@ def test_voice_login_success_with_cnn_stub(monkeypatch):
27
  (tmpdir / "classes.txt").write_text("alice\nbob\n", encoding="utf-8")
28
  monkeypatch.setenv("VOICE_CNN_MODEL_DIR", str(tmpdir))
29
 
30
- # 在 VoiceAuthService 初始化前,先以假模組覆蓋 cnn_adapter,避免實際載入大型相依(如 torchaudio)
31
  dummy = types.SimpleNamespace(
32
  predict_files=lambda model_dir, inputs, threshold=0.0: [{
33
  "file": str(inputs[0]),
@@ -37,7 +37,7 @@ def test_voice_login_success_with_cnn_stub(monkeypatch):
37
  "is_unknown": False,
38
  }]
39
  )
40
- monkeypatch.setitem(sys.modules, 'models.speaker_identification.cnn_adapter', dummy)
41
 
42
  svc = VoiceAuthService(config=VoiceLoginConfig(
43
  window_seconds=1,
@@ -68,11 +68,11 @@ def test_voice_login_no_audio_returns_error(monkeypatch):
68
  (tmpdir / "classes.txt").write_text("alice\nbob\n", encoding="utf-8")
69
  monkeypatch.setenv("VOICE_CNN_MODEL_DIR", str(tmpdir))
70
 
71
- # 同樣先注入假 cnn_adapter 模組
72
  dummy = types.SimpleNamespace(
73
  predict_files=lambda model_dir, inputs, threshold=0.0: []
74
  )
75
- monkeypatch.setitem(sys.modules, 'models.speaker_identification.cnn_adapter', dummy)
76
 
77
  svc = VoiceAuthService(config=VoiceLoginConfig(
78
  window_seconds=1,
 
27
  (tmpdir / "classes.txt").write_text("alice\nbob\n", encoding="utf-8")
28
  monkeypatch.setenv("VOICE_CNN_MODEL_DIR", str(tmpdir))
29
 
30
+ # 在 VoiceAuthService 初始化前,先以假模組覆蓋 inference,避免實際載入大型相依(如 torchaudio)
31
  dummy = types.SimpleNamespace(
32
  predict_files=lambda model_dir, inputs, threshold=0.0: [{
33
  "file": str(inputs[0]),
 
37
  "is_unknown": False,
38
  }]
39
  )
40
+ monkeypatch.setitem(sys.modules, 'scripts.inference', dummy)
41
 
42
  svc = VoiceAuthService(config=VoiceLoginConfig(
43
  window_seconds=1,
 
68
  (tmpdir / "classes.txt").write_text("alice\nbob\n", encoding="utf-8")
69
  monkeypatch.setenv("VOICE_CNN_MODEL_DIR", str(tmpdir))
70
 
71
+ # 同樣先注入假 inference 模組
72
  dummy = types.SimpleNamespace(
73
  predict_files=lambda model_dir, inputs, threshold=0.0: []
74
  )
75
+ monkeypatch.setitem(sys.modules, 'scripts.inference', dummy)
76
 
77
  svc = VoiceAuthService(config=VoiceLoginConfig(
78
  window_seconds=1,