|
|
#include <iostream> |
|
|
#include <vector> |
|
|
#include <fstream> |
|
|
#include <string> |
|
|
#include <cmath> |
|
|
#include <cstring> |
|
|
#include <chrono> |
|
|
#include <random> |
|
|
#include <ctime> |
|
|
|
|
|
|
|
|
#include "audio_encoder_lib.h" |
|
|
#include <onnxruntime_cxx_api.h> |
|
|
|
|
|
#ifndef M_PI |
|
|
#define M_PI 3.14159265358979323846 |
|
|
#endif |
|
|
|
|
|
|
|
|
#pragma pack(push, 1) |
|
|
struct WavHeader { |
|
|
char riff_id[4]; |
|
|
uint32_t file_size; |
|
|
char wave_id[4]; |
|
|
char fmt_id[4]; |
|
|
uint32_t fmt_size; |
|
|
uint16_t audio_format; |
|
|
uint16_t num_channels; |
|
|
uint32_t sample_rate; |
|
|
uint32_t byte_rate; |
|
|
uint16_t block_align; |
|
|
uint16_t bits_per_sample; |
|
|
}; |
|
|
|
|
|
struct WavDataChunk { |
|
|
char data_id[4]; |
|
|
uint32_t data_size; |
|
|
}; |
|
|
#pragma pack(pop) |
|
|
|
|
|
|
|
|
void createDummyWavFile(const std::string& filename, int sampleRate, int numChannels, int bitsPerSample, double durationSeconds) { |
|
|
std::ofstream file(filename, std::ios::binary); |
|
|
if (!file.is_open()) { |
|
|
std::cerr << "Error: Could not create dummy WAV file: " << filename << std::endl; |
|
|
return; |
|
|
} |
|
|
|
|
|
WavHeader header; |
|
|
std::memcpy(header.riff_id, "RIFF", 4); |
|
|
std::memcpy(header.wave_id, "WAVE", 4); |
|
|
std::memcpy(header.fmt_id, "fmt ", 4); |
|
|
header.fmt_size = 16; |
|
|
header.audio_format = 1; |
|
|
header.num_channels = numChannels; |
|
|
header.sample_rate = sampleRate; |
|
|
header.bits_per_sample = bitsPerSample; |
|
|
header.byte_rate = (sampleRate * numChannels * bitsPerSample) / 8; |
|
|
header.block_align = (numChannels * bitsPerSample) / 8; |
|
|
|
|
|
WavDataChunk data_chunk; |
|
|
std::memcpy(data_chunk.data_id, "data", 4); |
|
|
uint32_t num_samples = static_cast<uint32_t>(sampleRate * durationSeconds); |
|
|
data_chunk.data_size = num_samples * numChannels * (bitsPerSample / 8); |
|
|
header.file_size = 36 + data_chunk.data_size; |
|
|
|
|
|
file.write(reinterpret_cast<const char*>(&header), sizeof(WavHeader)); |
|
|
file.write(reinterpret_cast<const char*>(&data_chunk), sizeof(WavDataChunk)); |
|
|
|
|
|
|
|
|
for (uint32_t i = 0; i < num_samples; ++i) { |
|
|
int16_t sample = static_cast<int16_t>(30000 * std::sin(2 * M_PI * 440 * i / static_cast<double>(sampleRate))); |
|
|
for (int c = 0; c < numChannels; ++c) { |
|
|
file.write(reinterpret_cast<const char*>(&sample), sizeof(int16_t)); |
|
|
} |
|
|
} |
|
|
|
|
|
file.close(); |
|
|
|
|
|
} |
|
|
|
|
|
int main(int argc, char* argv[]) { |
|
|
|
|
|
if (argc != 3) { |
|
|
std::cerr << "Usage: " << argv[0] << " <path_to_onnx_model> <path_to_wav_file_for_temp_use>" << std::endl; |
|
|
std::cerr << "Example: " << argv[0] << " model.onnx temp_audio.wav" << std::endl; |
|
|
return 1; |
|
|
} |
|
|
|
|
|
std::string onnxModelPath = argv[1]; |
|
|
std::string wavFilename = argv[2]; |
|
|
|
|
|
|
|
|
std::mt19937 rng(static_cast<unsigned int>(std::time(nullptr))); |
|
|
std::uniform_int_distribution<int> dist_frames(100, 300); |
|
|
|
|
|
|
|
|
const int WIN_LENGTH = 400; |
|
|
const int HOP_LENGTH = 160; |
|
|
const int TARGET_SAMPLE_RATE = 16000; |
|
|
|
|
|
try { |
|
|
|
|
|
|
|
|
AudioInferenceEngine engine(onnxModelPath); |
|
|
std::cout << "Engine initialized." << std::endl; |
|
|
|
|
|
|
|
|
std::cout << "\nRunning model inference and measuring time (100 runs with varying input sizes)..." << std::endl; |
|
|
int num_runs = 100; |
|
|
long long total_inference_time_us = 0; |
|
|
|
|
|
for (int i = 0; i < num_runs; ++i) { |
|
|
|
|
|
int random_frames = dist_frames(rng); |
|
|
|
|
|
|
|
|
|
|
|
long long num_samples_for_frames = static_cast<long long>(random_frames - 1) * HOP_LENGTH + WIN_LENGTH; |
|
|
double duration_seconds_for_frames = static_cast<double>(num_samples_for_frames) / TARGET_SAMPLE_RATE; |
|
|
|
|
|
|
|
|
|
|
|
createDummyWavFile(wavFilename, TARGET_SAMPLE_RATE, 1, 16, duration_seconds_for_frames); |
|
|
|
|
|
|
|
|
auto start_time = std::chrono::high_resolution_clock::now(); |
|
|
Eigen::MatrixXf features = engine.preprocessAudio(wavFilename); |
|
|
std::vector<float> model_output = engine.runInference(features); |
|
|
auto end_time = std::chrono::high_resolution_clock::now(); |
|
|
|
|
|
if (model_output.empty()) { |
|
|
std::cerr << "Error: Model inference failed for run " << i + 1 << ". Exiting." << std::endl; |
|
|
return 1; |
|
|
} |
|
|
|
|
|
|
|
|
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time); |
|
|
total_inference_time_us += duration.count(); |
|
|
|
|
|
|
|
|
if (i == 0) { |
|
|
std::cout << "First run (frames=" << features.rows() << ")"<< " take : "<< static_cast<double>(total_inference_time_us) / 1000.0 / 1000.0 <<"s output (first few elements): ["; |
|
|
for (size_t k = 0; k < std::min((size_t)10, model_output.size()); ++k) { |
|
|
std::cout << model_output[k] << (k == std::min((size_t)10, model_output.size()) - 1 ? "" : ", "); |
|
|
} |
|
|
std::cout << "]" << std::endl; |
|
|
} |
|
|
} |
|
|
|
|
|
double average_inference_time_ms = static_cast<double>(total_inference_time_us) / num_runs / 1000.0 / 1000.0; |
|
|
std::cout << "\nAverage ONNX model inference time over " << num_runs << " runs (with varying input frames): " |
|
|
<< average_inference_time_ms << " s" << std::endl; |
|
|
|
|
|
} catch (const Ort::Exception& e) { |
|
|
std::cerr << "ONNX Runtime Exception: " << e.what() << std::endl; |
|
|
return 1; |
|
|
} catch (const std::exception& e) { |
|
|
std::cerr << "Standard Exception: " << e.what() << std::endl; |
|
|
return 1; |
|
|
} |
|
|
|
|
|
std::cout << "\nProgram finished successfully." << std::endl; |
|
|
return 0; |
|
|
} |