File size: 4,965 Bytes
a2dca42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
#ifndef AUDIO_INFERENCE_LIBRARY_H
#define AUDIO_INFERENCE_LIBRARY_H

#include <string>
#include <vector>
#include <cstdint> // For uint32_t, int16_t
#include <memory>  // For std::unique_ptr
#include <Eigen/Dense>
using namespace Eigen;
// Forward declarations for ONNX Runtime types to avoid including full headers in .h
namespace Ort {
    struct Env;
    struct Session;
    struct MemoryInfo;
    struct AllocatorWithDefaultOptions;
    struct Value;
}

// Forward declaration for Eigen Matrix
namespace Eigen {
    template<typename Scalar, int Rows, int Cols, int Options, int MaxRows, int MaxCols>
    class Matrix;
    typedef Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor, Eigen::Dynamic, Eigen::Dynamic> MatrixXf;
}

/**
 * @brief Class to handle audio preprocessing and ONNX model inference.
 *
 * This class encapsulates the logic for loading WAV files, extracting Mel filterbank
 * features, and running inference on an ONNX model.
 */
class AudioInferenceEngine {
public:
    /**
     * @brief Constructor for AudioInferenceEngine.
     * @param modelPath The file path to the ONNX model.
     * @throws Ort::Exception if ONNX Runtime initialization fails.
     */
    AudioInferenceEngine(const std::string& modelPath);

    /**
     * @brief Destructor to clean up ONNX Runtime resources.
     */
    ~AudioInferenceEngine();

    /**
     * @brief Preprocesses an audio WAV file to extract Mel filterbank features.
     *
     * This function loads the WAV file, converts it to a float array, and then
     * applies the spectrogram and Mel filterbank extraction steps.
     *
     * @param wavFilePath The path to the WAV audio file.
     * @return An Eigen::MatrixXf containing the extracted features (frames x N_MELS).
     * Returns an empty matrix if preprocessing fails.
     */
    Eigen::MatrixXf preprocessAudio(const std::string& wavFilePath);

    /**
     * @brief Runs inference on the loaded ONNX model using the provided features.
     *
     * The input features should be the output of `preprocessAudio`. This function
     * converts the features to an ONNX Runtime tensor and executes the model.
     *
     * @param features An Eigen::MatrixXf containing the preprocessed audio features.
     * Expected shape: (frames, N_MELS).
     * @return A std::vector<float> containing the flattened output of the ONNX model.
     * Returns an empty vector if inference fails.
     */
    std::vector<float> runInference(const Eigen::MatrixXf& features);
    std::vector<Ort::Value> runInference_tensor(const Ort::Value& inputTensor);

private:
    // ONNX Runtime members
    std::unique_ptr<Ort::Env> env_;
    std::unique_ptr<Ort::Session> session_;
    std::unique_ptr<Ort::AllocatorWithDefaultOptions> allocator_;
    std::vector<const char*> input_node_names_;
    std::vector<const char*> output_node_names_;

    // Precomputed Mel filterbank matrix
    Eigen::MatrixXf mel_filterbank_;

    // Private helper functions (implemented in .cpp)
    // WAV file parsing structures
#pragma pack(push, 1)
    struct WavHeader {
        char     riff_id[4];
        uint32_t file_size;
        char     wave_id[4];
        char     fmt_id[4];
        uint32_t fmt_size;
        uint16_t audio_format;
        uint16_t num_channels;
        uint32_t sample_rate;
        uint32_t byte_rate;
        uint16_t block_align;
        uint16_t bits_per_sample;
    };

    struct WavDataChunk {
        char     data_id[4];
        uint32_t data_size;
    };
#pragma pack(pop)

    /**
     * @brief Loads audio data from a WAV file into a float vector.
     * @param filename The path to the WAV audio file.
     * @param actual_sample_rate Output parameter to store the sample rate read from the WAV file.
     * @return A std::vector<float> containing the normalized mono audio samples.
     */
    std::vector<float> loadWavToFloatArray(const std::string& filename, int& actual_sample_rate);

    /**
     * @brief Generates a Hamming window.
     * @param window_length The length of the window.
     * @return A std::vector<float> containing the Hamming window coefficients.
     */
    std::vector<float> generateHammingWindow(int window_length);

    /**
     * @brief Extracts spectrogram features from waveform.
     * @param wav The input waveform.
     * @param fs The sampling rate.
     * @return A 2D Eigen::MatrixXf representing the spectrogram.
     */
    Eigen::MatrixXf extractSpectrogram(const std::vector<float>& wav, int fs);

    /**
     * @brief Creates a Mel filter-bank matrix.
     * @param sample_rate Sample rate in Hz.
     * @param n_fft FFT size.
     * @param n_mels Mel filter size.
     * @param fmin Lowest frequency (in Hz).
     * @param fmax Highest frequency (in Hz).
     * @return An Eigen::MatrixXf representing the Mel transform matrix.
     */
    Eigen::MatrixXf speechlibMel(int sample_rate, int n_fft, int n_mels, float fmin, float fmax);
};

#endif // AUDIO_INFERENCE_LIBRARY_H