go / cpp /inference /audio_encoder_lib.h
jva96160's picture
Upload 71 files
a2dca42 verified
#ifndef AUDIO_INFERENCE_LIBRARY_H
#define AUDIO_INFERENCE_LIBRARY_H
#include <string>
#include <vector>
#include <cstdint> // For uint32_t, int16_t
#include <memory> // For std::unique_ptr
#include <Eigen/Dense>
using namespace Eigen;
// Forward declarations for ONNX Runtime types to avoid including full headers in .h
namespace Ort {
struct Env;
struct Session;
struct MemoryInfo;
struct AllocatorWithDefaultOptions;
struct Value;
}
// Forward declaration for Eigen Matrix
namespace Eigen {
template<typename Scalar, int Rows, int Cols, int Options, int MaxRows, int MaxCols>
class Matrix;
typedef Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor, Eigen::Dynamic, Eigen::Dynamic> MatrixXf;
}
/**
* @brief Class to handle audio preprocessing and ONNX model inference.
*
* This class encapsulates the logic for loading WAV files, extracting Mel filterbank
* features, and running inference on an ONNX model.
*/
class AudioInferenceEngine {
public:
/**
* @brief Constructor for AudioInferenceEngine.
* @param modelPath The file path to the ONNX model.
* @throws Ort::Exception if ONNX Runtime initialization fails.
*/
AudioInferenceEngine(const std::string& modelPath);
/**
* @brief Destructor to clean up ONNX Runtime resources.
*/
~AudioInferenceEngine();
/**
* @brief Preprocesses an audio WAV file to extract Mel filterbank features.
*
* This function loads the WAV file, converts it to a float array, and then
* applies the spectrogram and Mel filterbank extraction steps.
*
* @param wavFilePath The path to the WAV audio file.
* @return An Eigen::MatrixXf containing the extracted features (frames x N_MELS).
* Returns an empty matrix if preprocessing fails.
*/
Eigen::MatrixXf preprocessAudio(const std::string& wavFilePath);
/**
* @brief Runs inference on the loaded ONNX model using the provided features.
*
* The input features should be the output of `preprocessAudio`. This function
* converts the features to an ONNX Runtime tensor and executes the model.
*
* @param features An Eigen::MatrixXf containing the preprocessed audio features.
* Expected shape: (frames, N_MELS).
* @return A std::vector<float> containing the flattened output of the ONNX model.
* Returns an empty vector if inference fails.
*/
std::vector<float> runInference(const Eigen::MatrixXf& features);
std::vector<Ort::Value> runInference_tensor(const Ort::Value& inputTensor);
private:
// ONNX Runtime members
std::unique_ptr<Ort::Env> env_;
std::unique_ptr<Ort::Session> session_;
std::unique_ptr<Ort::AllocatorWithDefaultOptions> allocator_;
std::vector<const char*> input_node_names_;
std::vector<const char*> output_node_names_;
// Precomputed Mel filterbank matrix
Eigen::MatrixXf mel_filterbank_;
// Private helper functions (implemented in .cpp)
// WAV file parsing structures
#pragma pack(push, 1)
struct WavHeader {
char riff_id[4];
uint32_t file_size;
char wave_id[4];
char fmt_id[4];
uint32_t fmt_size;
uint16_t audio_format;
uint16_t num_channels;
uint32_t sample_rate;
uint32_t byte_rate;
uint16_t block_align;
uint16_t bits_per_sample;
};
struct WavDataChunk {
char data_id[4];
uint32_t data_size;
};
#pragma pack(pop)
/**
* @brief Loads audio data from a WAV file into a float vector.
* @param filename The path to the WAV audio file.
* @param actual_sample_rate Output parameter to store the sample rate read from the WAV file.
* @return A std::vector<float> containing the normalized mono audio samples.
*/
std::vector<float> loadWavToFloatArray(const std::string& filename, int& actual_sample_rate);
/**
* @brief Generates a Hamming window.
* @param window_length The length of the window.
* @return A std::vector<float> containing the Hamming window coefficients.
*/
std::vector<float> generateHammingWindow(int window_length);
/**
* @brief Extracts spectrogram features from waveform.
* @param wav The input waveform.
* @param fs The sampling rate.
* @return A 2D Eigen::MatrixXf representing the spectrogram.
*/
Eigen::MatrixXf extractSpectrogram(const std::vector<float>& wav, int fs);
/**
* @brief Creates a Mel filter-bank matrix.
* @param sample_rate Sample rate in Hz.
* @param n_fft FFT size.
* @param n_mels Mel filter size.
* @param fmin Lowest frequency (in Hz).
* @param fmax Highest frequency (in Hz).
* @return An Eigen::MatrixXf representing the Mel transform matrix.
*/
Eigen::MatrixXf speechlibMel(int sample_rate, int n_fft, int n_mels, float fmin, float fmax);
};
#endif // AUDIO_INFERENCE_LIBRARY_H