Spaces:
Running on Zero
Running on Zero
File size: 9,730 Bytes
0ddb4a4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 | # Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Kai Huang)
import enum
import logging
from collections import deque
import numpy as np
from .constants import FRAME_LENGTH_MS, FRAME_SHIFT_MS, FRAME_LENGTH_S, FRAME_SHIFT_S
logger = logging.getLogger(__name__)
@enum.unique
class VadState(enum.Enum):
SILENCE = 0
POSSIBLE_SPEECH = 1
SPEECH = 2
POSSIBLE_SILENCE = 3
class VadPostprocessor:
def __init__(self, smooth_window_size,
prob_threshold,
min_speech_frame,
max_speech_frame,
min_silence_frame,
merge_silence_frame,
extend_speech_frame):
self.smooth_window_size = max(1, smooth_window_size)
self.prob_threshold = prob_threshold
self.min_speech_frame = min_speech_frame
self.max_speech_frame = max_speech_frame
self.min_silence_frame = min_silence_frame
self.merge_silence_frame = merge_silence_frame
self.extend_speech_frame = extend_speech_frame
def process(self, raw_probs):
if not raw_probs:
return []
smoothed_probs = self._smooth_prob(raw_probs)
binary_preds = self._apply_threshold(smoothed_probs)
# decision: 0 means silence, 1 means speech
decisions = self._smooth_preds_with_state_machine(binary_preds)
fixed_decisions = self._fix_smooth_window_start(decisions)
smoothed_decisions = self._merge_short_silence_segments(fixed_decisions)
extend_decisions = self._extend_speech_segments(smoothed_decisions)
final_decisions = self._split_long_speech_segments(extend_decisions, raw_probs)
# don't call _merge_short_silence_segments after _split_long_speech_segments
return final_decisions
def decision_to_segment(self, decisions, wav_dur=None):
segments = []
speech_start = None
for t, decision in enumerate(decisions):
if decision == 1 and speech_start is None:
speech_start = t
elif decision == 0 and speech_start is not None:
if (t - speech_start) < self.min_speech_frame:
logger.warning("Unexpected short speech segment, check vad_postprocessor.py")
segments.append((speech_start * FRAME_SHIFT_S,
t * FRAME_SHIFT_S))
speech_start = None
if speech_start is not None:
t = len(decisions) - 1
if (t - speech_start) < self.min_speech_frame:
logger.warning("Unexpected short speech segment, check vad_postprocessor.py")
end_time = len(decisions) * FRAME_SHIFT_S + FRAME_LENGTH_S
if wav_dur is not None:
end_time = min(end_time, wav_dur)
segments.append((speech_start * FRAME_SHIFT_S,
end_time))
segments = [(round(s, 3), round(e, 3)) for s, e in segments]
return segments
def _smooth_prob_simple(self, probs):
if self.smooth_window_size <= 1:
return probs
smoothed_probs = probs.copy()
window = deque()
window_sum = 0.0
for i, p in enumerate(probs):
window.append(p)
window_sum += p
if len(window) > self.smooth_window_size:
left = window.popleft()
window_sum -= left
window_avg = window_sum / len(window)
smoothed_probs[i] = window_avg
return smoothed_probs
def _smooth_prob(self, probs):
if self.smooth_window_size <= 1:
return np.asarray(probs)
probs_np = np.array(probs)
kernel = np.ones(self.smooth_window_size) / self.smooth_window_size
# mode='same' 保持长度,'valid' 会变短
smoothed = np.convolve(probs_np, kernel, mode='full')[:len(probs)]
# 处理边界:前几帧用累积平均
for i in range(min(self.smooth_window_size - 1, len(probs))):
smoothed[i] = np.mean(probs_np[:i+1])
return smoothed #.tolist()
def _apply_threshold_simple(self, probs):
return [int(p >= self.prob_threshold) for p in probs]
def _apply_threshold(self, probs):
probs_np = np.asarray(probs)
return (probs_np >= self.prob_threshold).astype(int).tolist()
def _smooth_preds_with_state_machine(self, binary_preds):
"""
state transition is constrained by min_speech_frame & min_silence_frame
"""
if self.min_speech_frame <= 0 and self.min_silence_frame <= 0:
return binary_preds
decisions = [0] * len(binary_preds)
state = VadState.SILENCE
speech_start = -1
silence_start = -1
for t, is_speech in enumerate(binary_preds):
# State transition
if state == VadState.SILENCE:
if is_speech:
state = VadState.POSSIBLE_SPEECH
speech_start = t
elif state == VadState.POSSIBLE_SPEECH:
if is_speech:
assert speech_start != -1
if t - speech_start >= self.min_speech_frame:
state = VadState.SPEECH
decisions[speech_start:t] = [1] * (t - speech_start)
else:
state = VadState.SILENCE
speech_start = -1
elif state == VadState.SPEECH:
if not is_speech:
state = VadState.POSSIBLE_SILENCE
silence_start = t
elif state == VadState.POSSIBLE_SILENCE:
if not is_speech:
assert silence_start != -1
if t - silence_start >= self.min_silence_frame:
state = VadState.SILENCE
speech_start = -1
else:
state = VadState.SPEECH
silence_start = -1
# current frame's decision
if state == VadState.SPEECH or state == VadState.POSSIBLE_SILENCE:
decision = 1
elif state == VadState.SILENCE or state == VadState.POSSIBLE_SPEECH:
decision = 0
else:
raise ValueError("Impossible VAD state")
decisions[t] = decision
return decisions
def _fix_smooth_window_start(self, decisions):
new_decisions = decisions.copy()
for t, decision in enumerate(decisions):
if t > 0 and decisions[t-1] == 0 and decision == 1:
start = max(0, t-self.smooth_window_size)
new_decisions[start:t] = [1] * (t - start)
return new_decisions
def _merge_short_silence_segments(self, decisions):
if self.merge_silence_frame <= 0:
return decisions
new_decisions = decisions.copy()
silence_start = None
for t, decision in enumerate(decisions):
if t > 0 and decisions[t-1] == 1 and decision == 0 and silence_start is None:
silence_start = t
elif t > 0 and decisions[t-1] == 0 and decision == 1 and silence_start is not None:
silence_frame = t - silence_start
if silence_frame < self.merge_silence_frame:
new_decisions[silence_start:t] = [1] * silence_frame
silence_start = None
return new_decisions
def _extend_speech_segments_simple(self, decisions):
"""
extend N frames before & after speech segments
"""
if self.extend_speech_frame <= 0:
return decisions
new_decisions = decisions.copy()
for t, decision in enumerate(decisions):
if decision == 1:
start = max(0, t - self.extend_speech_frame)
end = min(len(decisions), t + self.extend_speech_frame + 1)
new_decisions[start:end] = [1] * (end - start)
return new_decisions
def _extend_speech_segments(self, decisions):
"""
extend N frames before & after speech segments
"""
if self.extend_speech_frame <= 0:
return decisions
decisions_np = np.array(decisions)
kernel = np.ones(2 * self.extend_speech_frame + 1)
extended = np.convolve(decisions_np, kernel, mode='same')
return (extended > 0).astype(int).tolist()
def _split_long_speech_segments(self, decisions, probs):
new_decisions = decisions.copy()
segments = self.decision_to_segment(decisions)
for start_s, end_s in segments:
start_frame = int(start_s / FRAME_SHIFT_S)
end_frame = int(end_s / FRAME_SHIFT_S)
dur_frames = end_frame - start_frame
if dur_frames > self.max_speech_frame:
segment_probs = probs[start_frame:end_frame]
split_points = self._find_split_points(segment_probs)
for split_point in split_points:
split_frame = start_frame + split_point
new_decisions[split_frame] = 0
return new_decisions
def _find_split_points(self, probs):
split_points = []
length = len(probs)
start = 0
while start < length:
if (length - start) <= self.max_speech_frame:
break
window_start = int(start + self.max_speech_frame / 2)
window_end = int(start + self.max_speech_frame)
window_probs = probs[window_start:window_end]
min_index = window_start + np.argmin(window_probs)
split_points.append(min_index)
start = min_index + 1
return split_points
|