|
|
import gc |
|
|
import time |
|
|
import numpy as np |
|
|
import onnxruntime |
|
|
from datetime import timedelta |
|
|
from pydub import AudioSegment |
|
|
from silero_vad import load_silero_vad, get_speech_timestamps, VADIterator |
|
|
import os |
|
|
import logging |
|
|
|
|
|
class FixedVADIterator(VADIterator): |
|
|
'''It fixes VADIterator by allowing to process any audio length, not only exactly 512 frames at once. |
|
|
If audio to be processed at once is long and multiple voiced segments detected, |
|
|
then __call__ returns the start of the first segment, and end (or middle, which means no end) of the last segment. |
|
|
''' |
|
|
|
|
|
def reset_states(self): |
|
|
super().reset_states() |
|
|
self.buffer = np.array([],dtype=np.float32) |
|
|
|
|
|
def __call__(self, x, return_seconds=False): |
|
|
self.buffer = np.append(self.buffer, x) |
|
|
ret = None |
|
|
while len(self.buffer) >= 512: |
|
|
r = super().__call__(self.buffer[:512], return_seconds=return_seconds) |
|
|
self.buffer = self.buffer[512:] |
|
|
if ret is None: |
|
|
ret = r |
|
|
elif r is not None: |
|
|
if 'end' in r: |
|
|
ret['end'] = r['end'] |
|
|
if 'start' in r and 'end' in ret: |
|
|
|
|
|
del ret['end'] |
|
|
return ret if ret != {} else None |
|
|
|
|
|
class SileroVADProcessor: |
|
|
""" |
|
|
A class for processing audio files using Silero VAD to detect voice activity |
|
|
and extract voice segments from audio files. |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
activate_threshold=0.5, |
|
|
fusion_threshold=0.3, |
|
|
min_speech_duration=0.25, |
|
|
max_speech_duration=20, |
|
|
min_silence_duration=250, |
|
|
sample_rate=16000, |
|
|
ort_providers=None): |
|
|
""" |
|
|
Initialize the SileroVADProcessor. |
|
|
|
|
|
Args: |
|
|
activate_threshold (float): Threshold for voice activity detection |
|
|
fusion_threshold (float): Threshold for merging close speech segments (seconds) |
|
|
min_speech_duration (float): Minimum duration of speech to be considered valid (seconds) |
|
|
max_speech_duration (float): Maximum duration of speech (seconds) |
|
|
min_silence_duration (int): Minimum silence duration (ms) |
|
|
sample_rate (int): Sample rate of the audio (8000 or 16000 Hz) |
|
|
ort_providers (list): ONNX Runtime providers for acceleration |
|
|
""" |
|
|
|
|
|
self.activate_threshold = activate_threshold |
|
|
self.fusion_threshold = fusion_threshold |
|
|
self.min_speech_duration = min_speech_duration |
|
|
self.max_speech_duration = max_speech_duration |
|
|
self.min_silence_duration = min_silence_duration |
|
|
self.sample_rate = sample_rate |
|
|
self.ort_providers = ort_providers if ort_providers else [] |
|
|
|
|
|
|
|
|
self.logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
self._init_onnx_session() |
|
|
self.silero_vad = load_silero_vad(onnx=True) |
|
|
|
|
|
def _init_onnx_session(self): |
|
|
"""Initialize ONNX Runtime session with appropriate settings.""" |
|
|
session_opts = onnxruntime.SessionOptions() |
|
|
session_opts.log_severity_level = 3 |
|
|
session_opts.inter_op_num_threads = 0 |
|
|
session_opts.intra_op_num_threads = 0 |
|
|
session_opts.enable_cpu_mem_arena = True |
|
|
session_opts.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL |
|
|
session_opts.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL |
|
|
|
|
|
session_opts.add_session_config_entry("session.intra_op.allow_spinning", "1") |
|
|
session_opts.add_session_config_entry("session.inter_op.allow_spinning", "1") |
|
|
session_opts.add_session_config_entry("session.set_denormal_as_zero", "1") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_audio(self, audio_path): |
|
|
""" |
|
|
Load audio file and prepare it for VAD processing. |
|
|
|
|
|
Args: |
|
|
audio_path (str): Path to the audio file |
|
|
|
|
|
Returns: |
|
|
numpy.ndarray: Audio data as numpy array |
|
|
""" |
|
|
self.logger.info(f"Loading audio from {audio_path}") |
|
|
audio_segment = AudioSegment.from_file(audio_path) |
|
|
audio_segment = audio_segment.set_channels(1).set_frame_rate(self.sample_rate) |
|
|
|
|
|
|
|
|
dtype = np.float16 if self.use_gpu_fp16 else np.float32 |
|
|
audio_array = np.array(audio_segment.get_array_of_samples(), dtype=dtype) * 0.000030517578 |
|
|
|
|
|
self.audio_segment = audio_segment |
|
|
return audio_array |
|
|
|
|
|
@property |
|
|
def model(self): |
|
|
return self.silero_vad |
|
|
|
|
|
def process_timestamps(self, timestamps): |
|
|
""" |
|
|
Process VAD timestamps: filter short segments and merge close segments. |
|
|
|
|
|
Args: |
|
|
timestamps (list): List of (start, end) tuples |
|
|
|
|
|
Returns: |
|
|
list: Processed list of (start, end) tuples |
|
|
""" |
|
|
|
|
|
filtered_timestamps = [(start, end) for start, end in timestamps |
|
|
if (end - start) >= self.min_speech_duration] |
|
|
|
|
|
|
|
|
fused_timestamps_1st = [] |
|
|
for start, end in filtered_timestamps: |
|
|
if fused_timestamps_1st and (start - fused_timestamps_1st[-1][1] <= self.fusion_threshold): |
|
|
fused_timestamps_1st[-1] = (fused_timestamps_1st[-1][0], end) |
|
|
else: |
|
|
fused_timestamps_1st.append((start, end)) |
|
|
|
|
|
fused_timestamps_2nd = [] |
|
|
for start, end in fused_timestamps_1st: |
|
|
if fused_timestamps_2nd and (start - fused_timestamps_2nd[-1][1] <= self.fusion_threshold): |
|
|
fused_timestamps_2nd[-1] = (fused_timestamps_2nd[-1][0], end) |
|
|
else: |
|
|
fused_timestamps_2nd.append((start, end)) |
|
|
|
|
|
return fused_timestamps_2nd |
|
|
|
|
|
def format_time(self, seconds): |
|
|
""" |
|
|
Convert seconds to VTT time format 'hh:mm:ss.mmm'. |
|
|
|
|
|
Args: |
|
|
seconds (float): Time in seconds |
|
|
|
|
|
Returns: |
|
|
str: Formatted time string |
|
|
""" |
|
|
td = timedelta(seconds=seconds) |
|
|
td_sec = td.total_seconds() |
|
|
total_seconds = int(td_sec) |
|
|
milliseconds = int((td_sec - total_seconds) * 1000) |
|
|
hours = total_seconds // 3600 |
|
|
minutes = (total_seconds % 3600) // 60 |
|
|
seconds = total_seconds % 60 |
|
|
return f"{hours:02}:{minutes:02}:{seconds:02}.{milliseconds:03}" |
|
|
|
|
|
def detect_speech(self, audio:np.array): |
|
|
""" |
|
|
Run VAD on the audio file to detect speech segments. |
|
|
|
|
|
Args: |
|
|
audio_path (str): Path to the audio file |
|
|
|
|
|
Returns: |
|
|
list: List of processed timestamps as (start, end) tuples |
|
|
""" |
|
|
self.logger.info("Starting VAD process") |
|
|
start_time = time.time() |
|
|
|
|
|
raw_timestamps = get_speech_timestamps( |
|
|
audio, |
|
|
model=self.silero_vad, |
|
|
threshold=self.activate_threshold, |
|
|
max_speech_duration_s=self.max_speech_duration, |
|
|
min_speech_duration_ms=int(self.min_speech_duration * 1000), |
|
|
min_silence_duration_ms=self.min_silence_duration, |
|
|
return_seconds=True |
|
|
) |
|
|
|
|
|
|
|
|
timestamps = [(item['start'], item['end']) for item in raw_timestamps] |
|
|
processed_timestamps = self.process_timestamps(timestamps) |
|
|
|
|
|
|
|
|
del audio |
|
|
gc.collect() |
|
|
|
|
|
self.logger.info(f"VAD completed in {time.time() - start_time:.3f} seconds") |
|
|
return processed_timestamps |
|
|
|
|
|
""" |
|
|
Save timestamps in both second and sample indices formats. |
|
|
|
|
|
Args: |
|
|
timestamps (list): List of (start, end) tuples |
|
|
output_prefix (str): Prefix for output files |
|
|
""" |
|
|
|
|
|
seconds_path = f"{output_prefix}_timestamps_second.txt" |
|
|
with open(seconds_path, "w", encoding='UTF-8') as file: |
|
|
self.logger.info("Saving timestamps in seconds format") |
|
|
for start, end in timestamps: |
|
|
s_time = self.format_time(start) |
|
|
e_time = self.format_time(end) |
|
|
line = f"{s_time} --> {e_time}\n" |
|
|
file.write(line) |
|
|
|
|
|
|
|
|
indices_path = f"{output_prefix}_timestamps_indices.txt" |
|
|
with open(indices_path, "w", encoding='UTF-8') as file: |
|
|
self.logger.info("Saving timestamps in indices format") |
|
|
for start, end in timestamps: |
|
|
line = f"{int(start * self.sample_rate)} --> {int(end * self.sample_rate)}\n" |
|
|
file.write(line) |
|
|
|
|
|
self.logger.info(f"Timestamps saved to {seconds_path} and {indices_path}") |
|
|
|
|
|
def extract_speech_segments(self, audio_segment, timestamps): |
|
|
""" |
|
|
Extract speech segments from the audio and combine them into a single audio file. |
|
|
|
|
|
Args: |
|
|
timestamps (list): List of (start, end) tuples indicating speech segments |
|
|
|
|
|
Returns: |
|
|
AudioSegment: The combined speech segments |
|
|
""" |
|
|
audio_segment = audio_segment.numpy() |
|
|
combined_speech = np.array([], dtype=np.float32) |
|
|
|
|
|
|
|
|
for i, (start, end) in enumerate(timestamps): |
|
|
|
|
|
start_ms = int(start * 1000) |
|
|
end_ms = int(end * 1000) |
|
|
|
|
|
|
|
|
if end_ms > len(audio_segment): |
|
|
end_ms = len(audio_segment) |
|
|
|
|
|
|
|
|
segment = audio_segment[start_ms:end_ms] |
|
|
|
|
|
|
|
|
combined_speech = np.append(combined_speech, segment) |
|
|
|
|
|
return combined_speech |
|
|
|
|
|
def process_audio(self, audio_array:np.array): |
|
|
""" |
|
|
Complete processing pipeline: detect speech, save timestamps, and optionally extract speech. |
|
|
|
|
|
Returns: |
|
|
tuple: (timestamps, output_speech_path if extract_speech else None) |
|
|
""" |
|
|
|
|
|
|
|
|
timestamps = self.detect_speech(audio_array) |
|
|
|
|
|
combined_speech = self.extract_speech_segments(audio_array, timestamps) |
|
|
|
|
|
return timestamps, combined_speech |
|
|
|