File size: 6,388 Bytes
693e27f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import sys
import os
import json
import time
import numpy as np
import torch
import torchaudio
import gradio as gr
from torchaudio.compliance.kaldi import fbank
from pypinyin import lazy_pinyin
from train_pinyin import MMKWS2_Wrapper
# 设备与模型加载
# device = torch.device("cuda:4")
device = torch.device("cpu")
wrapper = MMKWS2_Wrapper.load_from_checkpoint(
"stepstep=024500.ckpt",
map_location=device,
)
wrapper.eval()
# 注册信息
registered = {"text": "", "audios": []}
enroll = None
enroll_text = None
last_wake_time = 0
def load_pinyin_index(save_path):
"""加载拼音索引映射"""
with open(save_path, "r", encoding="utf-8") as f:
data = json.load(f)
return data["pinyin_to_index"], data["index_to_pinyin"]
pinyin_to_index, index_to_pinyin = load_pinyin_index("pinyin_index.json")
def add_audio(text, audio, audio_list):
if not text:
return audio_list, "请先输入唤醒词文本"
if audio is None or audio[1] is None or len(audio[1]) == 0:
return audio_list, "请上传或录制音频"
audio_list = audio_list or []
if len(audio_list) >= 5:
return audio_list, "最多支持5条音频"
audio_list.append(audio)
return audio_list, f"已录入 {len(audio_list)} 条音频"
def register_keyword(text, audio_list):
if not text:
return gr.update(value="请先输入唤醒词文本")
if not audio_list or len(audio_list) == 0:
return gr.update(value="请至少上传或录制一条音频")
registered["text"] = text
registered["audios"] = audio_list
global enroll_text
enroll_text = text
fused_feats = []
for audio in audio_list:
anchor_wave, _ = torchaudio.load(audio)
anchor_text_embedding = torch.tensor([pinyin_to_index[p] + 1 for p in lazy_pinyin(text)])
anchor_wave = anchor_wave.to(device)
anchor_text_embedding = anchor_text_embedding.to(device).unsqueeze(0)
with torch.no_grad():
outputs = wrapper._hubert_model(anchor_wave.half())
anchor_wave_embedding = outputs.last_hidden_state
anchor_wave_embedding = anchor_wave_embedding.to(anchor_wave.dtype)
fused_feat = wrapper.model.enrollment(
anchor_wave_embedding,
anchor_text_embedding
)
fused_feats.append(fused_feat)
fused_feats = torch.cat(fused_feats, dim=0)
fused_feats, _ = fused_feats.max(dim=0)
fused_feats = fused_feats.unsqueeze(0)
global enroll
enroll = fused_feats
return gr.update(value=f"注册完成,唤醒词:{text},音频数:{len(audio_list)}")
def update_gallery(audio_list):
if audio_list and len(audio_list) > 0:
return gr.update(visible=True, value=audio_list[-1])
else:
return gr.update(visible=False, value=None)
def streaming_detect_handler(current_audio, state, audio_player):
global last_wake_time, enroll_text, enroll
if current_audio is None or current_audio[1] is None or len(current_audio[1]) == 0:
return state, gr.update()
if enroll_text is None:
return state, gr.update()
pad = len(enroll_text) * 5
state = (state or []) + [current_audio]
state = state[-pad:]
if len(state) < pad:
return state, gr.update()
sr = state[0][0]
audio_list = [x[1] for x in state]
audio_concat = np.concatenate(audio_list, axis=0)
audio_concat = audio_concat.astype(np.float32) / 32768.0
audio_concat = torch.from_numpy(audio_concat).unsqueeze(0)
audio_concat = torchaudio.functional.resample(audio_concat, sr, 16000)
audio_concat = audio_concat / torch.max(torch.abs(audio_concat))
compare_wave = fbank(audio_concat, num_mel_bins=80)
compare_wave = compare_wave.to(device).unsqueeze(0)
compare_lengths = torch.tensor([compare_wave.size(1)], device=compare_wave.device)
if enroll is None:
return state, None
current_time = time.time()
if current_time - last_wake_time <= 2:
return state, gr.update()
with torch.no_grad():
preds = wrapper.model.verification(
enroll,
compare_wave,
compare_lengths
)
preds = torch.sigmoid(preds).item()
if preds >= 0.85:
print(f"Wake up! {preds}")
last_wake_time = current_time
audio_path = "tts-2025-04-27@197cd135f2a2451b9cab9cf2add1c1ab.wav"
state = []
return state, gr.Audio(value=audio_path, visible=True, autoplay=True)
if preds >= 0.6:
print(f"Preds! {preds}")
return state, None
with gr.Blocks() as demo:
gr.Markdown("# 自定义关键词检测 Demo")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("## 注册唤醒词")
text_input = gr.Textbox(label="唤醒词文本", placeholder="请输入唤醒词")
audio_list_state = gr.State([])
audio_input = gr.Audio(label="上传或录制音频", type="filepath")
add_btn = gr.Button("添加音频")
audio_status = gr.Textbox(label="音频状态", interactive=False)
audio_gallery = gr.Audio(label="已添加音频", type="filepath", interactive=False, visible=False)
register_btn = gr.Button("注册完成")
register_status = gr.Textbox(label="注册状态", interactive=False)
add_btn.click(
add_audio,
inputs=[text_input, audio_input, audio_list_state],
outputs=[audio_list_state, audio_status]
).then(
update_gallery,
inputs=audio_list_state,
outputs=audio_gallery
)
register_btn.click(
register_keyword,
inputs=[text_input, audio_list_state],
outputs=register_status
)
with gr.Column(scale=2):
gr.Markdown("## 实时检测")
mic = gr.Audio(sources="microphone", streaming=True, label="实时监听")
state = gr.State(value=[])
audio_player = gr.Audio(label="唤醒提示", visible=False)
mic.stream(
streaming_detect_handler,
inputs=[mic, state, audio_player],
outputs=[state, audio_player],
time_limit=1000,
stream_every=0.05,
)
if __name__ == "__main__":
demo.launch() |