go / cpp /inference /main_text.cpp
jva96160's picture
Upload 71 files
a2dca42 verified
#include <iostream>
#include <vector>
#include <fstream>
#include <string>
#include <cmath> // For std::sin, M_PI
#include <cstring> // For std::memcpy
#include <chrono> // For time measurement
#include <random> // For random number generation
#include <ctime> // For seeding random number generator
// Include the new library header
#include "audio_encoder_lib.h"
#include <onnxruntime_cxx_api.h>
// Define M_PI if it's not already defined
#ifndef M_PI
#define M_PI 3.14159265358979323846
#endif
// --- WAV File Header Structures (for dummy file creation) ---
#pragma pack(push, 1)
struct WavHeader {
char riff_id[4];
uint32_t file_size;
char wave_id[4];
char fmt_id[4];
uint32_t fmt_size;
uint16_t audio_format;
uint16_t num_channels;
uint32_t sample_rate;
uint32_t byte_rate;
uint16_t block_align;
uint16_t bits_per_sample;
};
struct WavDataChunk {
char data_id[4];
uint32_t data_size;
};
#pragma pack(pop)
// Function to write a dummy WAV file (moved here for example app)
void createDummyWavFile(const std::string& filename, int sampleRate, int numChannels, int bitsPerSample, double durationSeconds) {
std::ofstream file(filename, std::ios::binary);
if (!file.is_open()) {
std::cerr << "Error: Could not create dummy WAV file: " << filename << std::endl;
return;
}
WavHeader header;
std::memcpy(header.riff_id, "RIFF", 4);
std::memcpy(header.wave_id, "WAVE", 4);
std::memcpy(header.fmt_id, "fmt ", 4);
header.fmt_size = 16;
header.audio_format = 1; // PCM
header.num_channels = numChannels;
header.sample_rate = sampleRate;
header.bits_per_sample = bitsPerSample;
header.byte_rate = (sampleRate * numChannels * bitsPerSample) / 8;
header.block_align = (numChannels * bitsPerSample) / 8;
WavDataChunk data_chunk;
std::memcpy(data_chunk.data_id, "data", 4);
uint32_t num_samples = static_cast<uint32_t>(sampleRate * durationSeconds);
data_chunk.data_size = num_samples * numChannels * (bitsPerSample / 8);
header.file_size = 36 + data_chunk.data_size; // 36 is size of header before data chunk
file.write(reinterpret_cast<const char*>(&header), sizeof(WavHeader));
file.write(reinterpret_cast<const char*>(&data_chunk), sizeof(WavDataChunk));
// Generate a 440 Hz sine wave
for (uint32_t i = 0; i < num_samples; ++i) {
int16_t sample = static_cast<int16_t>(30000 * std::sin(2 * M_PI * 440 * i / static_cast<double>(sampleRate)));
for (int c = 0; c < numChannels; ++c) {
file.write(reinterpret_cast<const char*>(&sample), sizeof(int16_t));
}
}
file.close();
// std::cout << "Dummy WAV file '" << filename << "' created successfully." << std::endl; // Suppress verbose creation message
}
int main(int argc, char* argv[]) {
// --- 1. Process command-line arguments ---
if (argc != 3) {
std::cerr << "Usage: " << argv[0] << " <path_to_onnx_model> <path_to_wav_file_for_temp_use>" << std::endl;
std::cerr << "Example: " << argv[0] << " model.onnx temp_audio.wav" << std::endl;
return 1;
}
std::string onnxModelPath = argv[1];
std::string wavFilename = argv[2]; // This will be used as a temporary file
// --- Random number generation setup for dummy input frames ---
std::mt19937 rng(static_cast<unsigned int>(std::time(nullptr))); // Seed with current time
std::uniform_int_distribution<int> dist_frames(100, 300); // Distribution for frames (100 to 300)
// Define fixed parameters for feature extraction to calculate required duration
const int WIN_LENGTH = 400; // Window length (samples) - must match library's constant
const int HOP_LENGTH = 160; // Hop length (samples) - must match library's constant
const int TARGET_SAMPLE_RATE = 16000; // Target sample rate - must match library's constant
try {
// --- 2. Model Initialization ---
// This will load the ONNX model and precompute the Mel filterbank.
AudioInferenceEngine engine(onnxModelPath);
std::cout << "Engine initialized." << std::endl;
// --- 3. Model Inference and Time Measurement ---
std::cout << "\nRunning model inference and measuring time (100 runs with varying input sizes)..." << std::endl;
int num_runs = 100;
long long total_inference_time_us = 0; // Use microseconds for finer granularity
for (int i = 0; i < num_runs; ++i) {
// Generate a random number of frames for this run
int random_frames = dist_frames(rng);
// Calculate the number of samples needed to produce 'random_frames'
// frames = (num_samples - WIN_LENGTH) / HOP_LENGTH + 1
// num_samples = (frames - 1) * HOP_LENGTH + WIN_LENGTH
long long num_samples_for_frames = static_cast<long long>(random_frames - 1) * HOP_LENGTH + WIN_LENGTH;
double duration_seconds_for_frames = static_cast<double>(num_samples_for_frames) / TARGET_SAMPLE_RATE;
// Create a new dummy WAV file for this specific run
// This ensures the input size changes for each test.
createDummyWavFile(wavFilename, TARGET_SAMPLE_RATE, 1, 16, duration_seconds_for_frames);
// --- Measure the inference time ---
auto start_time = std::chrono::high_resolution_clock::now();
Eigen::MatrixXf features = engine.preprocessAudio(wavFilename);
std::vector<float> model_output = engine.runInference(features);
auto end_time = std::chrono::high_resolution_clock::now();
if (model_output.empty()) {
std::cerr << "Error: Model inference failed for run " << i + 1 << ". Exiting." << std::endl;
return 1;
}
// Calculate duration for this run in microseconds
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time);
total_inference_time_us += duration.count();
// Optionally print output for the first run or specific runs
if (i == 0) {
std::cout << "First run (frames=" << features.rows() << ")"<< " take : "<< static_cast<double>(total_inference_time_us) / 1000.0 / 1000.0 <<"s output (first few elements): [";
for (size_t k = 0; k < std::min((size_t)10, model_output.size()); ++k) {
std::cout << model_output[k] << (k == std::min((size_t)10, model_output.size()) - 1 ? "" : ", ");
}
std::cout << "]" << std::endl;
}
}
double average_inference_time_ms = static_cast<double>(total_inference_time_us) / num_runs / 1000.0 / 1000.0; // Convert microseconds to milliseconds
std::cout << "\nAverage ONNX model inference time over " << num_runs << " runs (with varying input frames): "
<< average_inference_time_ms << " s" << std::endl;
} catch (const Ort::Exception& e) {
std::cerr << "ONNX Runtime Exception: " << e.what() << std::endl;
return 1;
} catch (const std::exception& e) {
std::cerr << "Standard Exception: " << e.what() << std::endl;
return 1;
}
std::cout << "\nProgram finished successfully." << std::endl;
return 0;
}