|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include "audio_processor.hpp" |
|
|
#include <algorithm> |
|
|
#include <cmath> |
|
|
#include <cstring> |
|
|
#include <fstream> |
|
|
|
|
|
namespace pocket_tts_accelerator { |
|
|
|
|
|
AudioProcessor::AudioProcessor(MemoryPool& shared_memory_pool) |
|
|
: memory_pool(shared_memory_pool) { |
|
|
} |
|
|
|
|
|
AudioProcessor::~AudioProcessor() { |
|
|
} |
|
|
|
|
|
AudioData AudioProcessor::read_wav_file(const std::string& file_path) { |
|
|
AudioData result; |
|
|
result.is_valid = false; |
|
|
|
|
|
std::ifstream file_stream(file_path, std::ios::binary); |
|
|
|
|
|
if (!file_stream.is_open()) { |
|
|
result.error_message = "Failed to open file: " + file_path; |
|
|
return result; |
|
|
} |
|
|
|
|
|
WavFileHeader header; |
|
|
file_stream.read(reinterpret_cast<char*>(&header), sizeof(WavFileHeader)); |
|
|
|
|
|
if (file_stream.gcount() < static_cast<std::streamsize>(sizeof(WavFileHeader))) { |
|
|
result.error_message = "File is too small to be a valid WAV file"; |
|
|
return result; |
|
|
} |
|
|
|
|
|
if (!validate_wav_header(header)) { |
|
|
result.error_message = "Invalid WAV file header"; |
|
|
return result; |
|
|
} |
|
|
|
|
|
result.sample_rate = header.sample_rate; |
|
|
result.number_of_channels = header.number_of_channels; |
|
|
result.bits_per_sample = header.bits_per_sample; |
|
|
|
|
|
std::size_t sample_count = header.data_size / (header.bits_per_sample / 8); |
|
|
result.samples.resize(sample_count); |
|
|
|
|
|
if (header.bits_per_sample == 16) { |
|
|
file_stream.read(reinterpret_cast<char*>(result.samples.data()), header.data_size); |
|
|
} else if (header.bits_per_sample == 8) { |
|
|
std::vector<std::uint8_t> raw_data(sample_count); |
|
|
file_stream.read(reinterpret_cast<char*>(raw_data.data()), header.data_size); |
|
|
convert_uint8_to_int16(raw_data.data(), result.samples.data(), sample_count); |
|
|
} else if (header.bits_per_sample == 32) { |
|
|
if (header.audio_format == 3) { |
|
|
std::vector<float> raw_data(sample_count); |
|
|
file_stream.read(reinterpret_cast<char*>(raw_data.data()), header.data_size); |
|
|
convert_float32_to_int16(raw_data.data(), result.samples.data(), sample_count); |
|
|
} else { |
|
|
std::vector<std::int32_t> raw_data(sample_count); |
|
|
file_stream.read(reinterpret_cast<char*>(raw_data.data()), header.data_size); |
|
|
convert_int32_to_int16(raw_data.data(), result.samples.data(), sample_count); |
|
|
} |
|
|
} |
|
|
|
|
|
result.is_valid = true; |
|
|
return result; |
|
|
} |
|
|
|
|
|
bool AudioProcessor::write_wav_file(const std::string& file_path, const AudioData& audio_data) { |
|
|
std::ofstream file_stream(file_path, std::ios::binary); |
|
|
|
|
|
if (!file_stream.is_open()) { |
|
|
return false; |
|
|
} |
|
|
|
|
|
std::uint32_t data_size = static_cast<std::uint32_t>(audio_data.samples.size() * sizeof(std::int16_t)); |
|
|
std::uint32_t file_size = data_size + 36; |
|
|
|
|
|
WavFileHeader header; |
|
|
std::memcpy(header.riff_marker, "RIFF", 4); |
|
|
header.file_size = file_size; |
|
|
std::memcpy(header.wave_marker, "WAVE", 4); |
|
|
std::memcpy(header.format_marker, "fmt ", 4); |
|
|
header.format_chunk_size = 16; |
|
|
header.audio_format = 1; |
|
|
header.number_of_channels = audio_data.number_of_channels; |
|
|
header.sample_rate = audio_data.sample_rate; |
|
|
header.bits_per_sample = 16; |
|
|
header.byte_rate = audio_data.sample_rate * audio_data.number_of_channels * 2; |
|
|
header.block_align = audio_data.number_of_channels * 2; |
|
|
std::memcpy(header.data_marker, "data", 4); |
|
|
header.data_size = data_size; |
|
|
|
|
|
file_stream.write(reinterpret_cast<const char*>(&header), sizeof(WavFileHeader)); |
|
|
file_stream.write(reinterpret_cast<const char*>(audio_data.samples.data()), data_size); |
|
|
|
|
|
return file_stream.good(); |
|
|
} |
|
|
|
|
|
AudioProcessingResult AudioProcessor::convert_to_mono(const AudioData& input_audio) { |
|
|
AudioProcessingResult result; |
|
|
result.success = false; |
|
|
|
|
|
if (!input_audio.is_valid) { |
|
|
result.error_message = "Invalid input audio"; |
|
|
return result; |
|
|
} |
|
|
|
|
|
if (input_audio.number_of_channels == 1) { |
|
|
result.processed_samples = input_audio.samples; |
|
|
result.output_sample_rate = input_audio.sample_rate; |
|
|
result.success = true; |
|
|
return result; |
|
|
} |
|
|
|
|
|
std::size_t frame_count = input_audio.samples.size() / input_audio.number_of_channels; |
|
|
result.processed_samples.resize(frame_count); |
|
|
|
|
|
mix_channels_to_mono( |
|
|
input_audio.samples.data(), |
|
|
result.processed_samples.data(), |
|
|
frame_count, |
|
|
input_audio.number_of_channels |
|
|
); |
|
|
|
|
|
result.output_sample_rate = input_audio.sample_rate; |
|
|
result.success = true; |
|
|
return result; |
|
|
} |
|
|
|
|
|
AudioProcessingResult AudioProcessor::convert_to_pcm_int16(const AudioData& input_audio) { |
|
|
AudioProcessingResult result; |
|
|
result.success = false; |
|
|
|
|
|
if (!input_audio.is_valid) { |
|
|
result.error_message = "Invalid input audio"; |
|
|
return result; |
|
|
} |
|
|
|
|
|
result.processed_samples = input_audio.samples; |
|
|
result.output_sample_rate = input_audio.sample_rate; |
|
|
result.success = true; |
|
|
return result; |
|
|
} |
|
|
|
|
|
AudioProcessingResult AudioProcessor::resample_audio(const AudioData& input_audio, std::uint32_t target_sample_rate) { |
|
|
AudioProcessingResult result; |
|
|
result.success = false; |
|
|
|
|
|
if (!input_audio.is_valid) { |
|
|
result.error_message = "Invalid input audio"; |
|
|
return result; |
|
|
} |
|
|
|
|
|
if (input_audio.sample_rate == target_sample_rate) { |
|
|
result.processed_samples = input_audio.samples; |
|
|
result.output_sample_rate = target_sample_rate; |
|
|
result.success = true; |
|
|
return result; |
|
|
} |
|
|
|
|
|
double ratio = static_cast<double>(target_sample_rate) / static_cast<double>(input_audio.sample_rate); |
|
|
std::size_t output_sample_count = static_cast<std::size_t>(input_audio.samples.size() * ratio); |
|
|
|
|
|
result.processed_samples.resize(output_sample_count); |
|
|
|
|
|
for (std::size_t output_index = 0; output_index < output_sample_count; ++output_index) { |
|
|
double source_position = output_index / ratio; |
|
|
std::size_t source_index_floor = static_cast<std::size_t>(source_position); |
|
|
std::size_t source_index_ceil = source_index_floor + 1; |
|
|
double fractional_part = source_position - source_index_floor; |
|
|
|
|
|
if (source_index_ceil >= input_audio.samples.size()) { |
|
|
source_index_ceil = input_audio.samples.size() - 1; |
|
|
} |
|
|
|
|
|
double interpolated_value = |
|
|
input_audio.samples[source_index_floor] * (1.0 - fractional_part) + |
|
|
input_audio.samples[source_index_ceil] * fractional_part; |
|
|
|
|
|
result.processed_samples[output_index] = static_cast<std::int16_t>( |
|
|
std::clamp(interpolated_value, -32768.0, 32767.0) |
|
|
); |
|
|
} |
|
|
|
|
|
result.output_sample_rate = target_sample_rate; |
|
|
result.success = true; |
|
|
return result; |
|
|
} |
|
|
|
|
|
AudioProcessingResult AudioProcessor::normalize_audio(const AudioData& input_audio, float target_peak_level) { |
|
|
AudioProcessingResult result; |
|
|
result.success = false; |
|
|
|
|
|
if (!input_audio.is_valid) { |
|
|
result.error_message = "Invalid input audio"; |
|
|
return result; |
|
|
} |
|
|
|
|
|
std::int16_t max_absolute_value = 0; |
|
|
for (const std::int16_t sample : input_audio.samples) { |
|
|
std::int16_t absolute_value = static_cast<std::int16_t>(std::abs(sample)); |
|
|
if (absolute_value > max_absolute_value) { |
|
|
max_absolute_value = absolute_value; |
|
|
} |
|
|
} |
|
|
|
|
|
if (max_absolute_value == 0) { |
|
|
result.processed_samples = input_audio.samples; |
|
|
result.output_sample_rate = input_audio.sample_rate; |
|
|
result.success = true; |
|
|
return result; |
|
|
} |
|
|
|
|
|
float normalization_factor = (target_peak_level * 32767.0f) / static_cast<float>(max_absolute_value); |
|
|
|
|
|
result.processed_samples.resize(input_audio.samples.size()); |
|
|
|
|
|
for (std::size_t index = 0; index < input_audio.samples.size(); ++index) { |
|
|
float normalized_sample = static_cast<float>(input_audio.samples[index]) * normalization_factor; |
|
|
result.processed_samples[index] = static_cast<std::int16_t>( |
|
|
std::clamp(normalized_sample, -32768.0f, 32767.0f) |
|
|
); |
|
|
} |
|
|
|
|
|
result.output_sample_rate = input_audio.sample_rate; |
|
|
result.success = true; |
|
|
return result; |
|
|
} |
|
|
|
|
|
AudioProcessingResult AudioProcessor::process_audio_for_voice_cloning( |
|
|
const std::string& input_file_path, |
|
|
const std::string& output_file_path |
|
|
) { |
|
|
AudioProcessingResult result; |
|
|
result.success = false; |
|
|
|
|
|
AudioData input_audio = read_wav_file(input_file_path); |
|
|
|
|
|
if (!input_audio.is_valid) { |
|
|
result.error_message = "Failed to read input file: " + input_audio.error_message; |
|
|
return result; |
|
|
} |
|
|
|
|
|
AudioProcessingResult mono_result = convert_to_mono(input_audio); |
|
|
|
|
|
if (!mono_result.success) { |
|
|
result.error_message = "Failed to convert to mono: " + mono_result.error_message; |
|
|
return result; |
|
|
} |
|
|
|
|
|
AudioData mono_audio; |
|
|
mono_audio.samples = std::move(mono_result.processed_samples); |
|
|
mono_audio.sample_rate = mono_result.output_sample_rate; |
|
|
mono_audio.number_of_channels = 1; |
|
|
mono_audio.bits_per_sample = 16; |
|
|
mono_audio.is_valid = true; |
|
|
|
|
|
if (!write_wav_file(output_file_path, mono_audio)) { |
|
|
result.error_message = "Failed to write output file"; |
|
|
return result; |
|
|
} |
|
|
|
|
|
result.processed_samples = std::move(mono_audio.samples); |
|
|
result.output_sample_rate = mono_audio.sample_rate; |
|
|
result.success = true; |
|
|
return result; |
|
|
} |
|
|
|
|
|
bool AudioProcessor::validate_wav_header(const WavFileHeader& header) { |
|
|
if (std::memcmp(header.riff_marker, "RIFF", 4) != 0) { |
|
|
return false; |
|
|
} |
|
|
|
|
|
if (std::memcmp(header.wave_marker, "WAVE", 4) != 0) { |
|
|
return false; |
|
|
} |
|
|
|
|
|
if (std::memcmp(header.format_marker, "fmt ", 4) != 0) { |
|
|
return false; |
|
|
} |
|
|
|
|
|
if (header.audio_format != 1 && header.audio_format != 3) { |
|
|
return false; |
|
|
} |
|
|
|
|
|
if (header.number_of_channels < 1 || header.number_of_channels > 16) { |
|
|
return false; |
|
|
} |
|
|
|
|
|
if (header.sample_rate < 100 || header.sample_rate > 384000) { |
|
|
return false; |
|
|
} |
|
|
|
|
|
if (header.bits_per_sample != 8 && header.bits_per_sample != 16 && header.bits_per_sample != 32) { |
|
|
return false; |
|
|
} |
|
|
|
|
|
return true; |
|
|
} |
|
|
|
|
|
std::size_t AudioProcessor::calculate_audio_duration_milliseconds(const AudioData& audio_data) { |
|
|
if (!audio_data.is_valid || audio_data.sample_rate == 0) { |
|
|
return 0; |
|
|
} |
|
|
|
|
|
std::size_t frame_count = audio_data.samples.size() / audio_data.number_of_channels; |
|
|
return (frame_count * 1000) / audio_data.sample_rate; |
|
|
} |
|
|
|
|
|
void AudioProcessor::convert_float32_to_int16(const float* input, std::int16_t* output, std::size_t sample_count) { |
|
|
for (std::size_t index = 0; index < sample_count; ++index) { |
|
|
float clamped_value = std::clamp(input[index], -1.0f, 1.0f); |
|
|
output[index] = static_cast<std::int16_t>(clamped_value * 32767.0f); |
|
|
} |
|
|
} |
|
|
|
|
|
void AudioProcessor::convert_int32_to_int16(const std::int32_t* input, std::int16_t* output, std::size_t sample_count) { |
|
|
for (std::size_t index = 0; index < sample_count; ++index) { |
|
|
output[index] = static_cast<std::int16_t>(input[index] >> 16); |
|
|
} |
|
|
} |
|
|
|
|
|
void AudioProcessor::convert_uint8_to_int16(const std::uint8_t* input, std::int16_t* output, std::size_t sample_count) { |
|
|
for (std::size_t index = 0; index < sample_count; ++index) { |
|
|
output[index] = static_cast<std::int16_t>((static_cast<std::int16_t>(input[index]) - 128) * 256); |
|
|
} |
|
|
} |
|
|
|
|
|
void AudioProcessor::mix_channels_to_mono( |
|
|
const std::int16_t* input, |
|
|
std::int16_t* output, |
|
|
std::size_t frame_count, |
|
|
std::uint16_t channel_count |
|
|
) { |
|
|
for (std::size_t frame_index = 0; frame_index < frame_count; ++frame_index) { |
|
|
std::int32_t sum = 0; |
|
|
|
|
|
for (std::uint16_t channel_index = 0; channel_index < channel_count; ++channel_index) { |
|
|
sum += input[frame_index * channel_count + channel_index]; |
|
|
} |
|
|
|
|
|
output[frame_index] = static_cast<std::int16_t>(sum / channel_count); |
|
|
} |
|
|
} |
|
|
|
|
|
} |