File size: 7,338 Bytes
a2dca42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
#include <iostream>
#include <vector>
#include <fstream>
#include <string>
#include <cmath> // For std::sin, M_PI
#include <cstring> // For std::memcpy
#include <chrono>  // For time measurement
#include <random>  // For random number generation
#include <ctime>   // For seeding random number generator

// Include the new library header
#include "audio_encoder_lib.h"
#include <onnxruntime_cxx_api.h>
// Define M_PI if it's not already defined
#ifndef M_PI
#define M_PI 3.14159265358979323846
#endif

// --- WAV File Header Structures (for dummy file creation) ---
#pragma pack(push, 1)
struct WavHeader {
    char     riff_id[4];
    uint32_t file_size;
    char     wave_id[4];
    char     fmt_id[4];
    uint32_t fmt_size;
    uint16_t audio_format;
    uint16_t num_channels;
    uint32_t sample_rate;
    uint32_t byte_rate;
    uint16_t block_align;
    uint16_t bits_per_sample;
};

struct WavDataChunk {
    char     data_id[4];
    uint32_t data_size;
};
#pragma pack(pop)

// Function to write a dummy WAV file (moved here for example app)
void createDummyWavFile(const std::string& filename, int sampleRate, int numChannels, int bitsPerSample, double durationSeconds) {
    std::ofstream file(filename, std::ios::binary);
    if (!file.is_open()) {
        std::cerr << "Error: Could not create dummy WAV file: " << filename << std::endl;
        return;
    }

    WavHeader header;
    std::memcpy(header.riff_id, "RIFF", 4);
    std::memcpy(header.wave_id, "WAVE", 4);
    std::memcpy(header.fmt_id, "fmt ", 4);
    header.fmt_size = 16;
    header.audio_format = 1; // PCM
    header.num_channels = numChannels;
    header.sample_rate = sampleRate;
    header.bits_per_sample = bitsPerSample;
    header.byte_rate = (sampleRate * numChannels * bitsPerSample) / 8;
    header.block_align = (numChannels * bitsPerSample) / 8;

    WavDataChunk data_chunk;
    std::memcpy(data_chunk.data_id, "data", 4);
    uint32_t num_samples = static_cast<uint32_t>(sampleRate * durationSeconds);
    data_chunk.data_size = num_samples * numChannels * (bitsPerSample / 8);
    header.file_size = 36 + data_chunk.data_size; // 36 is size of header before data chunk

    file.write(reinterpret_cast<const char*>(&header), sizeof(WavHeader));
    file.write(reinterpret_cast<const char*>(&data_chunk), sizeof(WavDataChunk));

    // Generate a 440 Hz sine wave
    for (uint32_t i = 0; i < num_samples; ++i) {
        int16_t sample = static_cast<int16_t>(30000 * std::sin(2 * M_PI * 440 * i / static_cast<double>(sampleRate)));
        for (int c = 0; c < numChannels; ++c) {
            file.write(reinterpret_cast<const char*>(&sample), sizeof(int16_t));
        }
    }

    file.close();
    // std::cout << "Dummy WAV file '" << filename << "' created successfully." << std::endl; // Suppress verbose creation message
}

int main(int argc, char* argv[]) {
    // --- 1. Process command-line arguments ---
    if (argc != 3) {
        std::cerr << "Usage: " << argv[0] << " <path_to_onnx_model> <path_to_wav_file_for_temp_use>" << std::endl;
        std::cerr << "Example: " << argv[0] << " model.onnx temp_audio.wav" << std::endl;
        return 1;
    }

    std::string onnxModelPath = argv[1];
    std::string wavFilename = argv[2]; // This will be used as a temporary file

    // --- Random number generation setup for dummy input frames ---
    std::mt19937 rng(static_cast<unsigned int>(std::time(nullptr))); // Seed with current time
    std::uniform_int_distribution<int> dist_frames(100, 300); // Distribution for frames (100 to 300)

    // Define fixed parameters for feature extraction to calculate required duration
    const int WIN_LENGTH = 400; // Window length (samples) - must match library's constant
    const int HOP_LENGTH = 160; // Hop length (samples) - must match library's constant
    const int TARGET_SAMPLE_RATE = 16000; // Target sample rate - must match library's constant

    try {
        // --- 2. Model Initialization ---
        // This will load the ONNX model and precompute the Mel filterbank.
        AudioInferenceEngine engine(onnxModelPath);
        std::cout << "Engine initialized." << std::endl;

        // --- 3. Model Inference and Time Measurement ---
        std::cout << "\nRunning model inference and measuring time (100 runs with varying input sizes)..." << std::endl;
        int num_runs = 100;
        long long total_inference_time_us = 0; // Use microseconds for finer granularity

        for (int i = 0; i < num_runs; ++i) {
            // Generate a random number of frames for this run
            int random_frames = dist_frames(rng);
            // Calculate the number of samples needed to produce 'random_frames'
            // frames = (num_samples - WIN_LENGTH) / HOP_LENGTH + 1
            // num_samples = (frames - 1) * HOP_LENGTH + WIN_LENGTH
            long long num_samples_for_frames = static_cast<long long>(random_frames - 1) * HOP_LENGTH + WIN_LENGTH;
            double duration_seconds_for_frames = static_cast<double>(num_samples_for_frames) / TARGET_SAMPLE_RATE;

            // Create a new dummy WAV file for this specific run
            // This ensures the input size changes for each test.
            createDummyWavFile(wavFilename, TARGET_SAMPLE_RATE, 1, 16, duration_seconds_for_frames);

            // --- Measure the inference time ---
            auto start_time = std::chrono::high_resolution_clock::now();
            Eigen::MatrixXf features = engine.preprocessAudio(wavFilename);
            std::vector<float> model_output = engine.runInference(features);
            auto end_time = std::chrono::high_resolution_clock::now();

            if (model_output.empty()) {
                std::cerr << "Error: Model inference failed for run " << i + 1 << ". Exiting." << std::endl;
                return 1;
            }

            // Calculate duration for this run in microseconds
            auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time);
            total_inference_time_us += duration.count();

            // Optionally print output for the first run or specific runs
            if (i == 0) {
                std::cout << "First run (frames=" << features.rows() << ")"<< " take : "<< static_cast<double>(total_inference_time_us) / 1000.0 / 1000.0 <<"s output (first few elements): [";
                for (size_t k = 0; k < std::min((size_t)10, model_output.size()); ++k) {
                    std::cout << model_output[k] << (k == std::min((size_t)10, model_output.size()) - 1 ? "" : ", ");
                }
                std::cout << "]" << std::endl;
            }
        }

        double average_inference_time_ms = static_cast<double>(total_inference_time_us) / num_runs / 1000.0 / 1000.0; // Convert microseconds to milliseconds
        std::cout << "\nAverage ONNX model inference time over " << num_runs << " runs (with varying input frames): "
                  << average_inference_time_ms << " s" << std::endl;

    } catch (const Ort::Exception& e) {
        std::cerr << "ONNX Runtime Exception: " << e.what() << std::endl;
        return 1;
    } catch (const std::exception& e) {
        std::cerr << "Standard Exception: " << e.what() << std::endl;
        return 1;
    }

    std::cout << "\nProgram finished successfully." << std::endl;
    return 0;
}