aaaaaaaaaaaaaaa / accelerator /src /audio_processor.cpp
arifather51's picture
Upload 28 files
a57f260 verified
//
// SPDX-FileCopyrightText: Hadad <hadad@linuxmail.org>
// SPDX-License-Identifier: Apache-2.0
//
#include "audio_processor.hpp"
#include <algorithm>
#include <cmath>
#include <cstring>
#include <fstream>
namespace pocket_tts_accelerator {
AudioProcessor::AudioProcessor(MemoryPool& shared_memory_pool)
: memory_pool(shared_memory_pool) {
}
AudioProcessor::~AudioProcessor() {
}
AudioData AudioProcessor::read_wav_file(const std::string& file_path) {
AudioData result;
result.is_valid = false;
std::ifstream file_stream(file_path, std::ios::binary);
if (!file_stream.is_open()) {
result.error_message = "Failed to open file: " + file_path;
return result;
}
WavFileHeader header;
file_stream.read(reinterpret_cast<char*>(&header), sizeof(WavFileHeader));
if (file_stream.gcount() < static_cast<std::streamsize>(sizeof(WavFileHeader))) {
result.error_message = "File is too small to be a valid WAV file";
return result;
}
if (!validate_wav_header(header)) {
result.error_message = "Invalid WAV file header";
return result;
}
result.sample_rate = header.sample_rate;
result.number_of_channels = header.number_of_channels;
result.bits_per_sample = header.bits_per_sample;
std::size_t sample_count = header.data_size / (header.bits_per_sample / 8);
result.samples.resize(sample_count);
if (header.bits_per_sample == 16) {
file_stream.read(reinterpret_cast<char*>(result.samples.data()), header.data_size);
} else if (header.bits_per_sample == 8) {
std::vector<std::uint8_t> raw_data(sample_count);
file_stream.read(reinterpret_cast<char*>(raw_data.data()), header.data_size);
convert_uint8_to_int16(raw_data.data(), result.samples.data(), sample_count);
} else if (header.bits_per_sample == 32) {
if (header.audio_format == 3) {
std::vector<float> raw_data(sample_count);
file_stream.read(reinterpret_cast<char*>(raw_data.data()), header.data_size);
convert_float32_to_int16(raw_data.data(), result.samples.data(), sample_count);
} else {
std::vector<std::int32_t> raw_data(sample_count);
file_stream.read(reinterpret_cast<char*>(raw_data.data()), header.data_size);
convert_int32_to_int16(raw_data.data(), result.samples.data(), sample_count);
}
}
result.is_valid = true;
return result;
}
bool AudioProcessor::write_wav_file(const std::string& file_path, const AudioData& audio_data) {
std::ofstream file_stream(file_path, std::ios::binary);
if (!file_stream.is_open()) {
return false;
}
std::uint32_t data_size = static_cast<std::uint32_t>(audio_data.samples.size() * sizeof(std::int16_t));
std::uint32_t file_size = data_size + 36;
WavFileHeader header;
std::memcpy(header.riff_marker, "RIFF", 4);
header.file_size = file_size;
std::memcpy(header.wave_marker, "WAVE", 4);
std::memcpy(header.format_marker, "fmt ", 4);
header.format_chunk_size = 16;
header.audio_format = 1;
header.number_of_channels = audio_data.number_of_channels;
header.sample_rate = audio_data.sample_rate;
header.bits_per_sample = 16;
header.byte_rate = audio_data.sample_rate * audio_data.number_of_channels * 2;
header.block_align = audio_data.number_of_channels * 2;
std::memcpy(header.data_marker, "data", 4);
header.data_size = data_size;
file_stream.write(reinterpret_cast<const char*>(&header), sizeof(WavFileHeader));
file_stream.write(reinterpret_cast<const char*>(audio_data.samples.data()), data_size);
return file_stream.good();
}
AudioProcessingResult AudioProcessor::convert_to_mono(const AudioData& input_audio) {
AudioProcessingResult result;
result.success = false;
if (!input_audio.is_valid) {
result.error_message = "Invalid input audio";
return result;
}
if (input_audio.number_of_channels == 1) {
result.processed_samples = input_audio.samples;
result.output_sample_rate = input_audio.sample_rate;
result.success = true;
return result;
}
std::size_t frame_count = input_audio.samples.size() / input_audio.number_of_channels;
result.processed_samples.resize(frame_count);
mix_channels_to_mono(
input_audio.samples.data(),
result.processed_samples.data(),
frame_count,
input_audio.number_of_channels
);
result.output_sample_rate = input_audio.sample_rate;
result.success = true;
return result;
}
AudioProcessingResult AudioProcessor::convert_to_pcm_int16(const AudioData& input_audio) {
AudioProcessingResult result;
result.success = false;
if (!input_audio.is_valid) {
result.error_message = "Invalid input audio";
return result;
}
result.processed_samples = input_audio.samples;
result.output_sample_rate = input_audio.sample_rate;
result.success = true;
return result;
}
AudioProcessingResult AudioProcessor::resample_audio(const AudioData& input_audio, std::uint32_t target_sample_rate) {
AudioProcessingResult result;
result.success = false;
if (!input_audio.is_valid) {
result.error_message = "Invalid input audio";
return result;
}
if (input_audio.sample_rate == target_sample_rate) {
result.processed_samples = input_audio.samples;
result.output_sample_rate = target_sample_rate;
result.success = true;
return result;
}
double ratio = static_cast<double>(target_sample_rate) / static_cast<double>(input_audio.sample_rate);
std::size_t output_sample_count = static_cast<std::size_t>(input_audio.samples.size() * ratio);
result.processed_samples.resize(output_sample_count);
for (std::size_t output_index = 0; output_index < output_sample_count; ++output_index) {
double source_position = output_index / ratio;
std::size_t source_index_floor = static_cast<std::size_t>(source_position);
std::size_t source_index_ceil = source_index_floor + 1;
double fractional_part = source_position - source_index_floor;
if (source_index_ceil >= input_audio.samples.size()) {
source_index_ceil = input_audio.samples.size() - 1;
}
double interpolated_value =
input_audio.samples[source_index_floor] * (1.0 - fractional_part) +
input_audio.samples[source_index_ceil] * fractional_part;
result.processed_samples[output_index] = static_cast<std::int16_t>(
std::clamp(interpolated_value, -32768.0, 32767.0)
);
}
result.output_sample_rate = target_sample_rate;
result.success = true;
return result;
}
AudioProcessingResult AudioProcessor::normalize_audio(const AudioData& input_audio, float target_peak_level) {
AudioProcessingResult result;
result.success = false;
if (!input_audio.is_valid) {
result.error_message = "Invalid input audio";
return result;
}
std::int16_t max_absolute_value = 0;
for (const std::int16_t sample : input_audio.samples) {
std::int16_t absolute_value = static_cast<std::int16_t>(std::abs(sample));
if (absolute_value > max_absolute_value) {
max_absolute_value = absolute_value;
}
}
if (max_absolute_value == 0) {
result.processed_samples = input_audio.samples;
result.output_sample_rate = input_audio.sample_rate;
result.success = true;
return result;
}
float normalization_factor = (target_peak_level * 32767.0f) / static_cast<float>(max_absolute_value);
result.processed_samples.resize(input_audio.samples.size());
for (std::size_t index = 0; index < input_audio.samples.size(); ++index) {
float normalized_sample = static_cast<float>(input_audio.samples[index]) * normalization_factor;
result.processed_samples[index] = static_cast<std::int16_t>(
std::clamp(normalized_sample, -32768.0f, 32767.0f)
);
}
result.output_sample_rate = input_audio.sample_rate;
result.success = true;
return result;
}
AudioProcessingResult AudioProcessor::process_audio_for_voice_cloning(
const std::string& input_file_path,
const std::string& output_file_path
) {
AudioProcessingResult result;
result.success = false;
AudioData input_audio = read_wav_file(input_file_path);
if (!input_audio.is_valid) {
result.error_message = "Failed to read input file: " + input_audio.error_message;
return result;
}
AudioProcessingResult mono_result = convert_to_mono(input_audio);
if (!mono_result.success) {
result.error_message = "Failed to convert to mono: " + mono_result.error_message;
return result;
}
AudioData mono_audio;
mono_audio.samples = std::move(mono_result.processed_samples);
mono_audio.sample_rate = mono_result.output_sample_rate;
mono_audio.number_of_channels = 1;
mono_audio.bits_per_sample = 16;
mono_audio.is_valid = true;
if (!write_wav_file(output_file_path, mono_audio)) {
result.error_message = "Failed to write output file";
return result;
}
result.processed_samples = std::move(mono_audio.samples);
result.output_sample_rate = mono_audio.sample_rate;
result.success = true;
return result;
}
bool AudioProcessor::validate_wav_header(const WavFileHeader& header) {
if (std::memcmp(header.riff_marker, "RIFF", 4) != 0) {
return false;
}
if (std::memcmp(header.wave_marker, "WAVE", 4) != 0) {
return false;
}
if (std::memcmp(header.format_marker, "fmt ", 4) != 0) {
return false;
}
if (header.audio_format != 1 && header.audio_format != 3) {
return false;
}
if (header.number_of_channels < 1 || header.number_of_channels > 16) {
return false;
}
if (header.sample_rate < 100 || header.sample_rate > 384000) {
return false;
}
if (header.bits_per_sample != 8 && header.bits_per_sample != 16 && header.bits_per_sample != 32) {
return false;
}
return true;
}
std::size_t AudioProcessor::calculate_audio_duration_milliseconds(const AudioData& audio_data) {
if (!audio_data.is_valid || audio_data.sample_rate == 0) {
return 0;
}
std::size_t frame_count = audio_data.samples.size() / audio_data.number_of_channels;
return (frame_count * 1000) / audio_data.sample_rate;
}
void AudioProcessor::convert_float32_to_int16(const float* input, std::int16_t* output, std::size_t sample_count) {
for (std::size_t index = 0; index < sample_count; ++index) {
float clamped_value = std::clamp(input[index], -1.0f, 1.0f);
output[index] = static_cast<std::int16_t>(clamped_value * 32767.0f);
}
}
void AudioProcessor::convert_int32_to_int16(const std::int32_t* input, std::int16_t* output, std::size_t sample_count) {
for (std::size_t index = 0; index < sample_count; ++index) {
output[index] = static_cast<std::int16_t>(input[index] >> 16);
}
}
void AudioProcessor::convert_uint8_to_int16(const std::uint8_t* input, std::int16_t* output, std::size_t sample_count) {
for (std::size_t index = 0; index < sample_count; ++index) {
output[index] = static_cast<std::int16_t>((static_cast<std::int16_t>(input[index]) - 128) * 256);
}
}
void AudioProcessor::mix_channels_to_mono(
const std::int16_t* input,
std::int16_t* output,
std::size_t frame_count,
std::uint16_t channel_count
) {
for (std::size_t frame_index = 0; frame_index < frame_count; ++frame_index) {
std::int32_t sum = 0;
for (std::uint16_t channel_index = 0; channel_index < channel_count; ++channel_index) {
sum += input[frame_index * channel_count + channel_index];
}
output[frame_index] = static_cast<std::int16_t>(sum / channel_count);
}
}
}