import sys import os import json import time import numpy as np import torch import torchaudio import gradio as gr from torchaudio.compliance.kaldi import fbank from pypinyin import lazy_pinyin from train_pinyin import MMKWS2_Wrapper # 设备与模型加载 # device = torch.device("cuda:4") device = torch.device("cpu") wrapper = MMKWS2_Wrapper.load_from_checkpoint( "stepstep=024500.ckpt", map_location=device, ) wrapper.eval() # 注册信息 registered = {"text": "", "audios": []} enroll = None enroll_text = None last_wake_time = 0 def load_pinyin_index(save_path): """加载拼音索引映射""" with open(save_path, "r", encoding="utf-8") as f: data = json.load(f) return data["pinyin_to_index"], data["index_to_pinyin"] pinyin_to_index, index_to_pinyin = load_pinyin_index("pinyin_index.json") def add_audio(text, audio, audio_list): if not text: return audio_list, "请先输入唤醒词文本" if audio is None or audio[1] is None or len(audio[1]) == 0: return audio_list, "请上传或录制音频" audio_list = audio_list or [] if len(audio_list) >= 5: return audio_list, "最多支持5条音频" audio_list.append(audio) return audio_list, f"已录入 {len(audio_list)} 条音频" def register_keyword(text, audio_list): if not text: return gr.update(value="请先输入唤醒词文本") if not audio_list or len(audio_list) == 0: return gr.update(value="请至少上传或录制一条音频") registered["text"] = text registered["audios"] = audio_list global enroll_text enroll_text = text fused_feats = [] for audio in audio_list: anchor_wave, _ = torchaudio.load(audio) anchor_text_embedding = torch.tensor([pinyin_to_index[p] + 1 for p in lazy_pinyin(text)]) anchor_wave = anchor_wave.to(device) anchor_text_embedding = anchor_text_embedding.to(device).unsqueeze(0) with torch.no_grad(): outputs = wrapper._hubert_model(anchor_wave.half()) anchor_wave_embedding = outputs.last_hidden_state anchor_wave_embedding = anchor_wave_embedding.to(anchor_wave.dtype) fused_feat = wrapper.model.enrollment( anchor_wave_embedding, anchor_text_embedding ) fused_feats.append(fused_feat) fused_feats = torch.cat(fused_feats, dim=0) fused_feats, _ = fused_feats.max(dim=0) fused_feats = fused_feats.unsqueeze(0) global enroll enroll = fused_feats return gr.update(value=f"注册完成,唤醒词:{text},音频数:{len(audio_list)}") def update_gallery(audio_list): if audio_list and len(audio_list) > 0: return gr.update(visible=True, value=audio_list[-1]) else: return gr.update(visible=False, value=None) def streaming_detect_handler(current_audio, state, audio_player): global last_wake_time, enroll_text, enroll if current_audio is None or current_audio[1] is None or len(current_audio[1]) == 0: return state, gr.update() if enroll_text is None: return state, gr.update() pad = len(enroll_text) * 5 state = (state or []) + [current_audio] state = state[-pad:] if len(state) < pad: return state, gr.update() sr = state[0][0] audio_list = [x[1] for x in state] audio_concat = np.concatenate(audio_list, axis=0) audio_concat = audio_concat.astype(np.float32) / 32768.0 audio_concat = torch.from_numpy(audio_concat).unsqueeze(0) audio_concat = torchaudio.functional.resample(audio_concat, sr, 16000) audio_concat = audio_concat / torch.max(torch.abs(audio_concat)) compare_wave = fbank(audio_concat, num_mel_bins=80) compare_wave = compare_wave.to(device).unsqueeze(0) compare_lengths = torch.tensor([compare_wave.size(1)], device=compare_wave.device) if enroll is None: return state, None current_time = time.time() if current_time - last_wake_time <= 2: return state, gr.update() with torch.no_grad(): preds = wrapper.model.verification( enroll, compare_wave, compare_lengths ) preds = torch.sigmoid(preds).item() if preds >= 0.85: print(f"Wake up! {preds}") last_wake_time = current_time audio_path = "tts-2025-04-27@197cd135f2a2451b9cab9cf2add1c1ab.wav" state = [] return state, gr.Audio(value=audio_path, visible=True, autoplay=True) if preds >= 0.6: print(f"Preds! {preds}") return state, None with gr.Blocks() as demo: gr.Markdown("# 自定义关键词检测 Demo") with gr.Row(): with gr.Column(scale=1): gr.Markdown("## 注册唤醒词") text_input = gr.Textbox(label="唤醒词文本", placeholder="请输入唤醒词") audio_list_state = gr.State([]) audio_input = gr.Audio(label="上传或录制音频", type="filepath") add_btn = gr.Button("添加音频") audio_status = gr.Textbox(label="音频状态", interactive=False) audio_gallery = gr.Audio(label="已添加音频", type="filepath", interactive=False, visible=False) register_btn = gr.Button("注册完成") register_status = gr.Textbox(label="注册状态", interactive=False) add_btn.click( add_audio, inputs=[text_input, audio_input, audio_list_state], outputs=[audio_list_state, audio_status] ).then( update_gallery, inputs=audio_list_state, outputs=audio_gallery ) register_btn.click( register_keyword, inputs=[text_input, audio_list_state], outputs=register_status ) with gr.Column(scale=2): gr.Markdown("## 实时检测") mic = gr.Audio(sources="microphone", streaming=True, label="实时监听") state = gr.State(value=[]) audio_player = gr.Audio(label="唤醒提示", visible=False) mic.stream( streaming_detect_handler, inputs=[mic, state, audio_player], outputs=[state, audio_player], time_limit=1000, stream_every=0.05, ) if __name__ == "__main__": demo.launch()