File size: 9,730 Bytes
0ddb4a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Kai Huang)

import enum
import logging
from collections import deque

import numpy as np

from .constants import FRAME_LENGTH_MS, FRAME_SHIFT_MS, FRAME_LENGTH_S, FRAME_SHIFT_S

logger = logging.getLogger(__name__)


@enum.unique
class VadState(enum.Enum):
    SILENCE = 0
    POSSIBLE_SPEECH = 1
    SPEECH = 2
    POSSIBLE_SILENCE = 3


class VadPostprocessor:
    def __init__(self, smooth_window_size,
                 prob_threshold,
                 min_speech_frame,
                 max_speech_frame,
                 min_silence_frame,
                 merge_silence_frame,
                 extend_speech_frame):
        self.smooth_window_size = max(1, smooth_window_size)
        self.prob_threshold = prob_threshold
        self.min_speech_frame = min_speech_frame
        self.max_speech_frame = max_speech_frame
        self.min_silence_frame = min_silence_frame
        self.merge_silence_frame = merge_silence_frame
        self.extend_speech_frame = extend_speech_frame

    def process(self, raw_probs):
        if not raw_probs:
            return []

        smoothed_probs = self._smooth_prob(raw_probs)

        binary_preds = self._apply_threshold(smoothed_probs)

        # decision: 0 means silence, 1 means speech
        decisions = self._smooth_preds_with_state_machine(binary_preds)

        fixed_decisions = self._fix_smooth_window_start(decisions)
        smoothed_decisions = self._merge_short_silence_segments(fixed_decisions)
        extend_decisions = self._extend_speech_segments(smoothed_decisions)
        final_decisions = self._split_long_speech_segments(extend_decisions, raw_probs)
        # don't call _merge_short_silence_segments after _split_long_speech_segments

        return final_decisions

    def decision_to_segment(self, decisions, wav_dur=None):
        segments = []
        speech_start = None
        for t, decision in enumerate(decisions):
            if decision == 1 and speech_start is None:
                speech_start = t
            elif decision == 0 and speech_start is not None:
                if (t - speech_start) < self.min_speech_frame:
                    logger.warning("Unexpected short speech segment, check vad_postprocessor.py")
                segments.append((speech_start * FRAME_SHIFT_S,
                                 t * FRAME_SHIFT_S))
                speech_start = None
        if speech_start is not None:
            t = len(decisions) - 1
            if (t - speech_start) < self.min_speech_frame:
                logger.warning("Unexpected short speech segment, check vad_postprocessor.py")
            end_time = len(decisions) * FRAME_SHIFT_S + FRAME_LENGTH_S
            if wav_dur is not None:
                end_time = min(end_time, wav_dur)
            segments.append((speech_start * FRAME_SHIFT_S,
                             end_time))
        segments = [(round(s, 3), round(e, 3)) for s, e in segments]
        return segments

    def _smooth_prob_simple(self, probs):
        if self.smooth_window_size <= 1:
            return probs
        smoothed_probs = probs.copy()
        window = deque()
        window_sum = 0.0
        for i, p in enumerate(probs):
            window.append(p)
            window_sum += p
            if len(window) > self.smooth_window_size:
                left = window.popleft()
                window_sum -= left
            window_avg = window_sum / len(window)
            smoothed_probs[i] = window_avg
        return smoothed_probs

    def _smooth_prob(self, probs):
        if self.smooth_window_size <= 1:
            return np.asarray(probs)
        probs_np = np.array(probs)
        kernel = np.ones(self.smooth_window_size) / self.smooth_window_size
        # mode='same' 保持长度,'valid' 会变短
        smoothed = np.convolve(probs_np, kernel, mode='full')[:len(probs)]
        # 处理边界:前几帧用累积平均
        for i in range(min(self.smooth_window_size - 1, len(probs))):
            smoothed[i] = np.mean(probs_np[:i+1])
        return smoothed #.tolist()

    def _apply_threshold_simple(self, probs):
        return [int(p >= self.prob_threshold) for p in probs]

    def _apply_threshold(self, probs):
        probs_np = np.asarray(probs)
        return (probs_np >= self.prob_threshold).astype(int).tolist()

    def _smooth_preds_with_state_machine(self, binary_preds):
        """
        state transition is constrained by min_speech_frame & min_silence_frame
        """
        if self.min_speech_frame <= 0 and self.min_silence_frame <= 0:
            return binary_preds
        decisions = [0] * len(binary_preds)
        state = VadState.SILENCE
        speech_start = -1
        silence_start = -1
        for t, is_speech in enumerate(binary_preds):
            # State transition
            if state == VadState.SILENCE:
                if is_speech:
                    state = VadState.POSSIBLE_SPEECH
                    speech_start = t

            elif state == VadState.POSSIBLE_SPEECH:
                if is_speech:
                    assert speech_start != -1
                    if t - speech_start >= self.min_speech_frame:
                        state = VadState.SPEECH
                        decisions[speech_start:t] = [1] * (t - speech_start)
                else:
                    state = VadState.SILENCE
                    speech_start = -1

            elif state == VadState.SPEECH:
                if not is_speech:
                    state = VadState.POSSIBLE_SILENCE
                    silence_start = t

            elif state == VadState.POSSIBLE_SILENCE:
                if not is_speech:
                    assert silence_start != -1
                    if t - silence_start >= self.min_silence_frame:
                        state = VadState.SILENCE
                        speech_start = -1
                else:
                    state = VadState.SPEECH
                    silence_start = -1

            # current frame's decision
            if state == VadState.SPEECH or state == VadState.POSSIBLE_SILENCE:
                decision = 1
            elif state == VadState.SILENCE or state == VadState.POSSIBLE_SPEECH:
                decision = 0
            else:
                raise ValueError("Impossible VAD state")

            decisions[t] = decision
        return decisions

    def _fix_smooth_window_start(self, decisions):
        new_decisions = decisions.copy()
        for t, decision in enumerate(decisions):
            if t > 0 and decisions[t-1] == 0 and decision == 1:
                start = max(0, t-self.smooth_window_size)
                new_decisions[start:t] = [1] * (t - start)
        return new_decisions

    def _merge_short_silence_segments(self, decisions):
        if self.merge_silence_frame <= 0:
            return decisions
        new_decisions = decisions.copy()
        silence_start = None
        for t, decision in enumerate(decisions):
            if t > 0 and decisions[t-1] == 1 and decision == 0 and silence_start is None:
                silence_start = t
            elif t > 0 and decisions[t-1] == 0 and decision == 1 and silence_start is not None:
                silence_frame = t - silence_start
                if silence_frame < self.merge_silence_frame:
                    new_decisions[silence_start:t] = [1] * silence_frame
                silence_start = None
        return new_decisions

    def _extend_speech_segments_simple(self, decisions):
        """
        extend N frames before & after speech segments
        """
        if self.extend_speech_frame <= 0:
            return decisions
        new_decisions = decisions.copy()
        for t, decision in enumerate(decisions):
            if decision == 1:
                start = max(0, t - self.extend_speech_frame)
                end = min(len(decisions), t + self.extend_speech_frame + 1)
                new_decisions[start:end] = [1] * (end - start)
        return new_decisions

    def _extend_speech_segments(self, decisions):
        """
        extend N frames before & after speech segments
        """
        if self.extend_speech_frame <= 0:
            return decisions
        decisions_np = np.array(decisions)
        kernel = np.ones(2 * self.extend_speech_frame + 1)
        extended = np.convolve(decisions_np, kernel, mode='same')
        return (extended > 0).astype(int).tolist()

    def _split_long_speech_segments(self, decisions, probs):
        new_decisions = decisions.copy()
        segments = self.decision_to_segment(decisions)
        for start_s, end_s in segments:
            start_frame = int(start_s / FRAME_SHIFT_S)
            end_frame = int(end_s / FRAME_SHIFT_S)
            dur_frames = end_frame - start_frame
            if dur_frames > self.max_speech_frame:
                segment_probs = probs[start_frame:end_frame]
                split_points = self._find_split_points(segment_probs)
                for split_point in split_points:
                    split_frame = start_frame + split_point
                    new_decisions[split_frame] = 0
        return new_decisions

    def _find_split_points(self, probs):
        split_points = []
        length = len(probs)
        start = 0
        while start < length:
            if (length - start) <= self.max_speech_frame:
                break
            window_start = int(start + self.max_speech_frame / 2)
            window_end = int(start + self.max_speech_frame)
            window_probs = probs[window_start:window_end]

            min_index = window_start + np.argmin(window_probs)
            split_points.append(min_index)

            start = min_index + 1
        return split_points