FireRedASR2S / fireredasr2s /fireredvad /core /vad_postprocessor.py
FireRed Team
Init
0ddb4a4 verified
# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Kai Huang)
import enum
import logging
from collections import deque
import numpy as np
from .constants import FRAME_LENGTH_MS, FRAME_SHIFT_MS, FRAME_LENGTH_S, FRAME_SHIFT_S
logger = logging.getLogger(__name__)
@enum.unique
class VadState(enum.Enum):
SILENCE = 0
POSSIBLE_SPEECH = 1
SPEECH = 2
POSSIBLE_SILENCE = 3
class VadPostprocessor:
def __init__(self, smooth_window_size,
prob_threshold,
min_speech_frame,
max_speech_frame,
min_silence_frame,
merge_silence_frame,
extend_speech_frame):
self.smooth_window_size = max(1, smooth_window_size)
self.prob_threshold = prob_threshold
self.min_speech_frame = min_speech_frame
self.max_speech_frame = max_speech_frame
self.min_silence_frame = min_silence_frame
self.merge_silence_frame = merge_silence_frame
self.extend_speech_frame = extend_speech_frame
def process(self, raw_probs):
if not raw_probs:
return []
smoothed_probs = self._smooth_prob(raw_probs)
binary_preds = self._apply_threshold(smoothed_probs)
# decision: 0 means silence, 1 means speech
decisions = self._smooth_preds_with_state_machine(binary_preds)
fixed_decisions = self._fix_smooth_window_start(decisions)
smoothed_decisions = self._merge_short_silence_segments(fixed_decisions)
extend_decisions = self._extend_speech_segments(smoothed_decisions)
final_decisions = self._split_long_speech_segments(extend_decisions, raw_probs)
# don't call _merge_short_silence_segments after _split_long_speech_segments
return final_decisions
def decision_to_segment(self, decisions, wav_dur=None):
segments = []
speech_start = None
for t, decision in enumerate(decisions):
if decision == 1 and speech_start is None:
speech_start = t
elif decision == 0 and speech_start is not None:
if (t - speech_start) < self.min_speech_frame:
logger.warning("Unexpected short speech segment, check vad_postprocessor.py")
segments.append((speech_start * FRAME_SHIFT_S,
t * FRAME_SHIFT_S))
speech_start = None
if speech_start is not None:
t = len(decisions) - 1
if (t - speech_start) < self.min_speech_frame:
logger.warning("Unexpected short speech segment, check vad_postprocessor.py")
end_time = len(decisions) * FRAME_SHIFT_S + FRAME_LENGTH_S
if wav_dur is not None:
end_time = min(end_time, wav_dur)
segments.append((speech_start * FRAME_SHIFT_S,
end_time))
segments = [(round(s, 3), round(e, 3)) for s, e in segments]
return segments
def _smooth_prob_simple(self, probs):
if self.smooth_window_size <= 1:
return probs
smoothed_probs = probs.copy()
window = deque()
window_sum = 0.0
for i, p in enumerate(probs):
window.append(p)
window_sum += p
if len(window) > self.smooth_window_size:
left = window.popleft()
window_sum -= left
window_avg = window_sum / len(window)
smoothed_probs[i] = window_avg
return smoothed_probs
def _smooth_prob(self, probs):
if self.smooth_window_size <= 1:
return np.asarray(probs)
probs_np = np.array(probs)
kernel = np.ones(self.smooth_window_size) / self.smooth_window_size
# mode='same' 保持长度,'valid' 会变短
smoothed = np.convolve(probs_np, kernel, mode='full')[:len(probs)]
# 处理边界:前几帧用累积平均
for i in range(min(self.smooth_window_size - 1, len(probs))):
smoothed[i] = np.mean(probs_np[:i+1])
return smoothed #.tolist()
def _apply_threshold_simple(self, probs):
return [int(p >= self.prob_threshold) for p in probs]
def _apply_threshold(self, probs):
probs_np = np.asarray(probs)
return (probs_np >= self.prob_threshold).astype(int).tolist()
def _smooth_preds_with_state_machine(self, binary_preds):
"""
state transition is constrained by min_speech_frame & min_silence_frame
"""
if self.min_speech_frame <= 0 and self.min_silence_frame <= 0:
return binary_preds
decisions = [0] * len(binary_preds)
state = VadState.SILENCE
speech_start = -1
silence_start = -1
for t, is_speech in enumerate(binary_preds):
# State transition
if state == VadState.SILENCE:
if is_speech:
state = VadState.POSSIBLE_SPEECH
speech_start = t
elif state == VadState.POSSIBLE_SPEECH:
if is_speech:
assert speech_start != -1
if t - speech_start >= self.min_speech_frame:
state = VadState.SPEECH
decisions[speech_start:t] = [1] * (t - speech_start)
else:
state = VadState.SILENCE
speech_start = -1
elif state == VadState.SPEECH:
if not is_speech:
state = VadState.POSSIBLE_SILENCE
silence_start = t
elif state == VadState.POSSIBLE_SILENCE:
if not is_speech:
assert silence_start != -1
if t - silence_start >= self.min_silence_frame:
state = VadState.SILENCE
speech_start = -1
else:
state = VadState.SPEECH
silence_start = -1
# current frame's decision
if state == VadState.SPEECH or state == VadState.POSSIBLE_SILENCE:
decision = 1
elif state == VadState.SILENCE or state == VadState.POSSIBLE_SPEECH:
decision = 0
else:
raise ValueError("Impossible VAD state")
decisions[t] = decision
return decisions
def _fix_smooth_window_start(self, decisions):
new_decisions = decisions.copy()
for t, decision in enumerate(decisions):
if t > 0 and decisions[t-1] == 0 and decision == 1:
start = max(0, t-self.smooth_window_size)
new_decisions[start:t] = [1] * (t - start)
return new_decisions
def _merge_short_silence_segments(self, decisions):
if self.merge_silence_frame <= 0:
return decisions
new_decisions = decisions.copy()
silence_start = None
for t, decision in enumerate(decisions):
if t > 0 and decisions[t-1] == 1 and decision == 0 and silence_start is None:
silence_start = t
elif t > 0 and decisions[t-1] == 0 and decision == 1 and silence_start is not None:
silence_frame = t - silence_start
if silence_frame < self.merge_silence_frame:
new_decisions[silence_start:t] = [1] * silence_frame
silence_start = None
return new_decisions
def _extend_speech_segments_simple(self, decisions):
"""
extend N frames before & after speech segments
"""
if self.extend_speech_frame <= 0:
return decisions
new_decisions = decisions.copy()
for t, decision in enumerate(decisions):
if decision == 1:
start = max(0, t - self.extend_speech_frame)
end = min(len(decisions), t + self.extend_speech_frame + 1)
new_decisions[start:end] = [1] * (end - start)
return new_decisions
def _extend_speech_segments(self, decisions):
"""
extend N frames before & after speech segments
"""
if self.extend_speech_frame <= 0:
return decisions
decisions_np = np.array(decisions)
kernel = np.ones(2 * self.extend_speech_frame + 1)
extended = np.convolve(decisions_np, kernel, mode='same')
return (extended > 0).astype(int).tolist()
def _split_long_speech_segments(self, decisions, probs):
new_decisions = decisions.copy()
segments = self.decision_to_segment(decisions)
for start_s, end_s in segments:
start_frame = int(start_s / FRAME_SHIFT_S)
end_frame = int(end_s / FRAME_SHIFT_S)
dur_frames = end_frame - start_frame
if dur_frames > self.max_speech_frame:
segment_probs = probs[start_frame:end_frame]
split_points = self._find_split_points(segment_probs)
for split_point in split_points:
split_frame = start_frame + split_point
new_decisions[split_frame] = 0
return new_decisions
def _find_split_points(self, probs):
split_points = []
length = len(probs)
start = 0
while start < length:
if (length - start) <= self.max_speech_frame:
break
window_start = int(start + self.max_speech_frame / 2)
window_end = int(start + self.max_speech_frame)
window_probs = probs[window_start:window_end]
min_index = window_start + np.argmin(window_probs)
split_points.append(min_index)
start = min_index + 1
return split_points