File size: 6,388 Bytes
693e27f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import sys
import os
import json
import time
import numpy as np
import torch
import torchaudio
import gradio as gr
from torchaudio.compliance.kaldi import fbank
from pypinyin import lazy_pinyin
from train_pinyin import MMKWS2_Wrapper

# 设备与模型加载
# device = torch.device("cuda:4")
device = torch.device("cpu")
wrapper = MMKWS2_Wrapper.load_from_checkpoint(
    "stepstep=024500.ckpt",
    map_location=device,
)
wrapper.eval()

# 注册信息
registered = {"text": "", "audios": []}
enroll = None
enroll_text = None
last_wake_time = 0

def load_pinyin_index(save_path):
    """加载拼音索引映射"""
    with open(save_path, "r", encoding="utf-8") as f:
        data = json.load(f)
    return data["pinyin_to_index"], data["index_to_pinyin"]

pinyin_to_index, index_to_pinyin = load_pinyin_index("pinyin_index.json")

def add_audio(text, audio, audio_list):
    if not text:
        return audio_list, "请先输入唤醒词文本"
    if audio is None or audio[1] is None or len(audio[1]) == 0:
        return audio_list, "请上传或录制音频"
    audio_list = audio_list or []
    if len(audio_list) >= 5:
        return audio_list, "最多支持5条音频"
    audio_list.append(audio)
    return audio_list, f"已录入 {len(audio_list)} 条音频"

def register_keyword(text, audio_list):
    if not text:
        return gr.update(value="请先输入唤醒词文本")
    if not audio_list or len(audio_list) == 0:
        return gr.update(value="请至少上传或录制一条音频")
    registered["text"] = text
    registered["audios"] = audio_list
    global enroll_text
    enroll_text = text
    fused_feats = []
    for audio in audio_list:
        anchor_wave, _ = torchaudio.load(audio)
        anchor_text_embedding = torch.tensor([pinyin_to_index[p] + 1 for p in lazy_pinyin(text)])
        anchor_wave = anchor_wave.to(device)
        anchor_text_embedding = anchor_text_embedding.to(device).unsqueeze(0)
        with torch.no_grad():
            outputs = wrapper._hubert_model(anchor_wave.half())
            anchor_wave_embedding = outputs.last_hidden_state
        anchor_wave_embedding = anchor_wave_embedding.to(anchor_wave.dtype)
        fused_feat = wrapper.model.enrollment(
            anchor_wave_embedding,
            anchor_text_embedding
        )
        fused_feats.append(fused_feat)
    fused_feats = torch.cat(fused_feats, dim=0)
    fused_feats, _ = fused_feats.max(dim=0)
    fused_feats = fused_feats.unsqueeze(0)
    global enroll
    enroll = fused_feats
    return gr.update(value=f"注册完成,唤醒词:{text},音频数:{len(audio_list)}")

def update_gallery(audio_list):
    if audio_list and len(audio_list) > 0:
        return gr.update(visible=True, value=audio_list[-1])
    else:
        return gr.update(visible=False, value=None)

def streaming_detect_handler(current_audio, state, audio_player):
    global last_wake_time, enroll_text, enroll
    if current_audio is None or current_audio[1] is None or len(current_audio[1]) == 0:
        return state, gr.update()
    if enroll_text is None:
        return state, gr.update()
    pad = len(enroll_text) * 5
    state = (state or []) + [current_audio]
    state = state[-pad:]
    if len(state) < pad:
        return state, gr.update()
    sr = state[0][0]
    audio_list = [x[1] for x in state]
    audio_concat = np.concatenate(audio_list, axis=0)
    audio_concat = audio_concat.astype(np.float32) / 32768.0
    audio_concat = torch.from_numpy(audio_concat).unsqueeze(0)
    audio_concat = torchaudio.functional.resample(audio_concat, sr, 16000)
    audio_concat = audio_concat / torch.max(torch.abs(audio_concat))
    compare_wave = fbank(audio_concat, num_mel_bins=80)
    compare_wave = compare_wave.to(device).unsqueeze(0)
    compare_lengths = torch.tensor([compare_wave.size(1)], device=compare_wave.device)
    if enroll is None:
        return state, None
    current_time = time.time()
    if current_time - last_wake_time <= 2:
        return state, gr.update()
    with torch.no_grad():
        preds = wrapper.model.verification(
            enroll,
            compare_wave,
            compare_lengths
        )
    preds = torch.sigmoid(preds).item()
    if preds >= 0.85:
        print(f"Wake up! {preds}")
        last_wake_time = current_time
        audio_path = "tts-2025-04-27@197cd135f2a2451b9cab9cf2add1c1ab.wav"
        state = []
        return state, gr.Audio(value=audio_path, visible=True, autoplay=True)
    if preds >= 0.6:
        print(f"Preds! {preds}")
    return state, None

with gr.Blocks() as demo:
    gr.Markdown("# 自定义关键词检测 Demo")
    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("## 注册唤醒词")
            text_input = gr.Textbox(label="唤醒词文本", placeholder="请输入唤醒词")
            audio_list_state = gr.State([])
            audio_input = gr.Audio(label="上传或录制音频", type="filepath")
            add_btn = gr.Button("添加音频")
            audio_status = gr.Textbox(label="音频状态", interactive=False)
            audio_gallery = gr.Audio(label="已添加音频", type="filepath", interactive=False, visible=False)
            register_btn = gr.Button("注册完成")
            register_status = gr.Textbox(label="注册状态", interactive=False)
            add_btn.click(
                add_audio,
                inputs=[text_input, audio_input, audio_list_state],
                outputs=[audio_list_state, audio_status]
            ).then(
                update_gallery,
                inputs=audio_list_state,
                outputs=audio_gallery
            )
            register_btn.click(
                register_keyword,
                inputs=[text_input, audio_list_state],
                outputs=register_status
            )
        with gr.Column(scale=2):
            gr.Markdown("## 实时检测")
            mic = gr.Audio(sources="microphone", streaming=True, label="实时监听")
            state = gr.State(value=[])
            audio_player = gr.Audio(label="唤醒提示", visible=False)
            mic.stream(
                streaming_detect_handler,
                inputs=[mic, state, audio_player],
                outputs=[state, audio_player],
                time_limit=1000,
                stream_every=0.05,
            )
if __name__ == "__main__":
    demo.launch()