Spaces:
Sleeping
Sleeping
Commit ·
ab55b54
1
Parent(s): 85a83b0
feat: 語音登入優化 - 錄音時間調整為3秒 & CNN模組整合
Browse files- 將語音登入錄音時間從4秒調整為3秒,提升用戶體驗
- 在inference.py中添加predict_files函數,整合CNN說話者辨識
- 修改VoiceAuthService使用inference.py而非獨立adapter
- 前端語音綁定錄音時間同步調整為3秒
- 更新測試檔案中的模組引用路徑
app.py
CHANGED
|
@@ -219,7 +219,7 @@ async def lifespan(app: FastAPI):
|
|
| 219 |
# 初始化語音登入服務(硬編參數)
|
| 220 |
try:
|
| 221 |
app.state.voice_auth = VoiceAuthService(config=VoiceLoginConfig(
|
| 222 |
-
window_seconds=
|
| 223 |
required_windows=1,
|
| 224 |
sample_rate=16000,
|
| 225 |
prob_threshold=0.40,
|
|
|
|
| 219 |
# 初始化語音登入服務(硬編參數)
|
| 220 |
try:
|
| 221 |
app.state.voice_auth = VoiceAuthService(config=VoiceLoginConfig(
|
| 222 |
+
window_seconds=3,
|
| 223 |
required_windows=1,
|
| 224 |
sample_rate=16000,
|
| 225 |
prob_threshold=0.40,
|
models/speaker_identification/scripts/inference.py
CHANGED
|
@@ -78,6 +78,76 @@ def softmax(x):
|
|
| 78 |
return e / e.sum(dim=1, keepdim=True)
|
| 79 |
|
| 80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
# ============== 錄音與前處理(比照 process_audio.py) ==============
|
| 82 |
REC_SR = 22050
|
| 83 |
TARGET_RMS = 0.1
|
|
|
|
| 78 |
return e / e.sum(dim=1, keepdim=True)
|
| 79 |
|
| 80 |
|
| 81 |
+
def predict_files(model_dir, file_list, threshold=0.0):
|
| 82 |
+
"""
|
| 83 |
+
預測多個音訊檔案的說話者。
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
model_dir: 模型目錄,包含 speaker_id_model.pth 和 classes.txt
|
| 87 |
+
file_list: 檔案路徑列表
|
| 88 |
+
threshold: 預測門檻(目前未使用)
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
結果列表,每個元素為字典,包含 'pred', 'score', 'top'
|
| 92 |
+
"""
|
| 93 |
+
device = get_device()
|
| 94 |
+
bundle = torchaudio.pipelines.WAV2VEC2_BASE
|
| 95 |
+
target_sr = bundle.sample_rate
|
| 96 |
+
|
| 97 |
+
model_path = os.path.join(model_dir, 'speaker_id_model.pth')
|
| 98 |
+
classes_path = os.path.join(model_dir, 'classes.txt')
|
| 99 |
+
processed_dir = os.path.join(model_dir, 'processed_audio')
|
| 100 |
+
|
| 101 |
+
if os.path.isfile(classes_path):
|
| 102 |
+
classes = load_classes(classes_path)
|
| 103 |
+
elif os.path.isdir(processed_dir):
|
| 104 |
+
classes = load_classes(processed_dir)
|
| 105 |
+
else:
|
| 106 |
+
raise FileNotFoundError(f"找不到類別定義:{classes_path} 或 {processed_dir}")
|
| 107 |
+
|
| 108 |
+
num_classes = len(classes)
|
| 109 |
+
|
| 110 |
+
model = Wav2Vec2SpeakerClassifier(bundle, num_classes)
|
| 111 |
+
state = torch.load(model_path, map_location='cpu')
|
| 112 |
+
model.load_state_dict(state)
|
| 113 |
+
model.to(device).eval()
|
| 114 |
+
|
| 115 |
+
results = []
|
| 116 |
+
for file_path in file_list:
|
| 117 |
+
try:
|
| 118 |
+
# 前處理音訊
|
| 119 |
+
y, sr = process_like_training(file_path)
|
| 120 |
+
|
| 121 |
+
# 轉成模型輸入
|
| 122 |
+
y_t = torch.tensor(y, dtype=torch.float32).unsqueeze(0)
|
| 123 |
+
if sr != target_sr:
|
| 124 |
+
resampler = torchaudio.transforms.Resample(sr, target_sr)
|
| 125 |
+
y_t = resampler(y_t)
|
| 126 |
+
length = torch.tensor([y_t.shape[1]], dtype=torch.long)
|
| 127 |
+
waveforms, lengths = y_t.to(device), length.to(device)
|
| 128 |
+
|
| 129 |
+
with torch.no_grad():
|
| 130 |
+
logits = model(waveforms, lengths)
|
| 131 |
+
probs = softmax(logits).squeeze(0).cpu()
|
| 132 |
+
top_prob, top_idx = torch.max(probs, dim=0)
|
| 133 |
+
pred = classes[top_idx.item()]
|
| 134 |
+
|
| 135 |
+
# 獲取 top 候選
|
| 136 |
+
topk = torch.topk(probs, k=min(3, num_classes))
|
| 137 |
+
top = [(classes[i], float(p)) for p, i in zip(topk.values.tolist(), topk.indices.tolist())]
|
| 138 |
+
|
| 139 |
+
result = {
|
| 140 |
+
'pred': pred,
|
| 141 |
+
'score': float(top_prob.item()),
|
| 142 |
+
'top': top
|
| 143 |
+
}
|
| 144 |
+
results.append(result)
|
| 145 |
+
except Exception as e:
|
| 146 |
+
results.append({'error': str(e)})
|
| 147 |
+
|
| 148 |
+
return results
|
| 149 |
+
|
| 150 |
+
|
| 151 |
# ============== 錄音與前處理(比照 process_audio.py) ==============
|
| 152 |
REC_SR = 22050
|
| 153 |
TARGET_RMS = 0.1
|
services/voice_login.py
CHANGED
|
@@ -105,7 +105,7 @@ class VoiceAuthService:
|
|
| 105 |
if str(self.identity_dir) not in os.sys.path:
|
| 106 |
os.sys.path.insert(0, str(self.identity_dir))
|
| 107 |
try:
|
| 108 |
-
from
|
| 109 |
except Exception as e: # pragma: no cover
|
| 110 |
raise RuntimeError(f"載入 CNN 說話者辨識模組失敗:{e}")
|
| 111 |
self._predict_files = _predict_files
|
|
|
|
| 105 |
if str(self.identity_dir) not in os.sys.path:
|
| 106 |
os.sys.path.insert(0, str(self.identity_dir))
|
| 107 |
try:
|
| 108 |
+
from scripts.inference import predict_files as _predict_files # type: ignore
|
| 109 |
except Exception as e: # pragma: no cover
|
| 110 |
raise RuntimeError(f"載入 CNN 說話者辨識模組失敗:{e}")
|
| 111 |
self._predict_files = _predict_files
|
static/frontend/js/websocket.js
CHANGED
|
@@ -783,7 +783,7 @@ async function handleVoiceBindingReady() {
|
|
| 783 |
|
| 784 |
// 更新提示文字
|
| 785 |
if (typeof transcript !== 'undefined') {
|
| 786 |
-
transcript.textContent = '請開始說話(錄音
|
| 787 |
transcript.className = 'voice-transcript provisional';
|
| 788 |
}
|
| 789 |
|
|
@@ -814,10 +814,10 @@ async function handleVoiceBindingReady() {
|
|
| 814 |
return;
|
| 815 |
}
|
| 816 |
|
| 817 |
-
console.log('⏱️ 開始倒數
|
| 818 |
|
| 819 |
// 倒數計時提示
|
| 820 |
-
let countdown =
|
| 821 |
const countdownInterval = setInterval(() => {
|
| 822 |
countdown--;
|
| 823 |
if (countdown > 0 && typeof transcript !== 'undefined') {
|
|
@@ -826,10 +826,10 @@ async function handleVoiceBindingReady() {
|
|
| 826 |
}
|
| 827 |
}, 1000);
|
| 828 |
|
| 829 |
-
//
|
| 830 |
setTimeout(() => {
|
| 831 |
clearInterval(countdownInterval);
|
| 832 |
-
console.log('⏹️
|
| 833 |
|
| 834 |
// 停止音訊視覺化
|
| 835 |
if (typeof stopRealAudioAnalysis === 'function') {
|
|
@@ -852,7 +852,7 @@ async function handleVoiceBindingReady() {
|
|
| 852 |
transcript.className = 'voice-transcript provisional';
|
| 853 |
}
|
| 854 |
|
| 855 |
-
},
|
| 856 |
} else {
|
| 857 |
console.error('❌ WebSocket 管理器未初始化');
|
| 858 |
showErrorNotification('系統錯誤:WebSocket 未連接');
|
|
|
|
| 783 |
|
| 784 |
// 更新提示文字
|
| 785 |
if (typeof transcript !== 'undefined') {
|
| 786 |
+
transcript.textContent = '請開始說話(錄音 3 秒)...';
|
| 787 |
transcript.className = 'voice-transcript provisional';
|
| 788 |
}
|
| 789 |
|
|
|
|
| 814 |
return;
|
| 815 |
}
|
| 816 |
|
| 817 |
+
console.log('⏱️ 開始倒數 3 秒錄音...');
|
| 818 |
|
| 819 |
// 倒數計時提示
|
| 820 |
+
let countdown = 3;
|
| 821 |
const countdownInterval = setInterval(() => {
|
| 822 |
countdown--;
|
| 823 |
if (countdown > 0 && typeof transcript !== 'undefined') {
|
|
|
|
| 826 |
}
|
| 827 |
}, 1000);
|
| 828 |
|
| 829 |
+
// 3 秒後自動停止錄音
|
| 830 |
setTimeout(() => {
|
| 831 |
clearInterval(countdownInterval);
|
| 832 |
+
console.log('⏹️ 3 秒錄音完成,自動停止');
|
| 833 |
|
| 834 |
// 停止音訊視覺化
|
| 835 |
if (typeof stopRealAudioAnalysis === 'function') {
|
|
|
|
| 852 |
transcript.className = 'voice-transcript provisional';
|
| 853 |
}
|
| 854 |
|
| 855 |
+
}, 3000); // 3 秒錄音時長
|
| 856 |
} else {
|
| 857 |
console.error('❌ WebSocket 管理器未初始化');
|
| 858 |
showErrorNotification('系統錯誤:WebSocket 未連接');
|
tests/services/test_voice_login_cnn.py
CHANGED
|
@@ -27,7 +27,7 @@ def test_voice_login_success_with_cnn_stub(monkeypatch):
|
|
| 27 |
(tmpdir / "classes.txt").write_text("alice\nbob\n", encoding="utf-8")
|
| 28 |
monkeypatch.setenv("VOICE_CNN_MODEL_DIR", str(tmpdir))
|
| 29 |
|
| 30 |
-
# 在 VoiceAuthService 初始化前,先以假模組覆蓋
|
| 31 |
dummy = types.SimpleNamespace(
|
| 32 |
predict_files=lambda model_dir, inputs, threshold=0.0: [{
|
| 33 |
"file": str(inputs[0]),
|
|
@@ -37,7 +37,7 @@ def test_voice_login_success_with_cnn_stub(monkeypatch):
|
|
| 37 |
"is_unknown": False,
|
| 38 |
}]
|
| 39 |
)
|
| 40 |
-
monkeypatch.setitem(sys.modules, '
|
| 41 |
|
| 42 |
svc = VoiceAuthService(config=VoiceLoginConfig(
|
| 43 |
window_seconds=1,
|
|
@@ -68,11 +68,11 @@ def test_voice_login_no_audio_returns_error(monkeypatch):
|
|
| 68 |
(tmpdir / "classes.txt").write_text("alice\nbob\n", encoding="utf-8")
|
| 69 |
monkeypatch.setenv("VOICE_CNN_MODEL_DIR", str(tmpdir))
|
| 70 |
|
| 71 |
-
# 同樣先注入假
|
| 72 |
dummy = types.SimpleNamespace(
|
| 73 |
predict_files=lambda model_dir, inputs, threshold=0.0: []
|
| 74 |
)
|
| 75 |
-
monkeypatch.setitem(sys.modules, '
|
| 76 |
|
| 77 |
svc = VoiceAuthService(config=VoiceLoginConfig(
|
| 78 |
window_seconds=1,
|
|
|
|
| 27 |
(tmpdir / "classes.txt").write_text("alice\nbob\n", encoding="utf-8")
|
| 28 |
monkeypatch.setenv("VOICE_CNN_MODEL_DIR", str(tmpdir))
|
| 29 |
|
| 30 |
+
# 在 VoiceAuthService 初始化前,先以假模組覆蓋 inference,避免實際載入大型相依(如 torchaudio)
|
| 31 |
dummy = types.SimpleNamespace(
|
| 32 |
predict_files=lambda model_dir, inputs, threshold=0.0: [{
|
| 33 |
"file": str(inputs[0]),
|
|
|
|
| 37 |
"is_unknown": False,
|
| 38 |
}]
|
| 39 |
)
|
| 40 |
+
monkeypatch.setitem(sys.modules, 'scripts.inference', dummy)
|
| 41 |
|
| 42 |
svc = VoiceAuthService(config=VoiceLoginConfig(
|
| 43 |
window_seconds=1,
|
|
|
|
| 68 |
(tmpdir / "classes.txt").write_text("alice\nbob\n", encoding="utf-8")
|
| 69 |
monkeypatch.setenv("VOICE_CNN_MODEL_DIR", str(tmpdir))
|
| 70 |
|
| 71 |
+
# 同樣先注入假 inference 模組
|
| 72 |
dummy = types.SimpleNamespace(
|
| 73 |
predict_files=lambda model_dir, inputs, threshold=0.0: []
|
| 74 |
)
|
| 75 |
+
monkeypatch.setitem(sys.modules, 'scripts.inference', dummy)
|
| 76 |
|
| 77 |
svc = VoiceAuthService(config=VoiceLoginConfig(
|
| 78 |
window_seconds=1,
|