vad_cpp / silero_vad_onnx /vad_iterator.h
hzeng412's picture
Duplicate from MoYoYoTech/vad_cpp
d21d362
#ifndef VAD_ITERATOR_H
#define VAD_ITERATOR_H
#include "time_stamp.h"
#include <vector>
#include <string>
#if defined(__APPLE__)
#include <onnxruntime/onnxruntime_cxx_api.h>
#else
#include "onnxruntime_run_options_config_keys.h"
#include "onnxruntime_cxx_api.h"
#endif
// 前向声明 timestamp_t
class timestamp_t;
class VadIterator {
public:
// 构造函数
VadIterator(const std::string ModelPath,
int Sample_rate = 16000,
int windows_frame_size = 32,
float Threshold = 0.5,
int min_silence_duration_ms = 100,
int speech_pad_ms = 30,
int min_speech_duration_ms = 250,
float max_speech_duration_s = -1); // -1 表示无穷大
// 公共方法
void process(const std::vector<float>& input_wav);
const std::vector<timestamp_t>& get_speech_timestamps() const;
void reset();
private:
// ONNX Runtime 资源
Ort::Env env;
Ort::SessionOptions session_options;
std::shared_ptr<Ort::Session> session = nullptr;
Ort::AllocatorWithDefaultOptions allocator;
Ort::MemoryInfo memory_info;
// Context 相关变量
const int context_samples = 64;
std::vector<float> _context;
int window_size_samples;
int effective_window_size;
int sr_per_ms;
// ONNX 输入输出相关
std::vector<Ort::Value> ort_inputs;
std::vector<const char*> input_node_names = {"input", "state", "sr"};
std::vector<float> input;
unsigned int size_state = 2 * 1 * 128;
std::vector<float> _state;
std::vector<int64_t> sr;
int64_t input_node_dims[2];
const int64_t state_node_dims[3] = {2, 1, 128};
const int64_t sr_node_dims[1] = {1};
std::vector<Ort::Value> ort_outputs;
std::vector<const char*> output_node_names = {"output", "stateN"};
// 模型参数
int sample_rate;
float threshold;
int min_silence_samples;
int min_silence_samples_at_max_speech;
int min_speech_samples;
float max_speech_samples;
int speech_pad_samples;
int audio_length_samples;
// 状态管理
bool triggered = false;
unsigned int temp_end = 0;
unsigned int current_sample = 0;
int prev_end;
int next_start = 0;
std::vector<timestamp_t> speeches;
timestamp_t current_speech;
// 私有方法
void init_onnx_model(const std::string& model_path);
void init_engine_threads(int inter_threads, int intra_threads);
void reset_states();
void predict(const std::vector<float>& data_chunk);
};
#endif // VAD_ITERATOR_H