File size: 4,965 Bytes
a2dca42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
#ifndef AUDIO_INFERENCE_LIBRARY_H
#define AUDIO_INFERENCE_LIBRARY_H
#include <string>
#include <vector>
#include <cstdint> // For uint32_t, int16_t
#include <memory> // For std::unique_ptr
#include <Eigen/Dense>
using namespace Eigen;
// Forward declarations for ONNX Runtime types to avoid including full headers in .h
namespace Ort {
struct Env;
struct Session;
struct MemoryInfo;
struct AllocatorWithDefaultOptions;
struct Value;
}
// Forward declaration for Eigen Matrix
namespace Eigen {
template<typename Scalar, int Rows, int Cols, int Options, int MaxRows, int MaxCols>
class Matrix;
typedef Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor, Eigen::Dynamic, Eigen::Dynamic> MatrixXf;
}
/**
* @brief Class to handle audio preprocessing and ONNX model inference.
*
* This class encapsulates the logic for loading WAV files, extracting Mel filterbank
* features, and running inference on an ONNX model.
*/
class AudioInferenceEngine {
public:
/**
* @brief Constructor for AudioInferenceEngine.
* @param modelPath The file path to the ONNX model.
* @throws Ort::Exception if ONNX Runtime initialization fails.
*/
AudioInferenceEngine(const std::string& modelPath);
/**
* @brief Destructor to clean up ONNX Runtime resources.
*/
~AudioInferenceEngine();
/**
* @brief Preprocesses an audio WAV file to extract Mel filterbank features.
*
* This function loads the WAV file, converts it to a float array, and then
* applies the spectrogram and Mel filterbank extraction steps.
*
* @param wavFilePath The path to the WAV audio file.
* @return An Eigen::MatrixXf containing the extracted features (frames x N_MELS).
* Returns an empty matrix if preprocessing fails.
*/
Eigen::MatrixXf preprocessAudio(const std::string& wavFilePath);
/**
* @brief Runs inference on the loaded ONNX model using the provided features.
*
* The input features should be the output of `preprocessAudio`. This function
* converts the features to an ONNX Runtime tensor and executes the model.
*
* @param features An Eigen::MatrixXf containing the preprocessed audio features.
* Expected shape: (frames, N_MELS).
* @return A std::vector<float> containing the flattened output of the ONNX model.
* Returns an empty vector if inference fails.
*/
std::vector<float> runInference(const Eigen::MatrixXf& features);
std::vector<Ort::Value> runInference_tensor(const Ort::Value& inputTensor);
private:
// ONNX Runtime members
std::unique_ptr<Ort::Env> env_;
std::unique_ptr<Ort::Session> session_;
std::unique_ptr<Ort::AllocatorWithDefaultOptions> allocator_;
std::vector<const char*> input_node_names_;
std::vector<const char*> output_node_names_;
// Precomputed Mel filterbank matrix
Eigen::MatrixXf mel_filterbank_;
// Private helper functions (implemented in .cpp)
// WAV file parsing structures
#pragma pack(push, 1)
struct WavHeader {
char riff_id[4];
uint32_t file_size;
char wave_id[4];
char fmt_id[4];
uint32_t fmt_size;
uint16_t audio_format;
uint16_t num_channels;
uint32_t sample_rate;
uint32_t byte_rate;
uint16_t block_align;
uint16_t bits_per_sample;
};
struct WavDataChunk {
char data_id[4];
uint32_t data_size;
};
#pragma pack(pop)
/**
* @brief Loads audio data from a WAV file into a float vector.
* @param filename The path to the WAV audio file.
* @param actual_sample_rate Output parameter to store the sample rate read from the WAV file.
* @return A std::vector<float> containing the normalized mono audio samples.
*/
std::vector<float> loadWavToFloatArray(const std::string& filename, int& actual_sample_rate);
/**
* @brief Generates a Hamming window.
* @param window_length The length of the window.
* @return A std::vector<float> containing the Hamming window coefficients.
*/
std::vector<float> generateHammingWindow(int window_length);
/**
* @brief Extracts spectrogram features from waveform.
* @param wav The input waveform.
* @param fs The sampling rate.
* @return A 2D Eigen::MatrixXf representing the spectrogram.
*/
Eigen::MatrixXf extractSpectrogram(const std::vector<float>& wav, int fs);
/**
* @brief Creates a Mel filter-bank matrix.
* @param sample_rate Sample rate in Hz.
* @param n_fft FFT size.
* @param n_mels Mel filter size.
* @param fmin Lowest frequency (in Hz).
* @param fmax Highest frequency (in Hz).
* @return An Eigen::MatrixXf representing the Mel transform matrix.
*/
Eigen::MatrixXf speechlibMel(int sample_rate, int n_fft, int n_mels, float fmin, float fmax);
};
#endif // AUDIO_INFERENCE_LIBRARY_H |