| #include "vad_iterator.h" |
| #include <cmath> |
| #include <cstdio> |
| #include <cstring> |
| #include <memory> |
|
|
|
|
| void VadIterator::init_onnx_model(const std::string& model_path) { |
| init_engine_threads(1, 1); |
| session = std::make_shared<Ort::Session>(env, model_path.c_str(), session_options); |
| } |
|
|
| void VadIterator::init_engine_threads(int inter_threads, int intra_threads) { |
| session_options.SetIntraOpNumThreads(intra_threads); |
| session_options.SetInterOpNumThreads(inter_threads); |
| session_options.SetGraphOptimizationLevel(ORT_ENABLE_ALL); |
| } |
|
|
| void VadIterator::reset_states() { |
| std::memset(_state.data(), 0, _state.size() * sizeof(float)); |
| triggered = false; |
| temp_end = 0; |
| current_sample = 0; |
| prev_end = next_start = 0; |
| speeches.clear(); |
| current_speech = timestamp_t(); |
| std::fill(_context.begin(), _context.end(), 0.0f); |
| } |
|
|
| void VadIterator::predict(const std::vector<float>& data_chunk) { |
| std::vector<float> new_data(effective_window_size, 0.0f); |
| std::copy(_context.begin(), _context.end(), new_data.begin()); |
| std::copy(data_chunk.begin(), data_chunk.end(), new_data.begin() + context_samples); |
| input = new_data; |
|
|
| Ort::Value input_ort = Ort::Value::CreateTensor<float>( |
| memory_info, input.data(), input.size(), input_node_dims, 2); |
| Ort::Value state_ort = Ort::Value::CreateTensor<float>( |
| memory_info, _state.data(), _state.size(), state_node_dims, 3); |
| Ort::Value sr_ort = Ort::Value::CreateTensor<int64_t>( |
| memory_info, sr.data(), sr.size(), sr_node_dims, 1); |
|
|
| ort_inputs.clear(); |
| ort_inputs.emplace_back(std::move(input_ort)); |
| ort_inputs.emplace_back(std::move(state_ort)); |
| ort_inputs.emplace_back(std::move(sr_ort)); |
|
|
| ort_outputs = session->Run( |
| Ort::RunOptions{nullptr}, |
| input_node_names.data(), ort_inputs.data(), ort_inputs.size(), |
| output_node_names.data(), output_node_names.size()); |
|
|
| float speech_prob = ort_outputs[0].GetTensorMutableData<float>()[0]; |
| float* stateN = ort_outputs[1].GetTensorMutableData<float>(); |
| std::memcpy(_state.data(), stateN, size_state * sizeof(float)); |
|
|
| current_sample += static_cast<unsigned int>(window_size_samples); |
|
|
| if (speech_prob >= threshold) { |
| if (temp_end != 0) { |
| temp_end = 0; |
| if (next_start < prev_end) |
| next_start = current_sample - window_size_samples; |
| } |
| if (!triggered) { |
| triggered = true; |
| current_speech.start = current_sample - window_size_samples; |
| } |
| std::copy(new_data.end() - context_samples, new_data.end(), _context.begin()); |
| return; |
| } |
|
|
| if (triggered && ((current_sample - current_speech.start) > max_speech_samples)) { |
| if (prev_end > 0) { |
| current_speech.end = prev_end; |
| speeches.push_back(current_speech); |
| current_speech = timestamp_t(); |
| if (next_start < prev_end) |
| triggered = false; |
| else |
| current_speech.start = next_start; |
| prev_end = 0; |
| next_start = 0; |
| temp_end = 0; |
| } else { |
| current_speech.end = current_sample; |
| speeches.push_back(current_speech); |
| current_speech = timestamp_t(); |
| prev_end = 0; |
| next_start = 0; |
| temp_end = 0; |
| triggered = false; |
| } |
| std::copy(new_data.end() - context_samples, new_data.end(), _context.begin()); |
| return; |
| } |
|
|
| if ((speech_prob >= (threshold - 0.15)) && (speech_prob < threshold)) { |
| std::copy(new_data.end() - context_samples, new_data.end(), _context.begin()); |
| return; |
| } |
|
|
| if (speech_prob < (threshold - 0.15)) { |
| if (triggered) { |
| if (temp_end == 0) |
| temp_end = current_sample; |
| if (current_sample - temp_end > min_silence_samples_at_max_speech) |
| prev_end = temp_end; |
| if ((current_sample - temp_end) >= min_silence_samples) { |
| current_speech.end = temp_end; |
| if (current_speech.end - current_speech.start > min_speech_samples) { |
| speeches.push_back(current_speech); |
| current_speech = timestamp_t(); |
| prev_end = 0; |
| next_start = 0; |
| temp_end = 0; |
| triggered = false; |
| } |
| } |
| } |
| std::copy(new_data.end() - context_samples, new_data.end(), _context.begin()); |
| return; |
| } |
| } |
|
|
| void VadIterator::process(const std::vector<float>& input_wav) { |
| reset_states(); |
| audio_length_samples = static_cast<int>(input_wav.size()); |
|
|
| for (size_t j = 0; j < static_cast<size_t>(audio_length_samples); j += static_cast<size_t>(window_size_samples)) { |
| if (j + static_cast<size_t>(window_size_samples) > static_cast<size_t>(audio_length_samples)) |
| break; |
| std::vector<float> chunk(&input_wav[j], &input_wav[j] + window_size_samples); |
| predict(chunk); |
| } |
|
|
| if (current_speech.start >= 0) { |
| current_speech.end = audio_length_samples; |
| speeches.push_back(current_speech); |
| current_speech = timestamp_t(); |
| prev_end = 0; |
| next_start = 0; |
| temp_end = 0; |
| triggered = false; |
| } |
| } |
|
|
| const std::vector<timestamp_t>& VadIterator::get_speech_timestamps() const { |
| return speeches; |
| } |
|
|
| void VadIterator::reset() { |
| reset_states(); |
| } |
|
|
| |
| VadIterator::VadIterator(const std::string ModelPath, |
| int Sample_rate, |
| int windows_frame_size, |
| float Threshold, |
| int min_silence_duration_ms, |
| int speech_pad_ms, |
| int min_speech_duration_ms, |
| float max_speech_duration_s) |
| : sample_rate(Sample_rate), |
| threshold(Threshold), |
| speech_pad_samples(speech_pad_ms), |
| prev_end(0), |
| memory_info(Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtDeviceAllocator, OrtMemType::OrtMemTypeDefault)) |
| { |
|
|
| sr_per_ms = sample_rate / 1000; |
| window_size_samples = windows_frame_size * sr_per_ms; |
| effective_window_size = window_size_samples + context_samples; |
|
|
| input_node_dims[0] = 1; |
| input_node_dims[1] = effective_window_size; |
|
|
| _state.resize(size_state); |
| sr.resize(1); |
| sr[0] = sample_rate; |
| _context.assign(context_samples, 0.0f); |
|
|
| min_speech_samples = sr_per_ms * min_speech_duration_ms; |
|
|
| if (max_speech_duration_s < 0) { |
| max_speech_samples = std::numeric_limits<float>::infinity(); |
| } else { |
| max_speech_samples = (sample_rate * max_speech_duration_s - window_size_samples - 2 * speech_pad_samples); |
| } |
|
|
| min_silence_samples = sr_per_ms * min_silence_duration_ms; |
| min_silence_samples_at_max_speech = sr_per_ms * 98; |
| |
| init_onnx_model(ModelPath); |
| } |