File size: 7,338 Bytes
a2dca42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
#include <iostream>
#include <vector>
#include <fstream>
#include <string>
#include <cmath> // For std::sin, M_PI
#include <cstring> // For std::memcpy
#include <chrono> // For time measurement
#include <random> // For random number generation
#include <ctime> // For seeding random number generator
// Include the new library header
#include "audio_encoder_lib.h"
#include <onnxruntime_cxx_api.h>
// Define M_PI if it's not already defined
#ifndef M_PI
#define M_PI 3.14159265358979323846
#endif
// --- WAV File Header Structures (for dummy file creation) ---
#pragma pack(push, 1)
struct WavHeader {
char riff_id[4];
uint32_t file_size;
char wave_id[4];
char fmt_id[4];
uint32_t fmt_size;
uint16_t audio_format;
uint16_t num_channels;
uint32_t sample_rate;
uint32_t byte_rate;
uint16_t block_align;
uint16_t bits_per_sample;
};
struct WavDataChunk {
char data_id[4];
uint32_t data_size;
};
#pragma pack(pop)
// Function to write a dummy WAV file (moved here for example app)
void createDummyWavFile(const std::string& filename, int sampleRate, int numChannels, int bitsPerSample, double durationSeconds) {
std::ofstream file(filename, std::ios::binary);
if (!file.is_open()) {
std::cerr << "Error: Could not create dummy WAV file: " << filename << std::endl;
return;
}
WavHeader header;
std::memcpy(header.riff_id, "RIFF", 4);
std::memcpy(header.wave_id, "WAVE", 4);
std::memcpy(header.fmt_id, "fmt ", 4);
header.fmt_size = 16;
header.audio_format = 1; // PCM
header.num_channels = numChannels;
header.sample_rate = sampleRate;
header.bits_per_sample = bitsPerSample;
header.byte_rate = (sampleRate * numChannels * bitsPerSample) / 8;
header.block_align = (numChannels * bitsPerSample) / 8;
WavDataChunk data_chunk;
std::memcpy(data_chunk.data_id, "data", 4);
uint32_t num_samples = static_cast<uint32_t>(sampleRate * durationSeconds);
data_chunk.data_size = num_samples * numChannels * (bitsPerSample / 8);
header.file_size = 36 + data_chunk.data_size; // 36 is size of header before data chunk
file.write(reinterpret_cast<const char*>(&header), sizeof(WavHeader));
file.write(reinterpret_cast<const char*>(&data_chunk), sizeof(WavDataChunk));
// Generate a 440 Hz sine wave
for (uint32_t i = 0; i < num_samples; ++i) {
int16_t sample = static_cast<int16_t>(30000 * std::sin(2 * M_PI * 440 * i / static_cast<double>(sampleRate)));
for (int c = 0; c < numChannels; ++c) {
file.write(reinterpret_cast<const char*>(&sample), sizeof(int16_t));
}
}
file.close();
// std::cout << "Dummy WAV file '" << filename << "' created successfully." << std::endl; // Suppress verbose creation message
}
int main(int argc, char* argv[]) {
// --- 1. Process command-line arguments ---
if (argc != 3) {
std::cerr << "Usage: " << argv[0] << " <path_to_onnx_model> <path_to_wav_file_for_temp_use>" << std::endl;
std::cerr << "Example: " << argv[0] << " model.onnx temp_audio.wav" << std::endl;
return 1;
}
std::string onnxModelPath = argv[1];
std::string wavFilename = argv[2]; // This will be used as a temporary file
// --- Random number generation setup for dummy input frames ---
std::mt19937 rng(static_cast<unsigned int>(std::time(nullptr))); // Seed with current time
std::uniform_int_distribution<int> dist_frames(100, 300); // Distribution for frames (100 to 300)
// Define fixed parameters for feature extraction to calculate required duration
const int WIN_LENGTH = 400; // Window length (samples) - must match library's constant
const int HOP_LENGTH = 160; // Hop length (samples) - must match library's constant
const int TARGET_SAMPLE_RATE = 16000; // Target sample rate - must match library's constant
try {
// --- 2. Model Initialization ---
// This will load the ONNX model and precompute the Mel filterbank.
AudioInferenceEngine engine(onnxModelPath);
std::cout << "Engine initialized." << std::endl;
// --- 3. Model Inference and Time Measurement ---
std::cout << "\nRunning model inference and measuring time (100 runs with varying input sizes)..." << std::endl;
int num_runs = 100;
long long total_inference_time_us = 0; // Use microseconds for finer granularity
for (int i = 0; i < num_runs; ++i) {
// Generate a random number of frames for this run
int random_frames = dist_frames(rng);
// Calculate the number of samples needed to produce 'random_frames'
// frames = (num_samples - WIN_LENGTH) / HOP_LENGTH + 1
// num_samples = (frames - 1) * HOP_LENGTH + WIN_LENGTH
long long num_samples_for_frames = static_cast<long long>(random_frames - 1) * HOP_LENGTH + WIN_LENGTH;
double duration_seconds_for_frames = static_cast<double>(num_samples_for_frames) / TARGET_SAMPLE_RATE;
// Create a new dummy WAV file for this specific run
// This ensures the input size changes for each test.
createDummyWavFile(wavFilename, TARGET_SAMPLE_RATE, 1, 16, duration_seconds_for_frames);
// --- Measure the inference time ---
auto start_time = std::chrono::high_resolution_clock::now();
Eigen::MatrixXf features = engine.preprocessAudio(wavFilename);
std::vector<float> model_output = engine.runInference(features);
auto end_time = std::chrono::high_resolution_clock::now();
if (model_output.empty()) {
std::cerr << "Error: Model inference failed for run " << i + 1 << ". Exiting." << std::endl;
return 1;
}
// Calculate duration for this run in microseconds
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time);
total_inference_time_us += duration.count();
// Optionally print output for the first run or specific runs
if (i == 0) {
std::cout << "First run (frames=" << features.rows() << ")"<< " take : "<< static_cast<double>(total_inference_time_us) / 1000.0 / 1000.0 <<"s output (first few elements): [";
for (size_t k = 0; k < std::min((size_t)10, model_output.size()); ++k) {
std::cout << model_output[k] << (k == std::min((size_t)10, model_output.size()) - 1 ? "" : ", ");
}
std::cout << "]" << std::endl;
}
}
double average_inference_time_ms = static_cast<double>(total_inference_time_us) / num_runs / 1000.0 / 1000.0; // Convert microseconds to milliseconds
std::cout << "\nAverage ONNX model inference time over " << num_runs << " runs (with varying input frames): "
<< average_inference_time_ms << " s" << std::endl;
} catch (const Ort::Exception& e) {
std::cerr << "ONNX Runtime Exception: " << e.what() << std::endl;
return 1;
} catch (const std::exception& e) {
std::cerr << "Standard Exception: " << e.what() << std::endl;
return 1;
}
std::cout << "\nProgram finished successfully." << std::endl;
return 0;
} |