go / cpp /inference /audio_encoder_lib.cpp
jva96160's picture
Upload 71 files
a2dca42 verified
#include "audio_encoder_lib.h"
#include <iostream>
#include <fstream>
#include <cmath>
#include <numeric>
#include <algorithm>
#include <cstring> // For memcpy
// Include specific ONNX Runtime headers for implementation
#include <onnxruntime_cxx_api.h>
// Include specific Eigen headers for implementation
#include <Eigen/Dense>
// Include specific KissFFT headers for implementation
#include <kiss_fft.h>
#include <kiss_fftr.h>
// Define M_PI if it's not already defined
#ifndef M_PI
#define M_PI 3.14159265358979323846
#endif
// --- Global parameters for feature extraction (matching Python script) ---
// These are constants derived from the Python preprocessing script and are
// internal to the feature extraction logic.
namespace { // Anonymous namespace for internal linkage
const float PREEMPHASIS_COEFF = 0.97f;
const int N_FFT = 512; // FFT size
const int WIN_LENGTH = 400; // Window length (samples)
const int HOP_LENGTH = 160; // Hop length (samples)
const int N_MELS = 80; // Number of Mel filterbank channels
const int TARGET_SAMPLE_RATE = 16000; // Target sample rate for feature extraction
}
// --- Implementation of AudioInferenceEngine methods ---
AudioInferenceEngine::AudioInferenceEngine(const std::string& modelPath) {
// 1. Initialize ONNX Runtime Environment
env_ = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "AudioInferenceEngine");
// 2. Configure Session Options
Ort::SessionOptions session_options;
session_options.SetIntraOpNumThreads(0);
session_options.SetGraphOptimizationLevel(ORT_ENABLE_EXTENDED);
// 3. Create ONNX Runtime Session
session_ = std::make_unique<Ort::Session>(*env_, modelPath.c_str(), session_options);
// 4. Initialize Allocator
allocator_ = std::make_unique<Ort::AllocatorWithDefaultOptions>();
// 5. Get Input and Output Node Names
// It's crucial to allocate these names using the allocator and store them
// as C-style strings for Ort::Session::Run.
size_t numInputNodes = session_->GetInputCount();
if (numInputNodes == 0) {
throw Ort::Exception("ONNX model has no input nodes.", ORT_FAIL);
}
input_node_names_.resize(numInputNodes);
for (size_t i = 0; i < numInputNodes; ++i) {
input_node_names_[i] = session_->GetInputNameAllocated(i, *allocator_).release(); // release() to manage lifetime
}
size_t numOutputNodes = session_->GetOutputCount();
if (numOutputNodes == 0) {
throw Ort::Exception("ONNX model has no output nodes.", ORT_FAIL);
}
output_node_names_.resize(numOutputNodes);
for (size_t i = 0; i < numOutputNodes; ++i) {
output_node_names_[i] = session_->GetOutputNameAllocated(i, *allocator_).release(); // release() to manage lifetime
}
// 6. Precompute Mel filterbank
// The Python example uses fmax=16000//2-80-230.
float mel_fmax = static_cast<float>(TARGET_SAMPLE_RATE) / 2.0f - 80.0f - 230.0f;
mel_filterbank_ = speechlibMel(TARGET_SAMPLE_RATE, N_FFT, N_MELS, 0.0f, mel_fmax);
if (mel_filterbank_.rows() == 0 || mel_filterbank_.cols() == 0) {
throw std::runtime_error("Failed to create Mel filterbank during initialization.");
}
std::cout << "AudioInferenceEngine initialized successfully with model: " << modelPath << std::endl;
}
AudioInferenceEngine::~AudioInferenceEngine() {
// Release allocated names
for (const char* name : input_node_names_) {
allocator_->Free(const_cast<void*>(reinterpret_cast<const void*>(name)));
}
for (const char* name : output_node_names_) {
allocator_->Free(const_cast<void*>(reinterpret_cast<const void*>(name)));
}
// unique_ptr automatically handles deletion of env_ and session_
}
/**
* @brief Private helper: Loads audio data from a WAV file.
*/
std::vector<float> AudioInferenceEngine::loadWavToFloatArray(const std::string& filename, int& actual_sample_rate) {
std::ifstream file(filename, std::ios::binary);
if (!file.is_open()) {
std::cerr << "Error: Could not open WAV file: " << filename << std::endl;
return {};
}
WavHeader header;
file.read(reinterpret_cast<char*>(&header), sizeof(WavHeader));
if (std::string(header.riff_id, 4) != "RIFF" ||
std::string(header.wave_id, 4) != "WAVE" ||
std::string(header.fmt_id, 4) != "fmt ") {
std::cerr << "Error: Invalid WAV header (RIFF, WAVE, or fmt chunk missing/invalid)." << std::endl;
file.close();
return {};
}
if (header.audio_format != 1) { // 1 = PCM
std::cerr << "Error: Only PCM audio format (1) is supported. Found: " << header.audio_format << std::endl;
file.close();
return {};
}
if (header.bits_per_sample != 16) {
std::cerr << "Error: Only 16-bit PCM is supported. Found: " << header.bits_per_sample << " bits per sample." << std::endl;
file.close();
return {};
}
actual_sample_rate = header.sample_rate;
WavDataChunk data_chunk;
bool data_chunk_found = false;
while (!file.eof()) {
file.read(reinterpret_cast<char*>(&data_chunk.data_id), 4);
file.read(reinterpret_cast<char*>(&data_chunk.data_size), 4);
if (std::string(data_chunk.data_id, 4) == "data") {
data_chunk_found = true;
break;
} else {
file.seekg(data_chunk.data_size, std::ios::cur);
}
}
if (!data_chunk_found) {
std::cerr << "Error: 'data' chunk not found in WAV file." << std::endl;
file.close();
return {};
}
std::vector<float> audioData;
int16_t sample_buffer;
long num_samples_to_read = data_chunk.data_size / sizeof(int16_t);
for (long i = 0; i < num_samples_to_read; ++i) {
file.read(reinterpret_cast<char*>(&sample_buffer), sizeof(int16_t));
float normalized_sample = static_cast<float>(sample_buffer) / 32768.0f;
if (header.num_channels == 1) {
audioData.push_back(normalized_sample);
} else if (header.num_channels == 2) {
int16_t right_sample;
if (file.read(reinterpret_cast<char*>(&right_sample), sizeof(int16_t))) {
float normalized_right_sample = static_cast<float>(right_sample) / 32768.0f;
audioData.push_back((normalized_sample + normalized_right_sample) / 2.0f);
i++;
} else {
std::cerr << "Warning: Unexpected end of file while reading stereo data." << std::endl;
break;
}
} else {
std::cerr << "Error: Unsupported number of channels: " << header.num_channels << std::endl;
file.close();
return {};
}
}
file.close();
return audioData;
}
/**
* @brief Private helper: Generates a Hamming window.
*/
std::vector<float> AudioInferenceEngine::generateHammingWindow(int window_length) {
std::vector<float> window(window_length);
for (int i = 0; i < window_length; ++i) {
window[i] = 0.54f - 0.46f * std::cos(2 * M_PI * i / static_cast<float>(window_length - 1));
}
return window;
}
/**
* @brief Private helper: Extracts spectrogram features.
*/
Eigen::MatrixXf AudioInferenceEngine::extractSpectrogram(const std::vector<float>& wav, int fs) {
int n_batch = (wav.size() - WIN_LENGTH) / HOP_LENGTH + 1;
if (n_batch <= 0) {
return Eigen::MatrixXf(0, N_FFT / 2 + 1);
}
std::vector<float> fft_window = generateHammingWindow(WIN_LENGTH);
kiss_fftr_cfg fft_cfg = kiss_fftr_alloc(N_FFT, 0 /* is_inverse_fft */, nullptr, nullptr);
if (!fft_cfg) {
std::cerr << "Error: Failed to allocate KissFFT configuration." << std::endl;
return Eigen::MatrixXf(0, N_FFT / 2 + 1);
}
Eigen::MatrixXf spec_matrix(n_batch, N_FFT / 2 + 1);
std::vector<float> frame_buffer(WIN_LENGTH);
kiss_fft_scalar fft_input[N_FFT];
kiss_fft_cpx fft_output[N_FFT / 2 + 1];
for (int i = 0; i < n_batch; ++i) {
int start_idx = i * HOP_LENGTH;
for (int j = 0; j < WIN_LENGTH; ++j) {
frame_buffer[j] = wav[start_idx + j];
}
// Apply pre-emphasis and scale by 32768
if (WIN_LENGTH > 0) {
if (WIN_LENGTH > 1) {
// Corrected pre-emphasis to match Python's np.roll and then overwrite first element
// The first element of the frame is pre-emphasized against the second element.
fft_input[0] = (frame_buffer[0] - PREEMPHASIS_COEFF * frame_buffer[1]) * 32768.0f;
for (int j = 1; j < WIN_LENGTH; ++j) {
fft_input[j] = (frame_buffer[j] - PREEMPHASIS_COEFF * frame_buffer[j - 1]) * 32768.0f;
}
} else { // WIN_LENGTH == 1
fft_input[0] = frame_buffer[0] * 32768.0f;
}
}
for (int j = WIN_LENGTH; j < N_FFT; ++j) {
fft_input[j] = 0.0f;
}
for (int j = 0; j < WIN_LENGTH; ++j) {
fft_input[j] *= fft_window[j];
}
kiss_fftr(fft_cfg, fft_input, fft_output);
for (int j = 0; j <= N_FFT / 2; ++j) {
spec_matrix(i, j) = std::sqrt(fft_output[j].r * fft_output[j].r + fft_output[j].i * fft_output[j].i);
}
}
kiss_fftr_free(fft_cfg);
return spec_matrix;
}
/**
* @brief Private helper: Creates a Mel filter-bank matrix.
*/
Eigen::MatrixXf AudioInferenceEngine::speechlibMel(int sample_rate, int n_fft, int n_mels, float fmin, float fmax) {
int bank_width = n_fft / 2 + 1;
if (fmax == 0.0f) fmax = sample_rate / 2.0f;
if (fmin == 0.0f) fmin = 0.0f;
auto mel = [](float f) { return 1127.0f * std::log(1.0f + f / 700.0f); };
auto bin2mel = [&](int fft_bin) { return 1127.0f * std::log(1.0f + static_cast<float>(fft_bin) * sample_rate / (static_cast<float>(n_fft) * 700.0f)); };
auto f2bin = [&](float f) { return static_cast<int>((f * n_fft / sample_rate) + 0.5f); };
int klo = f2bin(fmin) + 1;
int khi = f2bin(fmax);
khi = std::max(khi, klo);
float mlo = mel(fmin);
float mhi = mel(fmax);
std::vector<float> m_centers(n_mels + 2);
float ms = (mhi - mlo) / (n_mels + 1);
for (int i = 0; i < n_mels + 2; ++i) {
m_centers[i] = mlo + i * ms;
}
Eigen::MatrixXf matrix = Eigen::MatrixXf::Zero(n_mels, bank_width);
for (int m = 0; m < n_mels; ++m) {
float left = m_centers[m];
float center = m_centers[m + 1];
float right = m_centers[m + 2];
for (int fft_bin = klo; fft_bin < bank_width; ++fft_bin) {
float mbin = bin2mel(fft_bin);
if (left < mbin && mbin < right) {
matrix(m, fft_bin) = 1.0f - std::abs(center - mbin) / ms;
}
}
}
return matrix;
}
/**
* @brief Public method: Preprocesses an audio WAV file.
*/
Eigen::MatrixXf AudioInferenceEngine::preprocessAudio(const std::string& wavFilePath) {
int actual_wav_sample_rate = 0;
std::vector<float> audioWav = loadWavToFloatArray(wavFilePath, actual_wav_sample_rate);
if (audioWav.empty()) {
std::cerr << "Failed to load audio data from " << wavFilePath << "." << std::endl;
return Eigen::MatrixXf(0, N_MELS);
}
if (actual_wav_sample_rate != TARGET_SAMPLE_RATE) {
std::cerr << "Warning: WAV file sample rate (" << actual_wav_sample_rate
<< " Hz) does not match the target sample rate for feature extraction ("
<< TARGET_SAMPLE_RATE << " Hz)." << std::endl;
std::cerr << "This example does NOT include resampling. Features will be extracted at "
<< TARGET_SAMPLE_RATE << " Hz, which might lead to incorrect results if the WAV file's sample rate is different." << std::endl;
}
Eigen::MatrixXf spec = extractSpectrogram(audioWav, TARGET_SAMPLE_RATE);
if (spec.rows() == 0) {
std::cerr << "Error: Spectrogram extraction failed." << std::endl;
return Eigen::MatrixXf(0, N_MELS);
}
Eigen::MatrixXf spec_power = spec.array().square();
Eigen::MatrixXf fbank_power = spec_power * mel_filterbank_.transpose(); // Transpose mel_filterbank_ for correct multiplication
fbank_power = fbank_power.array().max(1.0f);
Eigen::MatrixXf log_fbank = fbank_power.array().log();
return log_fbank;
}
/**
* @brief Public method: Runs inference on the loaded ONNX model.
*/
std::vector<float> AudioInferenceEngine::runInference(const Eigen::MatrixXf& features) {
if (features.rows() == 0 || features.cols() == 0) {
std::cerr << "Error: Input features are empty for inference." << std::endl;
return {};
}
// Prepare Input Tensor Shape: [batch, frames, feature_size]
std::vector<int64_t> inputTensorShape = {1, features.rows(), features.cols()};
// Flatten Eigen::MatrixXf into std::vector<float> in row-major order
std::vector<float> inputTensorData(features.rows() * features.cols());
for (int r = 0; r < features.rows(); ++r) {
for (int c = 0; c < features.cols(); ++c) {
inputTensorData[r * features.cols() + c] = features(r, c);
}
}
Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
Ort::Value inputTensor = Ort::Value::CreateTensor<float>(memory_info, inputTensorData.data(), inputTensorData.size(),
inputTensorShape.data(), inputTensorShape.size());
if (!inputTensor.IsTensor()) {
std::cerr << "Error: Created input tensor is not valid!" << std::endl;
return {};
}
// Run Inference
std::vector<Ort::Value> outputTensors = session_->Run(Ort::RunOptions{nullptr},
input_node_names_.data(), &inputTensor, 1,
output_node_names_.data(), output_node_names_.size());
if (outputTensors.empty() || !outputTensors[0].IsTensor()) {
std::cerr << "Error: No valid output tensors received from the model." << std::endl;
return {};
}
// Copy output data
float* outputData = outputTensors[0].GetTensorMutableData<float>();
Ort::TensorTypeAndShapeInfo outputShapeInfo = outputTensors[0].GetTensorTypeAndShapeInfo();
size_t outputSize = outputShapeInfo.GetElementCount();
std::vector<float> result(outputData, outputData + outputSize);
return result;
}
std::vector<Ort::Value> AudioInferenceEngine::runInference_tensor(const Ort::Value& inputTensor) {
// Run Inference
std::vector<Ort::Value> outputTensors = session_->Run(Ort::RunOptions{nullptr},
input_node_names_.data(), &inputTensor, 1,
output_node_names_.data(), output_node_names_.size());
return outputTensors;
}