#include #include #include #include "vad_onnx.h" static void get_input_names(Ort::Session* session, std::vector &input_names_str, std::vector &input_names_char) { Ort::AllocatorWithDefaultOptions allocator; size_t nodes_num = session->GetInputCount(); input_names_str.resize(nodes_num); input_names_char.resize(nodes_num); for (size_t i = 0; i != nodes_num; ++i) { auto t = session->GetInputNameAllocated(i, allocator); input_names_str[i] = t.get(); input_names_char[i] = input_names_str[i].c_str(); } } static void get_output_names(Ort::Session* session, std::vector &output_names_, std::vector &vad_out_names_) { Ort::AllocatorWithDefaultOptions allocator; size_t nodes_num = session->GetOutputCount(); output_names_.resize(nodes_num); vad_out_names_.resize(nodes_num); for (size_t i = 0; i != nodes_num; ++i) { auto t = session->GetOutputNameAllocated(i, allocator); output_names_[i] = t.get(); vad_out_names_[i] = output_names_[i].c_str(); } } VadOnnx::VadOnnx(const std::string& model_path, int batch_size, int thread_num, float threshold, int sampling_rate, int min_silence_duration_ms, float max_speech_duration_s, int speech_pad_ms) : batch_size_(batch_size), thread_num_(thread_num), threshold_(threshold), sample_rates_(sampling_rate), min_silence_samples_(sampling_rate * min_silence_duration_ms / 1000.0), speech_pad_samples_(sampling_rate * speech_pad_ms / 1000.0), triggered_(false), temp_end_(0), current_sample_(0), start_(0), memory_info(Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU)) { init_onnx_model(model_path); get_input_names(session.get(), input_names_, vad_in_names_); get_output_names(session.get(), output_names_, vad_out_names_); sr.resize(1); sr[0] = sample_rates_; if (batch_size_ != 1) { state_shape = {2, batch_size_, 128}; state_size = 2 * batch_size_ * 128; } state_.resize(state_size); context_size = (sample_rates_ == 16000) ? 64 : 32; context_.resize(context_size); effective_window_size = window_size_samples + context_size; input_node_shape[0] = 1; input_node_shape[1] = effective_window_size; reset_states(); } VadOnnx::~VadOnnx() = default; void VadOnnx::reset_states() { std::memset(state_.data(), 0, state_.size() * sizeof(float)); std::fill(context_.begin(), context_.end(), 0.0f); triggered_ = false; temp_end_ = 0; current_sample_ = 0; start_ = 0; last_sr_ = 0; last_batch_size_ = 0; } float VadOnnx::forward_infer(std::vector& data_chunk) { // 合并 context 和 input std::vector x_with_context(effective_window_size, 0.0f); std::copy(context_.begin(), context_.end(), x_with_context.begin()); std::copy(data_chunk.begin(), data_chunk.end(), x_with_context.begin() + context_size); input = x_with_context; // Prepare inputs Ort::Value input_tensor = Ort::Value::CreateTensor( memory_info, input.data(), input.size(), input_node_shape.data(), 2); Ort::Value state_tensor = Ort::Value::CreateTensor( memory_info, state_.data(), state_.size(), state_shape.data(), 3); Ort::Value sr_tensor = Ort::Value::CreateTensor( memory_info, sr.data(), 1, sr_shape.data(), 1); ort_inputs.clear(); ort_inputs.emplace_back(std::move(input_tensor)); ort_inputs.emplace_back(std::move(state_tensor)); ort_inputs.emplace_back(std::move(sr_tensor)); // Run inference ort_outputs = session->Run( Ort::RunOptions{nullptr}, vad_in_names_.data(), ort_inputs.data(), ort_inputs.size(), vad_out_names_.data(), vad_out_names_.size()); // Get output float speech_prob = ort_outputs[0].GetTensorMutableData()[0]; // Update state float* stateN = ort_outputs[1].GetTensorMutableData(); std::memcpy(state_.data(), stateN, state_size * sizeof(float)); // Update context std::copy(x_with_context.end() - context_size, x_with_context.end(), context_.begin()); return speech_prob; } std::vector VadOnnx::vad_dectect(std::vector& audio) { std::vector result; // Pad to multiple of num_samples int pad_num = (window_size_samples - (audio.size() % window_size_samples)) % window_size_samples; audio.insert(audio.end(), pad_num, 0.0f); for (size_t i = 0; i < audio.size(); i += window_size_samples) { std::vector chunk(audio.begin() + i, audio.begin() + i + window_size_samples); auto prob = forward_infer(chunk); result.emplace_back(prob); } return result; } std::map VadOnnx::vad_dectect(std::vector& audio, bool return_seconds) { std::map result; // 将新音频追加到缓存中 buffer_.insert(buffer_.end(), audio.begin(), audio.end()); while (buffer_.size() > 0) { std::map tmp; std::vector chunk(buffer_.begin(), buffer_.begin() + std::min(static_cast(buffer_.size()), window_size_samples)); // 补零到固定长度 if (chunk.size() < static_cast(window_size_samples)) { chunk.resize(window_size_samples, 0.0f); } current_sample_ += window_size_samples; // 推理得到语音概率 float speech_prob = forward_infer(chunk); if (speech_prob >= threshold_ && temp_end_ > 0) { temp_end_ = 0; } if (speech_prob >= threshold_ && !triggered_) { triggered_ = true; start_ = std::max(0.0, current_sample_ - window_size_samples); tmp["start"] = return_seconds ? start_ / sample_rates_ : start_; } if (speech_prob < (threshold_ - 0.15) && triggered_) { if (temp_end_ == 0) { temp_end_ = current_sample_; } if (current_sample_ - temp_end_ >= min_silence_samples_) { double speech_end = temp_end_; tmp["end"] = return_seconds ? speech_end / sample_rates_ : speech_end; temp_end_ = 0; triggered_ = false; } } // 移除已处理的数据 if (window_size_samples >= buffer_.size()) { buffer_.clear(); // 全部丢弃 } else { std::copy(buffer_.begin() + window_size_samples, buffer_.end(), buffer_.begin()); buffer_.resize(buffer_.size() - window_size_samples); } // 合并检测结果 if (result.empty()) { result = tmp; } else if (!tmp.empty()) { // 如果当前结果有 'end',更新最终 end if (tmp.find("end") != tmp.end()) { result["end"] = tmp["end"]; } // 如果有新的 start,但前一个有 end,则合并成连续语音段 if (tmp.find("start") != tmp.end() && result.find("end") != result.end()) { result.erase("end"); } } } return result; } void VadOnnx::init_onnx_model(const std::string& model_path) { init_engine_threads(1, 1); init_exec_provider(); // 初始化 ONNX Session env_ = Ort::Env(ORT_LOGGING_LEVEL_WARNING, "VadOnnx"); session = std::make_unique(env_, ORTCHAR(model_path.c_str()), session_options); } void VadOnnx::init_engine_threads(int inter_threads, int intra_threads) { session_options.SetInterOpNumThreads(inter_threads); session_options.SetIntraOpNumThreads(intra_threads); session_options.SetGraphOptimizationLevel(ORT_ENABLE_ALL); } void VadOnnx::init_exec_provider() { // 获取所有可用的 Execution Providers std::vector providers = Ort::GetAvailableProviders(); // 根据支持情况添加 Execution Provider if (std::find(providers.begin(), providers.end(), "CUDAExecutionProvider") != providers.end()) { OrtCUDAProviderOptions cuda_options{}; session_options.AppendExecutionProvider_CUDA(cuda_options); } // #if defined(__APPLE__) // if (std::find(providers.begin(), providers.end(), "CoreMLExecutionProvider") != providers.end()) { // session_options.AppendExecutionProvider_CoreML(); // } // #endif }