Pocket TTS: System stability tuning.
Browse files- Dockerfile +11 -1
- accelerator/CMakeLists.txt +36 -0
- accelerator/include/accelerator_core.hpp +69 -0
- accelerator/include/audio_processor.hpp +84 -0
- accelerator/include/ipc_handler.hpp +107 -0
- accelerator/include/memory_pool.hpp +79 -0
- accelerator/include/thread_pool.hpp +83 -0
- accelerator/src/accelerator_core.cpp +485 -0
- accelerator/src/audio_processor.cpp +352 -0
- accelerator/src/ipc_handler.cpp +226 -0
- accelerator/src/main.cpp +79 -0
- accelerator/src/memory_pool.cpp +216 -0
- accelerator/src/thread_pool.cpp +84 -0
- app.py +15 -5
- config.py +7 -1
- src/accelerator/client.py +442 -0
- src/audio/converter.py +28 -0
Dockerfile
CHANGED
|
@@ -7,4 +7,14 @@ FROM hadadrjt/pocket-tts:hf
|
|
| 7 |
|
| 8 |
WORKDIR /app
|
| 9 |
|
| 10 |
-
COPY . .
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
WORKDIR /app
|
| 9 |
|
| 10 |
+
COPY . .
|
| 11 |
+
|
| 12 |
+
RUN mkdir build \
|
| 13 |
+
&& cd build \
|
| 14 |
+
&& cmake -DCMAKE_BUILD_TYPE=Release ../accelerator \
|
| 15 |
+
&& make -j$(nproc) \
|
| 16 |
+
&& mkdir -p "$PWD/../bin" \
|
| 17 |
+
&& mv pocket_tts_accelerator "$PWD/../bin/" \
|
| 18 |
+
&& chmod +x "$PWD/../bin/pocket_tts_accelerator" \
|
| 19 |
+
&& cd .. \
|
| 20 |
+
&& rm -rf accelerator
|
accelerator/CMakeLists.txt
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# SPDX-FileCopyrightText: Hadad <hadad@linuxmail.org>
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
cmake_minimum_required(VERSION 3.31.6)
|
| 7 |
+
|
| 8 |
+
project(pocket_tts_accelerator VERSION 0.0.0 LANGUAGES CXX)
|
| 9 |
+
|
| 10 |
+
set(CMAKE_CXX_STANDARD 17)
|
| 11 |
+
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
| 12 |
+
set(CMAKE_CXX_EXTENSIONS OFF)
|
| 13 |
+
|
| 14 |
+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -march=native -ffast-math -funroll-loops")
|
| 15 |
+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -Wpedantic")
|
| 16 |
+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread")
|
| 17 |
+
|
| 18 |
+
find_package(Threads REQUIRED)
|
| 19 |
+
|
| 20 |
+
set(ACCELERATOR_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/include)
|
| 21 |
+
|
| 22 |
+
set(ACCELERATOR_SOURCES
|
| 23 |
+
src/main.cpp
|
| 24 |
+
src/accelerator_core.cpp
|
| 25 |
+
src/audio_processor.cpp
|
| 26 |
+
src/ipc_handler.cpp
|
| 27 |
+
src/memory_pool.cpp
|
| 28 |
+
src/thread_pool.cpp
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
add_executable(pocket_tts_accelerator ${ACCELERATOR_SOURCES})
|
| 32 |
+
|
| 33 |
+
target_include_directories(pocket_tts_accelerator PRIVATE ${ACCELERATOR_INCLUDE_DIR})
|
| 34 |
+
target_link_libraries(pocket_tts_accelerator PRIVATE Threads::Threads)
|
| 35 |
+
|
| 36 |
+
install(TARGETS pocket_tts_accelerator DESTINATION bin)
|
accelerator/include/accelerator_core.hpp
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//
|
| 2 |
+
// SPDX-FileCopyrightText: Hadad <hadad@linuxmail.org>
|
| 3 |
+
// SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
//
|
| 5 |
+
|
| 6 |
+
#ifndef POCKET_TTS_ACCELERATOR_CORE_HPP
|
| 7 |
+
#define POCKET_TTS_ACCELERATOR_CORE_HPP
|
| 8 |
+
|
| 9 |
+
#include "audio_processor.hpp"
|
| 10 |
+
#include "ipc_handler.hpp"
|
| 11 |
+
#include "memory_pool.hpp"
|
| 12 |
+
#include "thread_pool.hpp"
|
| 13 |
+
#include <atomic>
|
| 14 |
+
#include <memory>
|
| 15 |
+
#include <string>
|
| 16 |
+
|
| 17 |
+
namespace pocket_tts_accelerator {
|
| 18 |
+
|
| 19 |
+
struct AcceleratorConfiguration {
|
| 20 |
+
std::size_t number_of_worker_threads;
|
| 21 |
+
std::size_t memory_pool_size_bytes;
|
| 22 |
+
std::string ipc_socket_path;
|
| 23 |
+
bool enable_verbose_logging;
|
| 24 |
+
};
|
| 25 |
+
|
| 26 |
+
class AcceleratorCore {
|
| 27 |
+
public:
|
| 28 |
+
explicit AcceleratorCore(const AcceleratorConfiguration& configuration);
|
| 29 |
+
~AcceleratorCore();
|
| 30 |
+
|
| 31 |
+
AcceleratorCore(const AcceleratorCore&) = delete;
|
| 32 |
+
AcceleratorCore& operator=(const AcceleratorCore&) = delete;
|
| 33 |
+
|
| 34 |
+
bool initialize();
|
| 35 |
+
void run();
|
| 36 |
+
void shutdown();
|
| 37 |
+
|
| 38 |
+
bool is_running() const;
|
| 39 |
+
std::string get_status_string() const;
|
| 40 |
+
|
| 41 |
+
static AcceleratorConfiguration get_default_configuration();
|
| 42 |
+
|
| 43 |
+
private:
|
| 44 |
+
void register_all_command_handlers();
|
| 45 |
+
void setup_signal_handlers();
|
| 46 |
+
|
| 47 |
+
std::vector<std::uint8_t> handle_ping_command(const std::vector<std::uint8_t>& payload);
|
| 48 |
+
std::vector<std::uint8_t> handle_process_audio_command(const std::vector<std::uint8_t>& payload);
|
| 49 |
+
std::vector<std::uint8_t> handle_convert_to_mono_command(const std::vector<std::uint8_t>& payload);
|
| 50 |
+
std::vector<std::uint8_t> handle_convert_to_pcm_command(const std::vector<std::uint8_t>& payload);
|
| 51 |
+
std::vector<std::uint8_t> handle_resample_audio_command(const std::vector<std::uint8_t>& payload);
|
| 52 |
+
std::vector<std::uint8_t> handle_get_memory_stats_command(const std::vector<std::uint8_t>& payload);
|
| 53 |
+
std::vector<std::uint8_t> handle_clear_memory_pool_command(const std::vector<std::uint8_t>& payload);
|
| 54 |
+
std::vector<std::uint8_t> handle_shutdown_command(const std::vector<std::uint8_t>& payload);
|
| 55 |
+
|
| 56 |
+
void log_message(const std::string& message) const;
|
| 57 |
+
|
| 58 |
+
AcceleratorConfiguration config;
|
| 59 |
+
std::unique_ptr<MemoryPool> memory_pool;
|
| 60 |
+
std::unique_ptr<ThreadPool> thread_pool;
|
| 61 |
+
std::unique_ptr<AudioProcessor> audio_processor;
|
| 62 |
+
std::unique_ptr<IpcHandler> ipc_handler;
|
| 63 |
+
std::atomic<bool> is_initialized;
|
| 64 |
+
std::atomic<bool> should_shutdown;
|
| 65 |
+
};
|
| 66 |
+
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
#endif
|
accelerator/include/audio_processor.hpp
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//
|
| 2 |
+
// SPDX-FileCopyrightText: Hadad <hadad@linuxmail.org>
|
| 3 |
+
// SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
//
|
| 5 |
+
|
| 6 |
+
#ifndef POCKET_TTS_AUDIO_PROCESSOR_HPP
|
| 7 |
+
#define POCKET_TTS_AUDIO_PROCESSOR_HPP
|
| 8 |
+
|
| 9 |
+
#include "memory_pool.hpp"
|
| 10 |
+
#include <cstddef>
|
| 11 |
+
#include <cstdint>
|
| 12 |
+
#include <string>
|
| 13 |
+
#include <vector>
|
| 14 |
+
|
| 15 |
+
namespace pocket_tts_accelerator {
|
| 16 |
+
|
| 17 |
+
struct WavFileHeader {
|
| 18 |
+
char riff_marker[4];
|
| 19 |
+
std::uint32_t file_size;
|
| 20 |
+
char wave_marker[4];
|
| 21 |
+
char format_marker[4];
|
| 22 |
+
std::uint32_t format_chunk_size;
|
| 23 |
+
std::uint16_t audio_format;
|
| 24 |
+
std::uint16_t number_of_channels;
|
| 25 |
+
std::uint32_t sample_rate;
|
| 26 |
+
std::uint32_t byte_rate;
|
| 27 |
+
std::uint16_t block_align;
|
| 28 |
+
std::uint16_t bits_per_sample;
|
| 29 |
+
char data_marker[4];
|
| 30 |
+
std::uint32_t data_size;
|
| 31 |
+
};
|
| 32 |
+
|
| 33 |
+
struct AudioData {
|
| 34 |
+
std::vector<std::int16_t> samples;
|
| 35 |
+
std::uint32_t sample_rate;
|
| 36 |
+
std::uint16_t number_of_channels;
|
| 37 |
+
std::uint16_t bits_per_sample;
|
| 38 |
+
bool is_valid;
|
| 39 |
+
std::string error_message;
|
| 40 |
+
};
|
| 41 |
+
|
| 42 |
+
struct AudioProcessingResult {
|
| 43 |
+
std::vector<std::int16_t> processed_samples;
|
| 44 |
+
std::uint32_t output_sample_rate;
|
| 45 |
+
bool success;
|
| 46 |
+
std::string error_message;
|
| 47 |
+
};
|
| 48 |
+
|
| 49 |
+
class AudioProcessor {
|
| 50 |
+
public:
|
| 51 |
+
explicit AudioProcessor(MemoryPool& shared_memory_pool);
|
| 52 |
+
~AudioProcessor();
|
| 53 |
+
|
| 54 |
+
AudioProcessor(const AudioProcessor&) = delete;
|
| 55 |
+
AudioProcessor& operator=(const AudioProcessor&) = delete;
|
| 56 |
+
|
| 57 |
+
AudioData read_wav_file(const std::string& file_path);
|
| 58 |
+
bool write_wav_file(const std::string& file_path, const AudioData& audio_data);
|
| 59 |
+
|
| 60 |
+
AudioProcessingResult convert_to_mono(const AudioData& input_audio);
|
| 61 |
+
AudioProcessingResult convert_to_pcm_int16(const AudioData& input_audio);
|
| 62 |
+
AudioProcessingResult resample_audio(const AudioData& input_audio, std::uint32_t target_sample_rate);
|
| 63 |
+
AudioProcessingResult normalize_audio(const AudioData& input_audio, float target_peak_level);
|
| 64 |
+
|
| 65 |
+
AudioProcessingResult process_audio_for_voice_cloning(
|
| 66 |
+
const std::string& input_file_path,
|
| 67 |
+
const std::string& output_file_path
|
| 68 |
+
);
|
| 69 |
+
|
| 70 |
+
static bool validate_wav_header(const WavFileHeader& header);
|
| 71 |
+
static std::size_t calculate_audio_duration_milliseconds(const AudioData& audio_data);
|
| 72 |
+
|
| 73 |
+
private:
|
| 74 |
+
void convert_float32_to_int16(const float* input, std::int16_t* output, std::size_t sample_count);
|
| 75 |
+
void convert_int32_to_int16(const std::int32_t* input, std::int16_t* output, std::size_t sample_count);
|
| 76 |
+
void convert_uint8_to_int16(const std::uint8_t* input, std::int16_t* output, std::size_t sample_count);
|
| 77 |
+
void mix_channels_to_mono(const std::int16_t* input, std::int16_t* output, std::size_t frame_count, std::uint16_t channel_count);
|
| 78 |
+
|
| 79 |
+
MemoryPool& memory_pool;
|
| 80 |
+
};
|
| 81 |
+
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
#endif
|
accelerator/include/ipc_handler.hpp
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//
|
| 2 |
+
// SPDX-FileCopyrightText: Hadad <hadad@linuxmail.org>
|
| 3 |
+
// SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
//
|
| 5 |
+
|
| 6 |
+
#ifndef POCKET_TTS_IPC_HANDLER_HPP
|
| 7 |
+
#define POCKET_TTS_IPC_HANDLER_HPP
|
| 8 |
+
|
| 9 |
+
#include <atomic>
|
| 10 |
+
#include <cstddef>
|
| 11 |
+
#include <cstdint>
|
| 12 |
+
#include <functional>
|
| 13 |
+
#include <memory>
|
| 14 |
+
#include <mutex>
|
| 15 |
+
#include <string>
|
| 16 |
+
#include <thread>
|
| 17 |
+
#include <vector>
|
| 18 |
+
|
| 19 |
+
namespace pocket_tts_accelerator {
|
| 20 |
+
|
| 21 |
+
enum class CommandType : std::uint32_t {
|
| 22 |
+
PING = 0,
|
| 23 |
+
PROCESS_AUDIO = 1,
|
| 24 |
+
CONVERT_TO_MONO = 2,
|
| 25 |
+
CONVERT_TO_PCM = 3,
|
| 26 |
+
RESAMPLE_AUDIO = 4,
|
| 27 |
+
GET_MEMORY_STATS = 5,
|
| 28 |
+
CLEAR_MEMORY_POOL = 6,
|
| 29 |
+
SHUTDOWN = 7,
|
| 30 |
+
UNKNOWN = 255
|
| 31 |
+
};
|
| 32 |
+
|
| 33 |
+
enum class ResponseStatus : std::uint32_t {
|
| 34 |
+
SUCCESS = 0,
|
| 35 |
+
ERROR_INVALID_COMMAND = 1,
|
| 36 |
+
ERROR_FILE_NOT_FOUND = 2,
|
| 37 |
+
ERROR_PROCESSING_FAILED = 3,
|
| 38 |
+
ERROR_MEMORY_ALLOCATION = 4,
|
| 39 |
+
ERROR_INTERNAL = 5
|
| 40 |
+
};
|
| 41 |
+
|
| 42 |
+
struct RequestHeader {
|
| 43 |
+
std::uint32_t magic_number;
|
| 44 |
+
std::uint32_t command_type;
|
| 45 |
+
std::uint32_t payload_size;
|
| 46 |
+
std::uint32_t request_id;
|
| 47 |
+
};
|
| 48 |
+
|
| 49 |
+
struct ResponseHeader {
|
| 50 |
+
std::uint32_t magic_number;
|
| 51 |
+
std::uint32_t status_code;
|
| 52 |
+
std::uint32_t payload_size;
|
| 53 |
+
std::uint32_t request_id;
|
| 54 |
+
};
|
| 55 |
+
|
| 56 |
+
struct ProcessAudioRequest {
|
| 57 |
+
char input_file_path[512];
|
| 58 |
+
char output_file_path[512];
|
| 59 |
+
std::uint32_t target_sample_rate;
|
| 60 |
+
std::uint32_t options_flags;
|
| 61 |
+
};
|
| 62 |
+
|
| 63 |
+
struct MemoryStatsResponse {
|
| 64 |
+
std::uint64_t total_allocated_bytes;
|
| 65 |
+
std::uint64_t total_used_bytes;
|
| 66 |
+
std::uint64_t block_count;
|
| 67 |
+
};
|
| 68 |
+
|
| 69 |
+
class IpcHandler {
|
| 70 |
+
public:
|
| 71 |
+
using CommandHandlerFunction = std::function<std::vector<std::uint8_t>(const std::vector<std::uint8_t>&)>;
|
| 72 |
+
|
| 73 |
+
explicit IpcHandler(const std::string& socket_path);
|
| 74 |
+
~IpcHandler();
|
| 75 |
+
|
| 76 |
+
IpcHandler(const IpcHandler&) = delete;
|
| 77 |
+
IpcHandler& operator=(const IpcHandler&) = delete;
|
| 78 |
+
|
| 79 |
+
bool start_server();
|
| 80 |
+
void stop_server();
|
| 81 |
+
bool is_running() const;
|
| 82 |
+
|
| 83 |
+
void register_command_handler(CommandType command_type, CommandHandlerFunction handler);
|
| 84 |
+
void set_shutdown_callback(std::function<void()> callback);
|
| 85 |
+
|
| 86 |
+
static constexpr std::uint32_t PROTOCOL_MAGIC_NUMBER = 0x50545453;
|
| 87 |
+
static constexpr std::size_t MAXIMUM_PAYLOAD_SIZE = 16 * 1024 * 1024;
|
| 88 |
+
static constexpr int CONNECTION_BACKLOG = 5;
|
| 89 |
+
|
| 90 |
+
private:
|
| 91 |
+
void accept_connections_loop();
|
| 92 |
+
void handle_client_connection(int client_socket_fd);
|
| 93 |
+
bool send_response(int socket_fd, const ResponseHeader& header, const std::vector<std::uint8_t>& payload);
|
| 94 |
+
bool receive_request(int socket_fd, RequestHeader& header, std::vector<std::uint8_t>& payload);
|
| 95 |
+
|
| 96 |
+
std::string socket_file_path;
|
| 97 |
+
int server_socket_fd;
|
| 98 |
+
std::atomic<bool> is_server_running;
|
| 99 |
+
std::thread accept_thread;
|
| 100 |
+
std::mutex handlers_mutex;
|
| 101 |
+
std::unordered_map<CommandType, CommandHandlerFunction> command_handlers;
|
| 102 |
+
std::function<void()> shutdown_callback;
|
| 103 |
+
};
|
| 104 |
+
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
#endif
|
accelerator/include/memory_pool.hpp
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//
|
| 2 |
+
// SPDX-FileCopyrightText: Hadad <hadad@linuxmail.org>
|
| 3 |
+
// SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
//
|
| 5 |
+
|
| 6 |
+
#ifndef POCKET_TTS_MEMORY_POOL_HPP
|
| 7 |
+
#define POCKET_TTS_MEMORY_POOL_HPP
|
| 8 |
+
|
| 9 |
+
#include <atomic>
|
| 10 |
+
#include <cstddef>
|
| 11 |
+
#include <cstdint>
|
| 12 |
+
#include <memory>
|
| 13 |
+
#include <mutex>
|
| 14 |
+
#include <unordered_map>
|
| 15 |
+
#include <vector>
|
| 16 |
+
|
| 17 |
+
namespace pocket_tts_accelerator {
|
| 18 |
+
|
| 19 |
+
struct MemoryBlock {
|
| 20 |
+
std::unique_ptr<std::uint8_t[]> data;
|
| 21 |
+
std::size_t block_size;
|
| 22 |
+
bool is_in_use;
|
| 23 |
+
std::uint64_t last_access_timestamp;
|
| 24 |
+
};
|
| 25 |
+
|
| 26 |
+
class MemoryPool {
|
| 27 |
+
public:
|
| 28 |
+
explicit MemoryPool(std::size_t initial_pool_size_bytes = 64 * 1024 * 1024);
|
| 29 |
+
~MemoryPool();
|
| 30 |
+
|
| 31 |
+
MemoryPool(const MemoryPool&) = delete;
|
| 32 |
+
MemoryPool& operator=(const MemoryPool&) = delete;
|
| 33 |
+
MemoryPool(MemoryPool&&) = delete;
|
| 34 |
+
MemoryPool& operator=(MemoryPool&&) = delete;
|
| 35 |
+
|
| 36 |
+
std::uint8_t* allocate(std::size_t requested_size_bytes);
|
| 37 |
+
void deallocate(std::uint8_t* pointer);
|
| 38 |
+
void clear_unused_blocks();
|
| 39 |
+
void reset_pool();
|
| 40 |
+
|
| 41 |
+
std::size_t get_total_allocated_bytes() const;
|
| 42 |
+
std::size_t get_total_used_bytes() const;
|
| 43 |
+
std::size_t get_block_count() const;
|
| 44 |
+
|
| 45 |
+
private:
|
| 46 |
+
std::size_t find_suitable_block_index(std::size_t requested_size) const;
|
| 47 |
+
void create_new_block(std::size_t block_size);
|
| 48 |
+
std::uint64_t get_current_timestamp() const;
|
| 49 |
+
|
| 50 |
+
std::vector<MemoryBlock> memory_blocks;
|
| 51 |
+
std::unordered_map<std::uint8_t*, std::size_t> pointer_to_block_index;
|
| 52 |
+
mutable std::mutex pool_mutex;
|
| 53 |
+
std::size_t total_allocated_bytes;
|
| 54 |
+
std::size_t total_used_bytes;
|
| 55 |
+
std::size_t maximum_pool_size_bytes;
|
| 56 |
+
};
|
| 57 |
+
|
| 58 |
+
class ScopedMemoryAllocation {
|
| 59 |
+
public:
|
| 60 |
+
ScopedMemoryAllocation(MemoryPool& pool, std::size_t size);
|
| 61 |
+
~ScopedMemoryAllocation();
|
| 62 |
+
|
| 63 |
+
ScopedMemoryAllocation(const ScopedMemoryAllocation&) = delete;
|
| 64 |
+
ScopedMemoryAllocation& operator=(const ScopedMemoryAllocation&) = delete;
|
| 65 |
+
ScopedMemoryAllocation(ScopedMemoryAllocation&& other) noexcept;
|
| 66 |
+
ScopedMemoryAllocation& operator=(ScopedMemoryAllocation&& other) noexcept;
|
| 67 |
+
|
| 68 |
+
std::uint8_t* get() const;
|
| 69 |
+
std::size_t size() const;
|
| 70 |
+
|
| 71 |
+
private:
|
| 72 |
+
MemoryPool* memory_pool_pointer;
|
| 73 |
+
std::uint8_t* allocated_pointer;
|
| 74 |
+
std::size_t allocation_size;
|
| 75 |
+
};
|
| 76 |
+
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
#endif
|
accelerator/include/thread_pool.hpp
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//
|
| 2 |
+
// SPDX-FileCopyrightText: Hadad <hadad@linuxmail.org>
|
| 3 |
+
// SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
//
|
| 5 |
+
|
| 6 |
+
#ifndef POCKET_TTS_THREAD_POOL_HPP
|
| 7 |
+
#define POCKET_TTS_THREAD_POOL_HPP
|
| 8 |
+
|
| 9 |
+
#include <atomic>
|
| 10 |
+
#include <condition_variable>
|
| 11 |
+
#include <functional>
|
| 12 |
+
#include <future>
|
| 13 |
+
#include <memory>
|
| 14 |
+
#include <mutex>
|
| 15 |
+
#include <queue>
|
| 16 |
+
#include <thread>
|
| 17 |
+
#include <vector>
|
| 18 |
+
|
| 19 |
+
namespace pocket_tts_accelerator {
|
| 20 |
+
|
| 21 |
+
class ThreadPool {
|
| 22 |
+
public:
|
| 23 |
+
explicit ThreadPool(std::size_t number_of_threads);
|
| 24 |
+
~ThreadPool();
|
| 25 |
+
|
| 26 |
+
ThreadPool(const ThreadPool&) = delete;
|
| 27 |
+
ThreadPool& operator=(const ThreadPool&) = delete;
|
| 28 |
+
ThreadPool(ThreadPool&&) = delete;
|
| 29 |
+
ThreadPool& operator=(ThreadPool&&) = delete;
|
| 30 |
+
|
| 31 |
+
template<typename FunctionType, typename... ArgumentTypes>
|
| 32 |
+
auto submit_task(FunctionType&& function, ArgumentTypes&&... arguments)
|
| 33 |
+
-> std::future<typename std::invoke_result<FunctionType, ArgumentTypes...>::type>;
|
| 34 |
+
|
| 35 |
+
void shutdown();
|
| 36 |
+
bool is_running() const;
|
| 37 |
+
std::size_t get_pending_task_count() const;
|
| 38 |
+
std::size_t get_thread_count() const;
|
| 39 |
+
|
| 40 |
+
private:
|
| 41 |
+
void worker_thread_function();
|
| 42 |
+
|
| 43 |
+
std::vector<std::thread> worker_threads;
|
| 44 |
+
std::queue<std::function<void()>> task_queue;
|
| 45 |
+
mutable std::mutex queue_mutex;
|
| 46 |
+
std::condition_variable task_available_condition;
|
| 47 |
+
std::atomic<bool> should_stop;
|
| 48 |
+
std::atomic<bool> is_stopped;
|
| 49 |
+
std::size_t thread_count;
|
| 50 |
+
};
|
| 51 |
+
|
| 52 |
+
template<typename FunctionType, typename... ArgumentTypes>
|
| 53 |
+
auto ThreadPool::submit_task(FunctionType&& function, ArgumentTypes&&... arguments)
|
| 54 |
+
-> std::future<typename std::invoke_result<FunctionType, ArgumentTypes...>::type> {
|
| 55 |
+
|
| 56 |
+
using ReturnType = typename std::invoke_result<FunctionType, ArgumentTypes...>::type;
|
| 57 |
+
|
| 58 |
+
auto packaged_task = std::make_shared<std::packaged_task<ReturnType()>>(
|
| 59 |
+
std::bind(std::forward<FunctionType>(function), std::forward<ArgumentTypes>(arguments)...)
|
| 60 |
+
);
|
| 61 |
+
|
| 62 |
+
std::future<ReturnType> result_future = packaged_task->get_future();
|
| 63 |
+
|
| 64 |
+
{
|
| 65 |
+
std::unique_lock<std::mutex> lock(queue_mutex);
|
| 66 |
+
|
| 67 |
+
if (should_stop.load()) {
|
| 68 |
+
throw std::runtime_error("Cannot submit task to stopped thread pool");
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
task_queue.emplace([packaged_task]() {
|
| 72 |
+
(*packaged_task)();
|
| 73 |
+
});
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
task_available_condition.notify_one();
|
| 77 |
+
|
| 78 |
+
return result_future;
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
#endif
|
accelerator/src/accelerator_core.cpp
ADDED
|
@@ -0,0 +1,485 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//
|
| 2 |
+
// SPDX-FileCopyrightText: Hadad <hadad@linuxmail.org>
|
| 3 |
+
// SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
//
|
| 5 |
+
|
| 6 |
+
#include "accelerator_core.hpp"
|
| 7 |
+
#include <chrono>
|
| 8 |
+
#include <cstring>
|
| 9 |
+
#include <ctime>
|
| 10 |
+
#include <iomanip>
|
| 11 |
+
#include <iostream>
|
| 12 |
+
#include <sstream>
|
| 13 |
+
#include <signal.h>
|
| 14 |
+
|
| 15 |
+
namespace pocket_tts_accelerator {
|
| 16 |
+
|
| 17 |
+
static AcceleratorCore* global_accelerator_instance = nullptr;
|
| 18 |
+
static volatile sig_atomic_t last_received_signal = 0;
|
| 19 |
+
|
| 20 |
+
static void signal_handler_function(int signal_number) {
|
| 21 |
+
last_received_signal = signal_number;
|
| 22 |
+
if (global_accelerator_instance != nullptr) {
|
| 23 |
+
global_accelerator_instance->shutdown();
|
| 24 |
+
}
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
AcceleratorCore::AcceleratorCore(const AcceleratorConfiguration& configuration)
|
| 28 |
+
: config(configuration)
|
| 29 |
+
, is_initialized(false)
|
| 30 |
+
, should_shutdown(false) {
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
AcceleratorCore::~AcceleratorCore() {
|
| 34 |
+
shutdown();
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
bool AcceleratorCore::initialize() {
|
| 38 |
+
if (is_initialized.load()) {
|
| 39 |
+
return true;
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
log_message("Initializing Pocket TTS Accelerator...");
|
| 43 |
+
|
| 44 |
+
memory_pool = std::make_unique<MemoryPool>(config.memory_pool_size_bytes);
|
| 45 |
+
log_message("Memory pool initialized with " + std::to_string(config.memory_pool_size_bytes / (1024 * 1024)) + " MB");
|
| 46 |
+
|
| 47 |
+
thread_pool = std::make_unique<ThreadPool>(config.number_of_worker_threads);
|
| 48 |
+
log_message("Thread pool initialized with " + std::to_string(config.number_of_worker_threads) + " worker threads");
|
| 49 |
+
|
| 50 |
+
audio_processor = std::make_unique<AudioProcessor>(*memory_pool);
|
| 51 |
+
log_message("Audio processor initialized");
|
| 52 |
+
|
| 53 |
+
ipc_handler = std::make_unique<IpcHandler>(config.ipc_socket_path);
|
| 54 |
+
log_message("IPC handler created for socket: " + config.ipc_socket_path);
|
| 55 |
+
|
| 56 |
+
register_all_command_handlers();
|
| 57 |
+
|
| 58 |
+
ipc_handler->set_shutdown_callback([this]() {
|
| 59 |
+
this->shutdown();
|
| 60 |
+
});
|
| 61 |
+
|
| 62 |
+
if (!ipc_handler->start_server()) {
|
| 63 |
+
log_message("ERROR: Failed to start IPC server");
|
| 64 |
+
return false;
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
log_message("IPC server started successfully");
|
| 68 |
+
|
| 69 |
+
global_accelerator_instance = this;
|
| 70 |
+
setup_signal_handlers();
|
| 71 |
+
|
| 72 |
+
is_initialized.store(true);
|
| 73 |
+
log_message("Pocket TTS Accelerator initialized successfully");
|
| 74 |
+
|
| 75 |
+
return true;
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
void AcceleratorCore::run() {
|
| 79 |
+
if (!is_initialized.load()) {
|
| 80 |
+
log_message("ERROR: Accelerator not initialized");
|
| 81 |
+
return;
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
log_message("Accelerator running and waiting for commands...");
|
| 85 |
+
|
| 86 |
+
while (!should_shutdown.load()) {
|
| 87 |
+
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
| 88 |
+
|
| 89 |
+
if (last_received_signal != 0) {
|
| 90 |
+
log_message("Received signal: " + std::to_string(last_received_signal));
|
| 91 |
+
last_received_signal = 0;
|
| 92 |
+
}
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
log_message("Accelerator main loop exited");
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
void AcceleratorCore::shutdown() {
|
| 99 |
+
if (should_shutdown.exchange(true)) {
|
| 100 |
+
return;
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
log_message("Shutting down Pocket TTS Accelerator...");
|
| 104 |
+
|
| 105 |
+
if (ipc_handler) {
|
| 106 |
+
ipc_handler->stop_server();
|
| 107 |
+
log_message("IPC server stopped");
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
if (thread_pool) {
|
| 111 |
+
thread_pool->shutdown();
|
| 112 |
+
log_message("Thread pool shut down");
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
if (memory_pool) {
|
| 116 |
+
memory_pool->reset_pool();
|
| 117 |
+
log_message("Memory pool reset");
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
is_initialized.store(false);
|
| 121 |
+
log_message("Pocket TTS Accelerator shut down complete");
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
bool AcceleratorCore::is_running() const {
|
| 125 |
+
return is_initialized.load() && !should_shutdown.load();
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
std::string AcceleratorCore::get_status_string() const {
|
| 129 |
+
if (!is_initialized.load()) {
|
| 130 |
+
return "Not initialized";
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
if (should_shutdown.load()) {
|
| 134 |
+
return "Shutting down";
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
return "Running";
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
AcceleratorConfiguration AcceleratorCore::get_default_configuration() {
|
| 141 |
+
AcceleratorConfiguration default_config;
|
| 142 |
+
default_config.number_of_worker_threads = 2;
|
| 143 |
+
default_config.memory_pool_size_bytes = 64 * 1024 * 1024;
|
| 144 |
+
default_config.ipc_socket_path = "/tmp/pocket_tts_accelerator.sock";
|
| 145 |
+
default_config.enable_verbose_logging = true;
|
| 146 |
+
return default_config;
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
void AcceleratorCore::register_all_command_handlers() {
|
| 150 |
+
ipc_handler->register_command_handler(
|
| 151 |
+
CommandType::PING,
|
| 152 |
+
[this](const std::vector<std::uint8_t>& payload) {
|
| 153 |
+
return this->handle_ping_command(payload);
|
| 154 |
+
}
|
| 155 |
+
);
|
| 156 |
+
|
| 157 |
+
ipc_handler->register_command_handler(
|
| 158 |
+
CommandType::PROCESS_AUDIO,
|
| 159 |
+
[this](const std::vector<std::uint8_t>& payload) {
|
| 160 |
+
return this->handle_process_audio_command(payload);
|
| 161 |
+
}
|
| 162 |
+
);
|
| 163 |
+
|
| 164 |
+
ipc_handler->register_command_handler(
|
| 165 |
+
CommandType::CONVERT_TO_MONO,
|
| 166 |
+
[this](const std::vector<std::uint8_t>& payload) {
|
| 167 |
+
return this->handle_convert_to_mono_command(payload);
|
| 168 |
+
}
|
| 169 |
+
);
|
| 170 |
+
|
| 171 |
+
ipc_handler->register_command_handler(
|
| 172 |
+
CommandType::CONVERT_TO_PCM,
|
| 173 |
+
[this](const std::vector<std::uint8_t>& payload) {
|
| 174 |
+
return this->handle_convert_to_pcm_command(payload);
|
| 175 |
+
}
|
| 176 |
+
);
|
| 177 |
+
|
| 178 |
+
ipc_handler->register_command_handler(
|
| 179 |
+
CommandType::RESAMPLE_AUDIO,
|
| 180 |
+
[this](const std::vector<std::uint8_t>& payload) {
|
| 181 |
+
return this->handle_resample_audio_command(payload);
|
| 182 |
+
}
|
| 183 |
+
);
|
| 184 |
+
|
| 185 |
+
ipc_handler->register_command_handler(
|
| 186 |
+
CommandType::GET_MEMORY_STATS,
|
| 187 |
+
[this](const std::vector<std::uint8_t>& payload) {
|
| 188 |
+
return this->handle_get_memory_stats_command(payload);
|
| 189 |
+
}
|
| 190 |
+
);
|
| 191 |
+
|
| 192 |
+
ipc_handler->register_command_handler(
|
| 193 |
+
CommandType::CLEAR_MEMORY_POOL,
|
| 194 |
+
[this](const std::vector<std::uint8_t>& payload) {
|
| 195 |
+
return this->handle_clear_memory_pool_command(payload);
|
| 196 |
+
}
|
| 197 |
+
);
|
| 198 |
+
|
| 199 |
+
log_message("All command handlers registered");
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
void AcceleratorCore::setup_signal_handlers() {
|
| 203 |
+
signal(SIGINT, signal_handler_function);
|
| 204 |
+
signal(SIGTERM, signal_handler_function);
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
std::vector<std::uint8_t> AcceleratorCore::handle_ping_command(const std::vector<std::uint8_t>& payload) {
|
| 208 |
+
std::string payload_content;
|
| 209 |
+
if (!payload.empty()) {
|
| 210 |
+
payload_content = std::string(payload.begin(), payload.end());
|
| 211 |
+
log_message("Received PING command with payload: " + payload_content);
|
| 212 |
+
} else {
|
| 213 |
+
log_message("Received PING command");
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
std::string response_message = "PONG";
|
| 217 |
+
if (!payload_content.empty()) {
|
| 218 |
+
response_message += ":" + payload_content;
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
return std::vector<std::uint8_t>(response_message.begin(), response_message.end());
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
std::vector<std::uint8_t> AcceleratorCore::handle_process_audio_command(const std::vector<std::uint8_t>& payload) {
|
| 225 |
+
log_message("Received PROCESS_AUDIO command with payload size: " + std::to_string(payload.size()) + " bytes");
|
| 226 |
+
|
| 227 |
+
if (payload.size() < sizeof(ProcessAudioRequest)) {
|
| 228 |
+
std::string error_message = "ERROR:Invalid payload size, expected " + std::to_string(sizeof(ProcessAudioRequest)) + " bytes";
|
| 229 |
+
return std::vector<std::uint8_t>(error_message.begin(), error_message.end());
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
ProcessAudioRequest request;
|
| 233 |
+
std::memcpy(&request, payload.data(), sizeof(ProcessAudioRequest));
|
| 234 |
+
|
| 235 |
+
std::string input_path(request.input_file_path);
|
| 236 |
+
std::string output_path(request.output_file_path);
|
| 237 |
+
|
| 238 |
+
log_message("Processing audio from: " + input_path + " to: " + output_path);
|
| 239 |
+
|
| 240 |
+
auto future_result = thread_pool->submit_task([this, input_path, output_path]() {
|
| 241 |
+
return this->audio_processor->process_audio_for_voice_cloning(input_path, output_path);
|
| 242 |
+
});
|
| 243 |
+
|
| 244 |
+
AudioProcessingResult result = future_result.get();
|
| 245 |
+
|
| 246 |
+
if (result.success) {
|
| 247 |
+
log_message("Audio processing completed successfully");
|
| 248 |
+
std::string success_message = "SUCCESS:" + output_path;
|
| 249 |
+
return std::vector<std::uint8_t>(success_message.begin(), success_message.end());
|
| 250 |
+
} else {
|
| 251 |
+
log_message("Audio processing failed: " + result.error_message);
|
| 252 |
+
std::string error_message = "ERROR:" + result.error_message;
|
| 253 |
+
return std::vector<std::uint8_t>(error_message.begin(), error_message.end());
|
| 254 |
+
}
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
std::vector<std::uint8_t> AcceleratorCore::handle_convert_to_mono_command(const std::vector<std::uint8_t>& payload) {
|
| 258 |
+
log_message("Received CONVERT_TO_MONO command with payload size: " + std::to_string(payload.size()) + " bytes");
|
| 259 |
+
|
| 260 |
+
if (payload.size() < sizeof(ProcessAudioRequest)) {
|
| 261 |
+
std::string error_message = "ERROR:Invalid payload size, expected " + std::to_string(sizeof(ProcessAudioRequest)) + " bytes";
|
| 262 |
+
return std::vector<std::uint8_t>(error_message.begin(), error_message.end());
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
ProcessAudioRequest request;
|
| 266 |
+
std::memcpy(&request, payload.data(), sizeof(ProcessAudioRequest));
|
| 267 |
+
|
| 268 |
+
std::string input_path(request.input_file_path);
|
| 269 |
+
std::string output_path(request.output_file_path);
|
| 270 |
+
|
| 271 |
+
log_message("Converting to mono from: " + input_path + " to: " + output_path);
|
| 272 |
+
|
| 273 |
+
AudioData audio_data = audio_processor->read_wav_file(input_path);
|
| 274 |
+
|
| 275 |
+
if (!audio_data.is_valid) {
|
| 276 |
+
std::string error_message = "ERROR:" + audio_data.error_message;
|
| 277 |
+
return std::vector<std::uint8_t>(error_message.begin(), error_message.end());
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
AudioProcessingResult result = audio_processor->convert_to_mono(audio_data);
|
| 281 |
+
|
| 282 |
+
if (!result.success) {
|
| 283 |
+
std::string error_message = "ERROR:" + result.error_message;
|
| 284 |
+
return std::vector<std::uint8_t>(error_message.begin(), error_message.end());
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
AudioData output_audio;
|
| 288 |
+
output_audio.samples = std::move(result.processed_samples);
|
| 289 |
+
output_audio.sample_rate = result.output_sample_rate;
|
| 290 |
+
output_audio.number_of_channels = 1;
|
| 291 |
+
output_audio.bits_per_sample = 16;
|
| 292 |
+
output_audio.is_valid = true;
|
| 293 |
+
|
| 294 |
+
if (!audio_processor->write_wav_file(output_path, output_audio)) {
|
| 295 |
+
std::string error_message = "ERROR:Failed to write output file";
|
| 296 |
+
return std::vector<std::uint8_t>(error_message.begin(), error_message.end());
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
log_message("Mono conversion completed successfully");
|
| 300 |
+
std::string success_message = "SUCCESS:" + output_path;
|
| 301 |
+
return std::vector<std::uint8_t>(success_message.begin(), success_message.end());
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
std::vector<std::uint8_t> AcceleratorCore::handle_convert_to_pcm_command(const std::vector<std::uint8_t>& payload) {
|
| 305 |
+
log_message("Received CONVERT_TO_PCM command with payload size: " + std::to_string(payload.size()) + " bytes");
|
| 306 |
+
|
| 307 |
+
if (payload.size() < sizeof(ProcessAudioRequest)) {
|
| 308 |
+
std::string error_message = "ERROR:Invalid payload size, expected " + std::to_string(sizeof(ProcessAudioRequest)) + " bytes";
|
| 309 |
+
return std::vector<std::uint8_t>(error_message.begin(), error_message.end());
|
| 310 |
+
}
|
| 311 |
+
|
| 312 |
+
ProcessAudioRequest request;
|
| 313 |
+
std::memcpy(&request, payload.data(), sizeof(ProcessAudioRequest));
|
| 314 |
+
|
| 315 |
+
std::string input_path(request.input_file_path);
|
| 316 |
+
std::string output_path(request.output_file_path);
|
| 317 |
+
|
| 318 |
+
log_message("Converting to PCM from: " + input_path + " to: " + output_path);
|
| 319 |
+
|
| 320 |
+
AudioData audio_data = audio_processor->read_wav_file(input_path);
|
| 321 |
+
|
| 322 |
+
if (!audio_data.is_valid) {
|
| 323 |
+
std::string error_message = "ERROR:" + audio_data.error_message;
|
| 324 |
+
return std::vector<std::uint8_t>(error_message.begin(), error_message.end());
|
| 325 |
+
}
|
| 326 |
+
|
| 327 |
+
AudioData mono_audio;
|
| 328 |
+
|
| 329 |
+
if (audio_data.number_of_channels > 1) {
|
| 330 |
+
log_message("Input has " + std::to_string(audio_data.number_of_channels) + " channels, converting to mono");
|
| 331 |
+
AudioProcessingResult mono_result = audio_processor->convert_to_mono(audio_data);
|
| 332 |
+
|
| 333 |
+
if (!mono_result.success) {
|
| 334 |
+
std::string error_message = "ERROR:" + mono_result.error_message;
|
| 335 |
+
return std::vector<std::uint8_t>(error_message.begin(), error_message.end());
|
| 336 |
+
}
|
| 337 |
+
|
| 338 |
+
mono_audio.samples = std::move(mono_result.processed_samples);
|
| 339 |
+
mono_audio.sample_rate = mono_result.output_sample_rate;
|
| 340 |
+
} else {
|
| 341 |
+
mono_audio.samples = std::move(audio_data.samples);
|
| 342 |
+
mono_audio.sample_rate = audio_data.sample_rate;
|
| 343 |
+
}
|
| 344 |
+
|
| 345 |
+
mono_audio.number_of_channels = 1;
|
| 346 |
+
mono_audio.bits_per_sample = 16;
|
| 347 |
+
mono_audio.is_valid = true;
|
| 348 |
+
|
| 349 |
+
if (!audio_processor->write_wav_file(output_path, mono_audio)) {
|
| 350 |
+
std::string error_message = "ERROR:Failed to write output file";
|
| 351 |
+
return std::vector<std::uint8_t>(error_message.begin(), error_message.end());
|
| 352 |
+
}
|
| 353 |
+
|
| 354 |
+
log_message("PCM conversion completed successfully");
|
| 355 |
+
std::string success_message = "SUCCESS:" + output_path;
|
| 356 |
+
return std::vector<std::uint8_t>(success_message.begin(), success_message.end());
|
| 357 |
+
}
|
| 358 |
+
|
| 359 |
+
std::vector<std::uint8_t> AcceleratorCore::handle_resample_audio_command(const std::vector<std::uint8_t>& payload) {
|
| 360 |
+
log_message("Received RESAMPLE_AUDIO command with payload size: " + std::to_string(payload.size()) + " bytes");
|
| 361 |
+
|
| 362 |
+
if (payload.size() < sizeof(ProcessAudioRequest)) {
|
| 363 |
+
std::string error_message = "ERROR:Invalid payload size, expected " + std::to_string(sizeof(ProcessAudioRequest)) + " bytes";
|
| 364 |
+
return std::vector<std::uint8_t>(error_message.begin(), error_message.end());
|
| 365 |
+
}
|
| 366 |
+
|
| 367 |
+
ProcessAudioRequest request;
|
| 368 |
+
std::memcpy(&request, payload.data(), sizeof(ProcessAudioRequest));
|
| 369 |
+
|
| 370 |
+
std::string input_path(request.input_file_path);
|
| 371 |
+
std::string output_path(request.output_file_path);
|
| 372 |
+
std::uint32_t target_sample_rate = request.target_sample_rate;
|
| 373 |
+
|
| 374 |
+
log_message("Resampling audio from: " + input_path + " to: " + output_path + " at " + std::to_string(target_sample_rate) + " Hz");
|
| 375 |
+
|
| 376 |
+
AudioData audio_data = audio_processor->read_wav_file(input_path);
|
| 377 |
+
|
| 378 |
+
if (!audio_data.is_valid) {
|
| 379 |
+
std::string error_message = "ERROR:" + audio_data.error_message;
|
| 380 |
+
return std::vector<std::uint8_t>(error_message.begin(), error_message.end());
|
| 381 |
+
}
|
| 382 |
+
|
| 383 |
+
AudioProcessingResult result = audio_processor->resample_audio(audio_data, target_sample_rate);
|
| 384 |
+
|
| 385 |
+
if (!result.success) {
|
| 386 |
+
std::string error_message = "ERROR:" + result.error_message;
|
| 387 |
+
return std::vector<std::uint8_t>(error_message.begin(), error_message.end());
|
| 388 |
+
}
|
| 389 |
+
|
| 390 |
+
AudioData output_audio;
|
| 391 |
+
output_audio.samples = std::move(result.processed_samples);
|
| 392 |
+
output_audio.sample_rate = result.output_sample_rate;
|
| 393 |
+
output_audio.number_of_channels = audio_data.number_of_channels;
|
| 394 |
+
output_audio.bits_per_sample = 16;
|
| 395 |
+
output_audio.is_valid = true;
|
| 396 |
+
|
| 397 |
+
if (!audio_processor->write_wav_file(output_path, output_audio)) {
|
| 398 |
+
std::string error_message = "ERROR:Failed to write output file";
|
| 399 |
+
return std::vector<std::uint8_t>(error_message.begin(), error_message.end());
|
| 400 |
+
}
|
| 401 |
+
|
| 402 |
+
log_message("Resampling completed successfully");
|
| 403 |
+
std::string success_message = "SUCCESS:" + output_path;
|
| 404 |
+
return std::vector<std::uint8_t>(success_message.begin(), success_message.end());
|
| 405 |
+
}
|
| 406 |
+
|
| 407 |
+
std::vector<std::uint8_t> AcceleratorCore::handle_get_memory_stats_command(const std::vector<std::uint8_t>& payload) {
|
| 408 |
+
log_message("Received GET_MEMORY_STATS command with payload size: " + std::to_string(payload.size()) + " bytes");
|
| 409 |
+
|
| 410 |
+
MemoryStatsResponse stats;
|
| 411 |
+
stats.total_allocated_bytes = memory_pool->get_total_allocated_bytes();
|
| 412 |
+
stats.total_used_bytes = memory_pool->get_total_used_bytes();
|
| 413 |
+
stats.block_count = memory_pool->get_block_count();
|
| 414 |
+
|
| 415 |
+
log_message("Memory stats - Allocated: " + std::to_string(stats.total_allocated_bytes) +
|
| 416 |
+
" bytes, Used: " + std::to_string(stats.total_used_bytes) +
|
| 417 |
+
" bytes, Blocks: " + std::to_string(stats.block_count));
|
| 418 |
+
|
| 419 |
+
std::vector<std::uint8_t> response(sizeof(MemoryStatsResponse));
|
| 420 |
+
std::memcpy(response.data(), &stats, sizeof(MemoryStatsResponse));
|
| 421 |
+
|
| 422 |
+
return response;
|
| 423 |
+
}
|
| 424 |
+
|
| 425 |
+
std::vector<std::uint8_t> AcceleratorCore::handle_clear_memory_pool_command(const std::vector<std::uint8_t>& payload) {
|
| 426 |
+
log_message("Received CLEAR_MEMORY_POOL command with payload size: " + std::to_string(payload.size()) + " bytes");
|
| 427 |
+
|
| 428 |
+
std::size_t blocks_before = memory_pool->get_block_count();
|
| 429 |
+
std::size_t allocated_before = memory_pool->get_total_allocated_bytes();
|
| 430 |
+
|
| 431 |
+
memory_pool->clear_unused_blocks();
|
| 432 |
+
|
| 433 |
+
std::size_t blocks_after = memory_pool->get_block_count();
|
| 434 |
+
std::size_t allocated_after = memory_pool->get_total_allocated_bytes();
|
| 435 |
+
|
| 436 |
+
std::size_t blocks_freed = blocks_before - blocks_after;
|
| 437 |
+
std::size_t bytes_freed = allocated_before - allocated_after;
|
| 438 |
+
|
| 439 |
+
log_message("Memory pool cleared - Freed " + std::to_string(blocks_freed) +
|
| 440 |
+
" blocks (" + std::to_string(bytes_freed) + " bytes)");
|
| 441 |
+
|
| 442 |
+
std::string success_message = "SUCCESS:Freed " + std::to_string(blocks_freed) +
|
| 443 |
+
" blocks (" + std::to_string(bytes_freed) + " bytes)";
|
| 444 |
+
return std::vector<std::uint8_t>(success_message.begin(), success_message.end());
|
| 445 |
+
}
|
| 446 |
+
|
| 447 |
+
std::vector<std::uint8_t> AcceleratorCore::handle_shutdown_command(const std::vector<std::uint8_t>& payload) {
|
| 448 |
+
std::string shutdown_reason;
|
| 449 |
+
if (!payload.empty()) {
|
| 450 |
+
shutdown_reason = std::string(payload.begin(), payload.end());
|
| 451 |
+
log_message("Received SHUTDOWN command with reason: " + shutdown_reason);
|
| 452 |
+
} else {
|
| 453 |
+
log_message("Received SHUTDOWN command");
|
| 454 |
+
}
|
| 455 |
+
|
| 456 |
+
shutdown();
|
| 457 |
+
|
| 458 |
+
std::string success_message = "SUCCESS:Shutting down";
|
| 459 |
+
if (!shutdown_reason.empty()) {
|
| 460 |
+
success_message += " (reason: " + shutdown_reason + ")";
|
| 461 |
+
}
|
| 462 |
+
return std::vector<std::uint8_t>(success_message.begin(), success_message.end());
|
| 463 |
+
}
|
| 464 |
+
|
| 465 |
+
void AcceleratorCore::log_message(const std::string& message) const {
|
| 466 |
+
if (config.enable_verbose_logging) {
|
| 467 |
+
auto now = std::chrono::system_clock::now();
|
| 468 |
+
std::time_t time_t_now = std::chrono::system_clock::to_time_t(now);
|
| 469 |
+
|
| 470 |
+
auto milliseconds = std::chrono::duration_cast<std::chrono::milliseconds>(
|
| 471 |
+
now.time_since_epoch()
|
| 472 |
+
) % 1000;
|
| 473 |
+
|
| 474 |
+
std::tm time_info;
|
| 475 |
+
localtime_r(&time_t_now, &time_info);
|
| 476 |
+
|
| 477 |
+
std::ostringstream timestamp_stream;
|
| 478 |
+
timestamp_stream << std::put_time(&time_info, "%Y-%m-%d %H:%M:%S");
|
| 479 |
+
timestamp_stream << '.' << std::setfill('0') << std::setw(3) << milliseconds.count();
|
| 480 |
+
|
| 481 |
+
std::cout << "[" << timestamp_stream.str() << "] [ACCELERATOR] " << message << std::endl;
|
| 482 |
+
}
|
| 483 |
+
}
|
| 484 |
+
|
| 485 |
+
}
|
accelerator/src/audio_processor.cpp
ADDED
|
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//
|
| 2 |
+
// SPDX-FileCopyrightText: Hadad <hadad@linuxmail.org>
|
| 3 |
+
// SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
//
|
| 5 |
+
|
| 6 |
+
#include "audio_processor.hpp"
|
| 7 |
+
#include <algorithm>
|
| 8 |
+
#include <cmath>
|
| 9 |
+
#include <cstring>
|
| 10 |
+
#include <fstream>
|
| 11 |
+
|
| 12 |
+
namespace pocket_tts_accelerator {
|
| 13 |
+
|
| 14 |
+
AudioProcessor::AudioProcessor(MemoryPool& shared_memory_pool)
|
| 15 |
+
: memory_pool(shared_memory_pool) {
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
AudioProcessor::~AudioProcessor() {
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
AudioData AudioProcessor::read_wav_file(const std::string& file_path) {
|
| 22 |
+
AudioData result;
|
| 23 |
+
result.is_valid = false;
|
| 24 |
+
|
| 25 |
+
std::ifstream file_stream(file_path, std::ios::binary);
|
| 26 |
+
|
| 27 |
+
if (!file_stream.is_open()) {
|
| 28 |
+
result.error_message = "Failed to open file: " + file_path;
|
| 29 |
+
return result;
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
WavFileHeader header;
|
| 33 |
+
file_stream.read(reinterpret_cast<char*>(&header), sizeof(WavFileHeader));
|
| 34 |
+
|
| 35 |
+
if (file_stream.gcount() < static_cast<std::streamsize>(sizeof(WavFileHeader))) {
|
| 36 |
+
result.error_message = "File is too small to be a valid WAV file";
|
| 37 |
+
return result;
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
if (!validate_wav_header(header)) {
|
| 41 |
+
result.error_message = "Invalid WAV file header";
|
| 42 |
+
return result;
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
result.sample_rate = header.sample_rate;
|
| 46 |
+
result.number_of_channels = header.number_of_channels;
|
| 47 |
+
result.bits_per_sample = header.bits_per_sample;
|
| 48 |
+
|
| 49 |
+
std::size_t sample_count = header.data_size / (header.bits_per_sample / 8);
|
| 50 |
+
result.samples.resize(sample_count);
|
| 51 |
+
|
| 52 |
+
if (header.bits_per_sample == 16) {
|
| 53 |
+
file_stream.read(reinterpret_cast<char*>(result.samples.data()), header.data_size);
|
| 54 |
+
} else if (header.bits_per_sample == 8) {
|
| 55 |
+
std::vector<std::uint8_t> raw_data(sample_count);
|
| 56 |
+
file_stream.read(reinterpret_cast<char*>(raw_data.data()), header.data_size);
|
| 57 |
+
convert_uint8_to_int16(raw_data.data(), result.samples.data(), sample_count);
|
| 58 |
+
} else if (header.bits_per_sample == 32) {
|
| 59 |
+
if (header.audio_format == 3) {
|
| 60 |
+
std::vector<float> raw_data(sample_count);
|
| 61 |
+
file_stream.read(reinterpret_cast<char*>(raw_data.data()), header.data_size);
|
| 62 |
+
convert_float32_to_int16(raw_data.data(), result.samples.data(), sample_count);
|
| 63 |
+
} else {
|
| 64 |
+
std::vector<std::int32_t> raw_data(sample_count);
|
| 65 |
+
file_stream.read(reinterpret_cast<char*>(raw_data.data()), header.data_size);
|
| 66 |
+
convert_int32_to_int16(raw_data.data(), result.samples.data(), sample_count);
|
| 67 |
+
}
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
result.is_valid = true;
|
| 71 |
+
return result;
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
bool AudioProcessor::write_wav_file(const std::string& file_path, const AudioData& audio_data) {
|
| 75 |
+
std::ofstream file_stream(file_path, std::ios::binary);
|
| 76 |
+
|
| 77 |
+
if (!file_stream.is_open()) {
|
| 78 |
+
return false;
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
std::uint32_t data_size = static_cast<std::uint32_t>(audio_data.samples.size() * sizeof(std::int16_t));
|
| 82 |
+
std::uint32_t file_size = data_size + 36;
|
| 83 |
+
|
| 84 |
+
WavFileHeader header;
|
| 85 |
+
std::memcpy(header.riff_marker, "RIFF", 4);
|
| 86 |
+
header.file_size = file_size;
|
| 87 |
+
std::memcpy(header.wave_marker, "WAVE", 4);
|
| 88 |
+
std::memcpy(header.format_marker, "fmt ", 4);
|
| 89 |
+
header.format_chunk_size = 16;
|
| 90 |
+
header.audio_format = 1;
|
| 91 |
+
header.number_of_channels = audio_data.number_of_channels;
|
| 92 |
+
header.sample_rate = audio_data.sample_rate;
|
| 93 |
+
header.bits_per_sample = 16;
|
| 94 |
+
header.byte_rate = audio_data.sample_rate * audio_data.number_of_channels * 2;
|
| 95 |
+
header.block_align = audio_data.number_of_channels * 2;
|
| 96 |
+
std::memcpy(header.data_marker, "data", 4);
|
| 97 |
+
header.data_size = data_size;
|
| 98 |
+
|
| 99 |
+
file_stream.write(reinterpret_cast<const char*>(&header), sizeof(WavFileHeader));
|
| 100 |
+
file_stream.write(reinterpret_cast<const char*>(audio_data.samples.data()), data_size);
|
| 101 |
+
|
| 102 |
+
return file_stream.good();
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
AudioProcessingResult AudioProcessor::convert_to_mono(const AudioData& input_audio) {
|
| 106 |
+
AudioProcessingResult result;
|
| 107 |
+
result.success = false;
|
| 108 |
+
|
| 109 |
+
if (!input_audio.is_valid) {
|
| 110 |
+
result.error_message = "Invalid input audio";
|
| 111 |
+
return result;
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
if (input_audio.number_of_channels == 1) {
|
| 115 |
+
result.processed_samples = input_audio.samples;
|
| 116 |
+
result.output_sample_rate = input_audio.sample_rate;
|
| 117 |
+
result.success = true;
|
| 118 |
+
return result;
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
std::size_t frame_count = input_audio.samples.size() / input_audio.number_of_channels;
|
| 122 |
+
result.processed_samples.resize(frame_count);
|
| 123 |
+
|
| 124 |
+
mix_channels_to_mono(
|
| 125 |
+
input_audio.samples.data(),
|
| 126 |
+
result.processed_samples.data(),
|
| 127 |
+
frame_count,
|
| 128 |
+
input_audio.number_of_channels
|
| 129 |
+
);
|
| 130 |
+
|
| 131 |
+
result.output_sample_rate = input_audio.sample_rate;
|
| 132 |
+
result.success = true;
|
| 133 |
+
return result;
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
AudioProcessingResult AudioProcessor::convert_to_pcm_int16(const AudioData& input_audio) {
|
| 137 |
+
AudioProcessingResult result;
|
| 138 |
+
result.success = false;
|
| 139 |
+
|
| 140 |
+
if (!input_audio.is_valid) {
|
| 141 |
+
result.error_message = "Invalid input audio";
|
| 142 |
+
return result;
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
result.processed_samples = input_audio.samples;
|
| 146 |
+
result.output_sample_rate = input_audio.sample_rate;
|
| 147 |
+
result.success = true;
|
| 148 |
+
return result;
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
AudioProcessingResult AudioProcessor::resample_audio(const AudioData& input_audio, std::uint32_t target_sample_rate) {
|
| 152 |
+
AudioProcessingResult result;
|
| 153 |
+
result.success = false;
|
| 154 |
+
|
| 155 |
+
if (!input_audio.is_valid) {
|
| 156 |
+
result.error_message = "Invalid input audio";
|
| 157 |
+
return result;
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
if (input_audio.sample_rate == target_sample_rate) {
|
| 161 |
+
result.processed_samples = input_audio.samples;
|
| 162 |
+
result.output_sample_rate = target_sample_rate;
|
| 163 |
+
result.success = true;
|
| 164 |
+
return result;
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
double ratio = static_cast<double>(target_sample_rate) / static_cast<double>(input_audio.sample_rate);
|
| 168 |
+
std::size_t output_sample_count = static_cast<std::size_t>(input_audio.samples.size() * ratio);
|
| 169 |
+
|
| 170 |
+
result.processed_samples.resize(output_sample_count);
|
| 171 |
+
|
| 172 |
+
for (std::size_t output_index = 0; output_index < output_sample_count; ++output_index) {
|
| 173 |
+
double source_position = output_index / ratio;
|
| 174 |
+
std::size_t source_index_floor = static_cast<std::size_t>(source_position);
|
| 175 |
+
std::size_t source_index_ceil = source_index_floor + 1;
|
| 176 |
+
double fractional_part = source_position - source_index_floor;
|
| 177 |
+
|
| 178 |
+
if (source_index_ceil >= input_audio.samples.size()) {
|
| 179 |
+
source_index_ceil = input_audio.samples.size() - 1;
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
double interpolated_value =
|
| 183 |
+
input_audio.samples[source_index_floor] * (1.0 - fractional_part) +
|
| 184 |
+
input_audio.samples[source_index_ceil] * fractional_part;
|
| 185 |
+
|
| 186 |
+
result.processed_samples[output_index] = static_cast<std::int16_t>(
|
| 187 |
+
std::clamp(interpolated_value, -32768.0, 32767.0)
|
| 188 |
+
);
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
result.output_sample_rate = target_sample_rate;
|
| 192 |
+
result.success = true;
|
| 193 |
+
return result;
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
AudioProcessingResult AudioProcessor::normalize_audio(const AudioData& input_audio, float target_peak_level) {
|
| 197 |
+
AudioProcessingResult result;
|
| 198 |
+
result.success = false;
|
| 199 |
+
|
| 200 |
+
if (!input_audio.is_valid) {
|
| 201 |
+
result.error_message = "Invalid input audio";
|
| 202 |
+
return result;
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
std::int16_t max_absolute_value = 0;
|
| 206 |
+
for (const std::int16_t sample : input_audio.samples) {
|
| 207 |
+
std::int16_t absolute_value = static_cast<std::int16_t>(std::abs(sample));
|
| 208 |
+
if (absolute_value > max_absolute_value) {
|
| 209 |
+
max_absolute_value = absolute_value;
|
| 210 |
+
}
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
if (max_absolute_value == 0) {
|
| 214 |
+
result.processed_samples = input_audio.samples;
|
| 215 |
+
result.output_sample_rate = input_audio.sample_rate;
|
| 216 |
+
result.success = true;
|
| 217 |
+
return result;
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
float normalization_factor = (target_peak_level * 32767.0f) / static_cast<float>(max_absolute_value);
|
| 221 |
+
|
| 222 |
+
result.processed_samples.resize(input_audio.samples.size());
|
| 223 |
+
|
| 224 |
+
for (std::size_t index = 0; index < input_audio.samples.size(); ++index) {
|
| 225 |
+
float normalized_sample = static_cast<float>(input_audio.samples[index]) * normalization_factor;
|
| 226 |
+
result.processed_samples[index] = static_cast<std::int16_t>(
|
| 227 |
+
std::clamp(normalized_sample, -32768.0f, 32767.0f)
|
| 228 |
+
);
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
result.output_sample_rate = input_audio.sample_rate;
|
| 232 |
+
result.success = true;
|
| 233 |
+
return result;
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
AudioProcessingResult AudioProcessor::process_audio_for_voice_cloning(
|
| 237 |
+
const std::string& input_file_path,
|
| 238 |
+
const std::string& output_file_path
|
| 239 |
+
) {
|
| 240 |
+
AudioProcessingResult result;
|
| 241 |
+
result.success = false;
|
| 242 |
+
|
| 243 |
+
AudioData input_audio = read_wav_file(input_file_path);
|
| 244 |
+
|
| 245 |
+
if (!input_audio.is_valid) {
|
| 246 |
+
result.error_message = "Failed to read input file: " + input_audio.error_message;
|
| 247 |
+
return result;
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
AudioProcessingResult mono_result = convert_to_mono(input_audio);
|
| 251 |
+
|
| 252 |
+
if (!mono_result.success) {
|
| 253 |
+
result.error_message = "Failed to convert to mono: " + mono_result.error_message;
|
| 254 |
+
return result;
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
AudioData mono_audio;
|
| 258 |
+
mono_audio.samples = std::move(mono_result.processed_samples);
|
| 259 |
+
mono_audio.sample_rate = mono_result.output_sample_rate;
|
| 260 |
+
mono_audio.number_of_channels = 1;
|
| 261 |
+
mono_audio.bits_per_sample = 16;
|
| 262 |
+
mono_audio.is_valid = true;
|
| 263 |
+
|
| 264 |
+
if (!write_wav_file(output_file_path, mono_audio)) {
|
| 265 |
+
result.error_message = "Failed to write output file";
|
| 266 |
+
return result;
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
result.processed_samples = std::move(mono_audio.samples);
|
| 270 |
+
result.output_sample_rate = mono_audio.sample_rate;
|
| 271 |
+
result.success = true;
|
| 272 |
+
return result;
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
bool AudioProcessor::validate_wav_header(const WavFileHeader& header) {
|
| 276 |
+
if (std::memcmp(header.riff_marker, "RIFF", 4) != 0) {
|
| 277 |
+
return false;
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
if (std::memcmp(header.wave_marker, "WAVE", 4) != 0) {
|
| 281 |
+
return false;
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
if (std::memcmp(header.format_marker, "fmt ", 4) != 0) {
|
| 285 |
+
return false;
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
if (header.audio_format != 1 && header.audio_format != 3) {
|
| 289 |
+
return false;
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
if (header.number_of_channels < 1 || header.number_of_channels > 16) {
|
| 293 |
+
return false;
|
| 294 |
+
}
|
| 295 |
+
|
| 296 |
+
if (header.sample_rate < 100 || header.sample_rate > 384000) {
|
| 297 |
+
return false;
|
| 298 |
+
}
|
| 299 |
+
|
| 300 |
+
if (header.bits_per_sample != 8 && header.bits_per_sample != 16 && header.bits_per_sample != 32) {
|
| 301 |
+
return false;
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
return true;
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
std::size_t AudioProcessor::calculate_audio_duration_milliseconds(const AudioData& audio_data) {
|
| 308 |
+
if (!audio_data.is_valid || audio_data.sample_rate == 0) {
|
| 309 |
+
return 0;
|
| 310 |
+
}
|
| 311 |
+
|
| 312 |
+
std::size_t frame_count = audio_data.samples.size() / audio_data.number_of_channels;
|
| 313 |
+
return (frame_count * 1000) / audio_data.sample_rate;
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
+
void AudioProcessor::convert_float32_to_int16(const float* input, std::int16_t* output, std::size_t sample_count) {
|
| 317 |
+
for (std::size_t index = 0; index < sample_count; ++index) {
|
| 318 |
+
float clamped_value = std::clamp(input[index], -1.0f, 1.0f);
|
| 319 |
+
output[index] = static_cast<std::int16_t>(clamped_value * 32767.0f);
|
| 320 |
+
}
|
| 321 |
+
}
|
| 322 |
+
|
| 323 |
+
void AudioProcessor::convert_int32_to_int16(const std::int32_t* input, std::int16_t* output, std::size_t sample_count) {
|
| 324 |
+
for (std::size_t index = 0; index < sample_count; ++index) {
|
| 325 |
+
output[index] = static_cast<std::int16_t>(input[index] >> 16);
|
| 326 |
+
}
|
| 327 |
+
}
|
| 328 |
+
|
| 329 |
+
void AudioProcessor::convert_uint8_to_int16(const std::uint8_t* input, std::int16_t* output, std::size_t sample_count) {
|
| 330 |
+
for (std::size_t index = 0; index < sample_count; ++index) {
|
| 331 |
+
output[index] = static_cast<std::int16_t>((static_cast<std::int16_t>(input[index]) - 128) * 256);
|
| 332 |
+
}
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
void AudioProcessor::mix_channels_to_mono(
|
| 336 |
+
const std::int16_t* input,
|
| 337 |
+
std::int16_t* output,
|
| 338 |
+
std::size_t frame_count,
|
| 339 |
+
std::uint16_t channel_count
|
| 340 |
+
) {
|
| 341 |
+
for (std::size_t frame_index = 0; frame_index < frame_count; ++frame_index) {
|
| 342 |
+
std::int32_t sum = 0;
|
| 343 |
+
|
| 344 |
+
for (std::uint16_t channel_index = 0; channel_index < channel_count; ++channel_index) {
|
| 345 |
+
sum += input[frame_index * channel_count + channel_index];
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
output[frame_index] = static_cast<std::int16_t>(sum / channel_count);
|
| 349 |
+
}
|
| 350 |
+
}
|
| 351 |
+
|
| 352 |
+
}
|
accelerator/src/ipc_handler.cpp
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//
|
| 2 |
+
// SPDX-FileCopyrightText: Hadad <hadad@linuxmail.org>
|
| 3 |
+
// SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
//
|
| 5 |
+
|
| 6 |
+
#include "ipc_handler.hpp"
|
| 7 |
+
#include <cstring>
|
| 8 |
+
#include <sys/socket.h>
|
| 9 |
+
#include <sys/un.h>
|
| 10 |
+
#include <unistd.h>
|
| 11 |
+
|
| 12 |
+
namespace pocket_tts_accelerator {
|
| 13 |
+
|
| 14 |
+
IpcHandler::IpcHandler(const std::string& socket_path)
|
| 15 |
+
: socket_file_path(socket_path)
|
| 16 |
+
, server_socket_fd(-1)
|
| 17 |
+
, is_server_running(false) {
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
IpcHandler::~IpcHandler() {
|
| 21 |
+
stop_server();
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
bool IpcHandler::start_server() {
|
| 25 |
+
if (is_server_running.load()) {
|
| 26 |
+
return true;
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
unlink(socket_file_path.c_str());
|
| 30 |
+
|
| 31 |
+
server_socket_fd = socket(AF_UNIX, SOCK_STREAM, 0);
|
| 32 |
+
|
| 33 |
+
if (server_socket_fd < 0) {
|
| 34 |
+
return false;
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
struct sockaddr_un server_address;
|
| 38 |
+
std::memset(&server_address, 0, sizeof(server_address));
|
| 39 |
+
server_address.sun_family = AF_UNIX;
|
| 40 |
+
std::strncpy(server_address.sun_path, socket_file_path.c_str(), sizeof(server_address.sun_path) - 1);
|
| 41 |
+
|
| 42 |
+
if (bind(server_socket_fd, reinterpret_cast<struct sockaddr*>(&server_address), sizeof(server_address)) < 0) {
|
| 43 |
+
close(server_socket_fd);
|
| 44 |
+
server_socket_fd = -1;
|
| 45 |
+
return false;
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
if (listen(server_socket_fd, CONNECTION_BACKLOG) < 0) {
|
| 49 |
+
close(server_socket_fd);
|
| 50 |
+
server_socket_fd = -1;
|
| 51 |
+
return false;
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
is_server_running.store(true);
|
| 55 |
+
accept_thread = std::thread(&IpcHandler::accept_connections_loop, this);
|
| 56 |
+
|
| 57 |
+
return true;
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
void IpcHandler::stop_server() {
|
| 61 |
+
if (!is_server_running.load()) {
|
| 62 |
+
return;
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
is_server_running.store(false);
|
| 66 |
+
|
| 67 |
+
if (server_socket_fd >= 0) {
|
| 68 |
+
shutdown(server_socket_fd, SHUT_RDWR);
|
| 69 |
+
close(server_socket_fd);
|
| 70 |
+
server_socket_fd = -1;
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
if (accept_thread.joinable()) {
|
| 74 |
+
accept_thread.join();
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
unlink(socket_file_path.c_str());
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
bool IpcHandler::is_running() const {
|
| 81 |
+
return is_server_running.load();
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
void IpcHandler::register_command_handler(CommandType command_type, CommandHandlerFunction handler) {
|
| 85 |
+
std::unique_lock<std::mutex> lock(handlers_mutex);
|
| 86 |
+
command_handlers[command_type] = std::move(handler);
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
void IpcHandler::set_shutdown_callback(std::function<void()> callback) {
|
| 90 |
+
shutdown_callback = std::move(callback);
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
void IpcHandler::accept_connections_loop() {
|
| 94 |
+
while (is_server_running.load()) {
|
| 95 |
+
struct sockaddr_un client_address;
|
| 96 |
+
socklen_t client_address_length = sizeof(client_address);
|
| 97 |
+
|
| 98 |
+
int client_socket_fd = accept(
|
| 99 |
+
server_socket_fd,
|
| 100 |
+
reinterpret_cast<struct sockaddr*>(&client_address),
|
| 101 |
+
&client_address_length
|
| 102 |
+
);
|
| 103 |
+
|
| 104 |
+
if (client_socket_fd < 0) {
|
| 105 |
+
if (!is_server_running.load()) {
|
| 106 |
+
break;
|
| 107 |
+
}
|
| 108 |
+
continue;
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
handle_client_connection(client_socket_fd);
|
| 112 |
+
close(client_socket_fd);
|
| 113 |
+
}
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
void IpcHandler::handle_client_connection(int client_socket_fd) {
|
| 117 |
+
RequestHeader request_header;
|
| 118 |
+
std::vector<std::uint8_t> request_payload;
|
| 119 |
+
|
| 120 |
+
if (!receive_request(client_socket_fd, request_header, request_payload)) {
|
| 121 |
+
return;
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
if (request_header.magic_number != PROTOCOL_MAGIC_NUMBER) {
|
| 125 |
+
ResponseHeader error_response;
|
| 126 |
+
error_response.magic_number = PROTOCOL_MAGIC_NUMBER;
|
| 127 |
+
error_response.status_code = static_cast<std::uint32_t>(ResponseStatus::ERROR_INVALID_COMMAND);
|
| 128 |
+
error_response.payload_size = 0;
|
| 129 |
+
error_response.request_id = request_header.request_id;
|
| 130 |
+
send_response(client_socket_fd, error_response, {});
|
| 131 |
+
return;
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
CommandType command_type = static_cast<CommandType>(request_header.command_type);
|
| 135 |
+
|
| 136 |
+
if (command_type == CommandType::SHUTDOWN) {
|
| 137 |
+
ResponseHeader shutdown_response;
|
| 138 |
+
shutdown_response.magic_number = PROTOCOL_MAGIC_NUMBER;
|
| 139 |
+
shutdown_response.status_code = static_cast<std::uint32_t>(ResponseStatus::SUCCESS);
|
| 140 |
+
shutdown_response.payload_size = 0;
|
| 141 |
+
shutdown_response.request_id = request_header.request_id;
|
| 142 |
+
send_response(client_socket_fd, shutdown_response, {});
|
| 143 |
+
|
| 144 |
+
if (shutdown_callback) {
|
| 145 |
+
shutdown_callback();
|
| 146 |
+
}
|
| 147 |
+
return;
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
std::vector<std::uint8_t> response_payload;
|
| 151 |
+
ResponseStatus status = ResponseStatus::SUCCESS;
|
| 152 |
+
|
| 153 |
+
{
|
| 154 |
+
std::unique_lock<std::mutex> lock(handlers_mutex);
|
| 155 |
+
auto handler_iterator = command_handlers.find(command_type);
|
| 156 |
+
|
| 157 |
+
if (handler_iterator != command_handlers.end()) {
|
| 158 |
+
try {
|
| 159 |
+
response_payload = handler_iterator->second(request_payload);
|
| 160 |
+
} catch (...) {
|
| 161 |
+
status = ResponseStatus::ERROR_INTERNAL;
|
| 162 |
+
}
|
| 163 |
+
} else {
|
| 164 |
+
status = ResponseStatus::ERROR_INVALID_COMMAND;
|
| 165 |
+
}
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
ResponseHeader response_header;
|
| 169 |
+
response_header.magic_number = PROTOCOL_MAGIC_NUMBER;
|
| 170 |
+
response_header.status_code = static_cast<std::uint32_t>(status);
|
| 171 |
+
response_header.payload_size = static_cast<std::uint32_t>(response_payload.size());
|
| 172 |
+
response_header.request_id = request_header.request_id;
|
| 173 |
+
|
| 174 |
+
send_response(client_socket_fd, response_header, response_payload);
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
bool IpcHandler::send_response(
|
| 178 |
+
int socket_fd,
|
| 179 |
+
const ResponseHeader& header,
|
| 180 |
+
const std::vector<std::uint8_t>& payload
|
| 181 |
+
) {
|
| 182 |
+
ssize_t bytes_written = write(socket_fd, &header, sizeof(ResponseHeader));
|
| 183 |
+
|
| 184 |
+
if (bytes_written != sizeof(ResponseHeader)) {
|
| 185 |
+
return false;
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
if (!payload.empty()) {
|
| 189 |
+
bytes_written = write(socket_fd, payload.data(), payload.size());
|
| 190 |
+
|
| 191 |
+
if (bytes_written != static_cast<ssize_t>(payload.size())) {
|
| 192 |
+
return false;
|
| 193 |
+
}
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
return true;
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
bool IpcHandler::receive_request(
|
| 200 |
+
int socket_fd,
|
| 201 |
+
RequestHeader& header,
|
| 202 |
+
std::vector<std::uint8_t>& payload
|
| 203 |
+
) {
|
| 204 |
+
ssize_t bytes_read = read(socket_fd, &header, sizeof(RequestHeader));
|
| 205 |
+
|
| 206 |
+
if (bytes_read != sizeof(RequestHeader)) {
|
| 207 |
+
return false;
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
if (header.payload_size > MAXIMUM_PAYLOAD_SIZE) {
|
| 211 |
+
return false;
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
if (header.payload_size > 0) {
|
| 215 |
+
payload.resize(header.payload_size);
|
| 216 |
+
bytes_read = read(socket_fd, payload.data(), header.payload_size);
|
| 217 |
+
|
| 218 |
+
if (bytes_read != static_cast<ssize_t>(header.payload_size)) {
|
| 219 |
+
return false;
|
| 220 |
+
}
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
return true;
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
}
|
accelerator/src/main.cpp
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//
|
| 2 |
+
// SPDX-FileCopyrightText: Hadad <hadad@linuxmail.org>
|
| 3 |
+
// SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
//
|
| 5 |
+
|
| 6 |
+
#include "accelerator_core.hpp"
|
| 7 |
+
#include <cstdlib>
|
| 8 |
+
#include <cstring>
|
| 9 |
+
#include <iostream>
|
| 10 |
+
#include <string>
|
| 11 |
+
|
| 12 |
+
void print_usage(const char* program_name) {
|
| 13 |
+
std::cout << "Usage: " << program_name << " [options]" << std::endl;
|
| 14 |
+
std::cout << std::endl;
|
| 15 |
+
std::cout << "Options:" << std::endl;
|
| 16 |
+
std::cout << " --socket PATH IPC socket path (default: /tmp/pocket_tts_accelerator.sock)" << std::endl;
|
| 17 |
+
std::cout << " --threads N Number of worker threads (default: 2)" << std::endl;
|
| 18 |
+
std::cout << " --memory MB Memory pool size in megabytes (default: 64)" << std::endl;
|
| 19 |
+
std::cout << " --quiet Disable verbose logging" << std::endl;
|
| 20 |
+
std::cout << " --help Show this help message" << std::endl;
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
int main(int argc, char* argv[]) {
|
| 24 |
+
pocket_tts_accelerator::AcceleratorConfiguration configuration =
|
| 25 |
+
pocket_tts_accelerator::AcceleratorCore::get_default_configuration();
|
| 26 |
+
|
| 27 |
+
for (int argument_index = 1; argument_index < argc; ++argument_index) {
|
| 28 |
+
std::string argument(argv[argument_index]);
|
| 29 |
+
|
| 30 |
+
if (argument == "--help" || argument == "-h") {
|
| 31 |
+
print_usage(argv[0]);
|
| 32 |
+
return 0;
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
if (argument == "--socket" && argument_index + 1 < argc) {
|
| 36 |
+
configuration.ipc_socket_path = argv[++argument_index];
|
| 37 |
+
continue;
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
if (argument == "--threads" && argument_index + 1 < argc) {
|
| 41 |
+
configuration.number_of_worker_threads = std::stoul(argv[++argument_index]);
|
| 42 |
+
continue;
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
if (argument == "--memory" && argument_index + 1 < argc) {
|
| 46 |
+
std::size_t memory_mb = std::stoul(argv[++argument_index]);
|
| 47 |
+
configuration.memory_pool_size_bytes = memory_mb * 1024 * 1024;
|
| 48 |
+
continue;
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
if (argument == "--quiet" || argument == "-q") {
|
| 52 |
+
configuration.enable_verbose_logging = false;
|
| 53 |
+
continue;
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
std::cerr << "Unknown argument: " << argument << std::endl;
|
| 57 |
+
print_usage(argv[0]);
|
| 58 |
+
return 1;
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
if (configuration.number_of_worker_threads < 1) {
|
| 62 |
+
configuration.number_of_worker_threads = 1;
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
if (configuration.number_of_worker_threads > 2) {
|
| 66 |
+
configuration.number_of_worker_threads = 2;
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
pocket_tts_accelerator::AcceleratorCore accelerator(configuration);
|
| 70 |
+
|
| 71 |
+
if (!accelerator.initialize()) {
|
| 72 |
+
std::cerr << "Failed to initialize accelerator" << std::endl;
|
| 73 |
+
return 1;
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
accelerator.run();
|
| 77 |
+
|
| 78 |
+
return 0;
|
| 79 |
+
}
|
accelerator/src/memory_pool.cpp
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//
|
| 2 |
+
// SPDX-FileCopyrightText: Hadad <hadad@linuxmail.org>
|
| 3 |
+
// SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
//
|
| 5 |
+
|
| 6 |
+
#include "memory_pool.hpp"
|
| 7 |
+
#include <algorithm>
|
| 8 |
+
#include <chrono>
|
| 9 |
+
#include <cstring>
|
| 10 |
+
|
| 11 |
+
namespace pocket_tts_accelerator {
|
| 12 |
+
|
| 13 |
+
MemoryPool::MemoryPool(std::size_t initial_pool_size_bytes)
|
| 14 |
+
: total_allocated_bytes(0)
|
| 15 |
+
, total_used_bytes(0)
|
| 16 |
+
, maximum_pool_size_bytes(initial_pool_size_bytes) {
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
MemoryPool::~MemoryPool() {
|
| 20 |
+
reset_pool();
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
std::uint8_t* MemoryPool::allocate(std::size_t requested_size_bytes) {
|
| 24 |
+
std::unique_lock<std::mutex> lock(pool_mutex);
|
| 25 |
+
|
| 26 |
+
std::size_t block_index = find_suitable_block_index(requested_size_bytes);
|
| 27 |
+
|
| 28 |
+
if (block_index != static_cast<std::size_t>(-1)) {
|
| 29 |
+
MemoryBlock& existing_block = memory_blocks[block_index];
|
| 30 |
+
existing_block.is_in_use = true;
|
| 31 |
+
existing_block.last_access_timestamp = get_current_timestamp();
|
| 32 |
+
total_used_bytes += existing_block.block_size;
|
| 33 |
+
return existing_block.data.get();
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
if (total_allocated_bytes + requested_size_bytes > maximum_pool_size_bytes) {
|
| 37 |
+
clear_unused_blocks();
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
std::size_t aligned_size = ((requested_size_bytes + 63) / 64) * 64;
|
| 41 |
+
|
| 42 |
+
memory_blocks.push_back(MemoryBlock{
|
| 43 |
+
std::make_unique<std::uint8_t[]>(aligned_size),
|
| 44 |
+
aligned_size,
|
| 45 |
+
true,
|
| 46 |
+
get_current_timestamp()
|
| 47 |
+
});
|
| 48 |
+
|
| 49 |
+
std::uint8_t* allocated_pointer = memory_blocks.back().data.get();
|
| 50 |
+
pointer_to_block_index[allocated_pointer] = memory_blocks.size() - 1;
|
| 51 |
+
|
| 52 |
+
total_allocated_bytes += aligned_size;
|
| 53 |
+
total_used_bytes += aligned_size;
|
| 54 |
+
|
| 55 |
+
return allocated_pointer;
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
void MemoryPool::deallocate(std::uint8_t* pointer) {
|
| 59 |
+
if (pointer == nullptr) {
|
| 60 |
+
return;
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
std::unique_lock<std::mutex> lock(pool_mutex);
|
| 64 |
+
|
| 65 |
+
auto iterator = pointer_to_block_index.find(pointer);
|
| 66 |
+
|
| 67 |
+
if (iterator != pointer_to_block_index.end()) {
|
| 68 |
+
std::size_t block_index = iterator->second;
|
| 69 |
+
|
| 70 |
+
if (block_index < memory_blocks.size()) {
|
| 71 |
+
MemoryBlock& block = memory_blocks[block_index];
|
| 72 |
+
|
| 73 |
+
if (block.is_in_use) {
|
| 74 |
+
block.is_in_use = false;
|
| 75 |
+
block.last_access_timestamp = get_current_timestamp();
|
| 76 |
+
total_used_bytes -= block.block_size;
|
| 77 |
+
}
|
| 78 |
+
}
|
| 79 |
+
}
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
void MemoryPool::clear_unused_blocks() {
|
| 83 |
+
std::vector<std::size_t> indices_to_remove;
|
| 84 |
+
|
| 85 |
+
for (std::size_t index = 0; index < memory_blocks.size(); ++index) {
|
| 86 |
+
if (!memory_blocks[index].is_in_use) {
|
| 87 |
+
indices_to_remove.push_back(index);
|
| 88 |
+
}
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
std::sort(indices_to_remove.rbegin(), indices_to_remove.rend());
|
| 92 |
+
|
| 93 |
+
for (std::size_t index : indices_to_remove) {
|
| 94 |
+
std::uint8_t* pointer = memory_blocks[index].data.get();
|
| 95 |
+
total_allocated_bytes -= memory_blocks[index].block_size;
|
| 96 |
+
|
| 97 |
+
pointer_to_block_index.erase(pointer);
|
| 98 |
+
memory_blocks.erase(memory_blocks.begin() + static_cast<std::ptrdiff_t>(index));
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
for (std::size_t index = 0; index < memory_blocks.size(); ++index) {
|
| 102 |
+
pointer_to_block_index[memory_blocks[index].data.get()] = index;
|
| 103 |
+
}
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
void MemoryPool::reset_pool() {
|
| 107 |
+
std::unique_lock<std::mutex> lock(pool_mutex);
|
| 108 |
+
|
| 109 |
+
memory_blocks.clear();
|
| 110 |
+
pointer_to_block_index.clear();
|
| 111 |
+
total_allocated_bytes = 0;
|
| 112 |
+
total_used_bytes = 0;
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
std::size_t MemoryPool::get_total_allocated_bytes() const {
|
| 116 |
+
std::unique_lock<std::mutex> lock(pool_mutex);
|
| 117 |
+
return total_allocated_bytes;
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
std::size_t MemoryPool::get_total_used_bytes() const {
|
| 121 |
+
std::unique_lock<std::mutex> lock(pool_mutex);
|
| 122 |
+
return total_used_bytes;
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
std::size_t MemoryPool::get_block_count() const {
|
| 126 |
+
std::unique_lock<std::mutex> lock(pool_mutex);
|
| 127 |
+
return memory_blocks.size();
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
std::size_t MemoryPool::find_suitable_block_index(std::size_t requested_size) const {
|
| 131 |
+
std::size_t best_fit_index = static_cast<std::size_t>(-1);
|
| 132 |
+
std::size_t best_fit_size = static_cast<std::size_t>(-1);
|
| 133 |
+
|
| 134 |
+
for (std::size_t index = 0; index < memory_blocks.size(); ++index) {
|
| 135 |
+
const MemoryBlock& block = memory_blocks[index];
|
| 136 |
+
|
| 137 |
+
if (!block.is_in_use && block.block_size >= requested_size) {
|
| 138 |
+
if (block.block_size < best_fit_size) {
|
| 139 |
+
best_fit_size = block.block_size;
|
| 140 |
+
best_fit_index = index;
|
| 141 |
+
}
|
| 142 |
+
}
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
return best_fit_index;
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
void MemoryPool::create_new_block(std::size_t block_size) {
|
| 149 |
+
std::size_t aligned_size = ((block_size + 63) / 64) * 64;
|
| 150 |
+
|
| 151 |
+
memory_blocks.push_back(MemoryBlock{
|
| 152 |
+
std::make_unique<std::uint8_t[]>(aligned_size),
|
| 153 |
+
aligned_size,
|
| 154 |
+
false,
|
| 155 |
+
get_current_timestamp()
|
| 156 |
+
});
|
| 157 |
+
|
| 158 |
+
pointer_to_block_index[memory_blocks.back().data.get()] = memory_blocks.size() - 1;
|
| 159 |
+
total_allocated_bytes += aligned_size;
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
std::uint64_t MemoryPool::get_current_timestamp() const {
|
| 163 |
+
auto current_time = std::chrono::steady_clock::now();
|
| 164 |
+
auto duration = current_time.time_since_epoch();
|
| 165 |
+
return std::chrono::duration_cast<std::chrono::milliseconds>(duration).count();
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
ScopedMemoryAllocation::ScopedMemoryAllocation(MemoryPool& pool, std::size_t size)
|
| 169 |
+
: memory_pool_pointer(&pool)
|
| 170 |
+
, allocated_pointer(pool.allocate(size))
|
| 171 |
+
, allocation_size(size) {
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
ScopedMemoryAllocation::~ScopedMemoryAllocation() {
|
| 175 |
+
if (memory_pool_pointer != nullptr && allocated_pointer != nullptr) {
|
| 176 |
+
memory_pool_pointer->deallocate(allocated_pointer);
|
| 177 |
+
}
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
ScopedMemoryAllocation::ScopedMemoryAllocation(ScopedMemoryAllocation&& other) noexcept
|
| 181 |
+
: memory_pool_pointer(other.memory_pool_pointer)
|
| 182 |
+
, allocated_pointer(other.allocated_pointer)
|
| 183 |
+
, allocation_size(other.allocation_size) {
|
| 184 |
+
|
| 185 |
+
other.memory_pool_pointer = nullptr;
|
| 186 |
+
other.allocated_pointer = nullptr;
|
| 187 |
+
other.allocation_size = 0;
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
ScopedMemoryAllocation& ScopedMemoryAllocation::operator=(ScopedMemoryAllocation&& other) noexcept {
|
| 191 |
+
if (this != &other) {
|
| 192 |
+
if (memory_pool_pointer != nullptr && allocated_pointer != nullptr) {
|
| 193 |
+
memory_pool_pointer->deallocate(allocated_pointer);
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
memory_pool_pointer = other.memory_pool_pointer;
|
| 197 |
+
allocated_pointer = other.allocated_pointer;
|
| 198 |
+
allocation_size = other.allocation_size;
|
| 199 |
+
|
| 200 |
+
other.memory_pool_pointer = nullptr;
|
| 201 |
+
other.allocated_pointer = nullptr;
|
| 202 |
+
other.allocation_size = 0;
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
return *this;
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
std::uint8_t* ScopedMemoryAllocation::get() const {
|
| 209 |
+
return allocated_pointer;
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
std::size_t ScopedMemoryAllocation::size() const {
|
| 213 |
+
return allocation_size;
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
}
|
accelerator/src/thread_pool.cpp
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//
|
| 2 |
+
// SPDX-FileCopyrightText: Hadad <hadad@linuxmail.org>
|
| 3 |
+
// SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
//
|
| 5 |
+
|
| 6 |
+
#include "thread_pool.hpp"
|
| 7 |
+
|
| 8 |
+
namespace pocket_tts_accelerator {
|
| 9 |
+
|
| 10 |
+
ThreadPool::ThreadPool(std::size_t number_of_threads)
|
| 11 |
+
: should_stop(false)
|
| 12 |
+
, is_stopped(false)
|
| 13 |
+
, thread_count(number_of_threads) {
|
| 14 |
+
|
| 15 |
+
worker_threads.reserve(number_of_threads);
|
| 16 |
+
|
| 17 |
+
for (std::size_t thread_index = 0; thread_index < number_of_threads; ++thread_index) {
|
| 18 |
+
worker_threads.emplace_back(&ThreadPool::worker_thread_function, this);
|
| 19 |
+
}
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
ThreadPool::~ThreadPool() {
|
| 23 |
+
shutdown();
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
void ThreadPool::shutdown() {
|
| 27 |
+
{
|
| 28 |
+
std::unique_lock<std::mutex> lock(queue_mutex);
|
| 29 |
+
|
| 30 |
+
if (is_stopped.load()) {
|
| 31 |
+
return;
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
should_stop.store(true);
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
task_available_condition.notify_all();
|
| 38 |
+
|
| 39 |
+
for (std::thread& worker_thread : worker_threads) {
|
| 40 |
+
if (worker_thread.joinable()) {
|
| 41 |
+
worker_thread.join();
|
| 42 |
+
}
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
is_stopped.store(true);
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
bool ThreadPool::is_running() const {
|
| 49 |
+
return !should_stop.load() && !is_stopped.load();
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
std::size_t ThreadPool::get_pending_task_count() const {
|
| 53 |
+
std::unique_lock<std::mutex> lock(queue_mutex);
|
| 54 |
+
return task_queue.size();
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
std::size_t ThreadPool::get_thread_count() const {
|
| 58 |
+
return thread_count;
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
void ThreadPool::worker_thread_function() {
|
| 62 |
+
while (true) {
|
| 63 |
+
std::function<void()> task_to_execute;
|
| 64 |
+
|
| 65 |
+
{
|
| 66 |
+
std::unique_lock<std::mutex> lock(queue_mutex);
|
| 67 |
+
|
| 68 |
+
task_available_condition.wait(lock, [this] {
|
| 69 |
+
return should_stop.load() || !task_queue.empty();
|
| 70 |
+
});
|
| 71 |
+
|
| 72 |
+
if (should_stop.load() && task_queue.empty()) {
|
| 73 |
+
return;
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
task_to_execute = std::move(task_queue.front());
|
| 77 |
+
task_queue.pop();
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
task_to_execute();
|
| 81 |
+
}
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
}
|
app.py
CHANGED
|
@@ -3,11 +3,10 @@
|
|
| 3 |
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
#
|
| 5 |
|
|
|
|
| 6 |
import math
|
| 7 |
import torch
|
| 8 |
import gradio as gr
|
| 9 |
-
torch.set_num_threads(1)
|
| 10 |
-
torch.set_num_interop_threads(1)
|
| 11 |
from config import (
|
| 12 |
AVAILABLE_VOICES,
|
| 13 |
DEFAULT_VOICE,
|
|
@@ -20,10 +19,22 @@ from config import (
|
|
| 20 |
MAXIMUM_INPUT_LENGTH,
|
| 21 |
VOICE_MODE_PRESET,
|
| 22 |
VOICE_MODE_CLONE,
|
| 23 |
-
EXAMPLE_PROMPTS
|
|
|
|
|
|
|
| 24 |
)
|
|
|
|
|
|
|
| 25 |
from src.core.authentication import authenticate_huggingface
|
| 26 |
authenticate_huggingface()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
from src.core.memory import start_background_cleanup_thread
|
| 28 |
start_background_cleanup_thread()
|
| 29 |
from src.generation.handler import (
|
|
@@ -64,8 +75,7 @@ with gr.Blocks(css=CSS, fill_height=False, fill_width=True) as app:
|
|
| 64 |
audio_output_component = gr.Audio(
|
| 65 |
label="Generated Speech Output",
|
| 66 |
type="filepath",
|
| 67 |
-
interactive=False
|
| 68 |
-
autoplay=False
|
| 69 |
)
|
| 70 |
|
| 71 |
with gr.Accordion("Voice Selection", open=True):
|
|
|
|
| 3 |
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
#
|
| 5 |
|
| 6 |
+
import atexit
|
| 7 |
import math
|
| 8 |
import torch
|
| 9 |
import gradio as gr
|
|
|
|
|
|
|
| 10 |
from config import (
|
| 11 |
AVAILABLE_VOICES,
|
| 12 |
DEFAULT_VOICE,
|
|
|
|
| 19 |
MAXIMUM_INPUT_LENGTH,
|
| 20 |
VOICE_MODE_PRESET,
|
| 21 |
VOICE_MODE_CLONE,
|
| 22 |
+
EXAMPLE_PROMPTS,
|
| 23 |
+
ACCELERATOR_WORKER_THREADS,
|
| 24 |
+
ACCELERATOR_ENABLED
|
| 25 |
)
|
| 26 |
+
torch.set_num_threads(ACCELERATOR_WORKER_THREADS)
|
| 27 |
+
torch.set_num_interop_threads(ACCELERATOR_WORKER_THREADS)
|
| 28 |
from src.core.authentication import authenticate_huggingface
|
| 29 |
authenticate_huggingface()
|
| 30 |
+
if ACCELERATOR_ENABLED:
|
| 31 |
+
from src.accelerator.client import start_accelerator_daemon, stop_accelerator_daemon
|
| 32 |
+
accelerator_started = start_accelerator_daemon()
|
| 33 |
+
if accelerator_started:
|
| 34 |
+
print("Accelerator daemon started successfully", flush=True)
|
| 35 |
+
else:
|
| 36 |
+
print("Accelerator daemon not available, using Python fallback", flush=True)
|
| 37 |
+
atexit.register(stop_accelerator_daemon)
|
| 38 |
from src.core.memory import start_background_cleanup_thread
|
| 39 |
start_background_cleanup_thread()
|
| 40 |
from src.generation.handler import (
|
|
|
|
| 75 |
audio_output_component = gr.Audio(
|
| 76 |
label="Generated Speech Output",
|
| 77 |
type="filepath",
|
| 78 |
+
interactive=False
|
|
|
|
| 79 |
)
|
| 80 |
|
| 81 |
with gr.Accordion("Voice Selection", open=True):
|
config.py
CHANGED
|
@@ -110,4 +110,10 @@ COPYRIGHT_NAME = "Hadad Darajat"
|
|
| 110 |
COPYRIGHT_URL = "https://www.linkedin.com/in/hadadrjt"
|
| 111 |
|
| 112 |
DESIGN_BY_NAME = "D3vShoaib/pocket-tts"
|
| 113 |
-
DESIGN_BY_URL = f"https://huggingface.co/spaces/{DESIGN_BY_NAME}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
COPYRIGHT_URL = "https://www.linkedin.com/in/hadadrjt"
|
| 111 |
|
| 112 |
DESIGN_BY_NAME = "D3vShoaib/pocket-tts"
|
| 113 |
+
DESIGN_BY_URL = f"https://huggingface.co/spaces/{DESIGN_BY_NAME}"
|
| 114 |
+
|
| 115 |
+
ACCELERATOR_SOCKET_PATH = "/app/pocket_tts_accelerator.sock"
|
| 116 |
+
ACCELERATOR_BINARY_PATH = "/app/bin/pocket_tts_accelerator"
|
| 117 |
+
ACCELERATOR_WORKER_THREADS = 2
|
| 118 |
+
ACCELERATOR_MEMORY_POOL_MB = 64
|
| 119 |
+
ACCELERATOR_ENABLED = True
|
src/accelerator/client.py
ADDED
|
@@ -0,0 +1,442 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# SPDX-FileCopyrightText: Hadad <hadad@linuxmail.org>
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import socket
|
| 8 |
+
import struct
|
| 9 |
+
import subprocess
|
| 10 |
+
import tempfile
|
| 11 |
+
import threading
|
| 12 |
+
import time
|
| 13 |
+
from typing import Optional, Tuple, Dict, Any
|
| 14 |
+
from config import (
|
| 15 |
+
ACCELERATOR_SOCKET_PATH,
|
| 16 |
+
ACCELERATOR_BINARY_PATH,
|
| 17 |
+
ACCELERATOR_WORKER_THREADS,
|
| 18 |
+
ACCELERATOR_MEMORY_POOL_MB
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
PROTOCOL_MAGIC_NUMBER = 0x50545453
|
| 22 |
+
|
| 23 |
+
COMMAND_PING = 0
|
| 24 |
+
COMMAND_PROCESS_AUDIO = 1
|
| 25 |
+
COMMAND_CONVERT_TO_MONO = 2
|
| 26 |
+
COMMAND_CONVERT_TO_PCM = 3
|
| 27 |
+
COMMAND_RESAMPLE_AUDIO = 4
|
| 28 |
+
COMMAND_GET_MEMORY_STATS = 5
|
| 29 |
+
COMMAND_CLEAR_MEMORY_POOL = 6
|
| 30 |
+
COMMAND_SHUTDOWN = 7
|
| 31 |
+
|
| 32 |
+
RESPONSE_SUCCESS = 0
|
| 33 |
+
RESPONSE_ERROR_INVALID_COMMAND = 1
|
| 34 |
+
RESPONSE_ERROR_FILE_NOT_FOUND = 2
|
| 35 |
+
RESPONSE_ERROR_PROCESSING_FAILED = 3
|
| 36 |
+
RESPONSE_ERROR_MEMORY_ALLOCATION = 4
|
| 37 |
+
RESPONSE_ERROR_INTERNAL = 5
|
| 38 |
+
|
| 39 |
+
REQUEST_HEADER_FORMAT = "=IIII"
|
| 40 |
+
RESPONSE_HEADER_FORMAT = "=IIII"
|
| 41 |
+
REQUEST_HEADER_SIZE = struct.calcsize(REQUEST_HEADER_FORMAT)
|
| 42 |
+
RESPONSE_HEADER_SIZE = struct.calcsize(RESPONSE_HEADER_FORMAT)
|
| 43 |
+
|
| 44 |
+
PROCESS_AUDIO_REQUEST_FORMAT = "=512s512sII"
|
| 45 |
+
PROCESS_AUDIO_REQUEST_SIZE = struct.calcsize(PROCESS_AUDIO_REQUEST_FORMAT)
|
| 46 |
+
|
| 47 |
+
MEMORY_STATS_RESPONSE_FORMAT = "=QQQ"
|
| 48 |
+
MEMORY_STATS_RESPONSE_SIZE = struct.calcsize(MEMORY_STATS_RESPONSE_FORMAT)
|
| 49 |
+
|
| 50 |
+
accelerator_process_handle = None
|
| 51 |
+
accelerator_process_lock = threading.Lock()
|
| 52 |
+
request_id_counter = 0
|
| 53 |
+
request_id_lock = threading.Lock()
|
| 54 |
+
|
| 55 |
+
class AcceleratorClient:
|
| 56 |
+
def __init__(self, socket_path: str = ACCELERATOR_SOCKET_PATH):
|
| 57 |
+
self.socket_path = socket_path
|
| 58 |
+
self.connection_timeout = 5.0
|
| 59 |
+
self.read_timeout = 30.0
|
| 60 |
+
|
| 61 |
+
def is_connected(self) -> bool:
|
| 62 |
+
try:
|
| 63 |
+
response = self.send_ping()
|
| 64 |
+
return response is not None and response.startswith(b"PONG")
|
| 65 |
+
|
| 66 |
+
except Exception:
|
| 67 |
+
return False
|
| 68 |
+
|
| 69 |
+
def send_ping(self) -> Optional[bytes]:
|
| 70 |
+
return self._send_command(COMMAND_PING, b"")
|
| 71 |
+
|
| 72 |
+
def process_audio(
|
| 73 |
+
self,
|
| 74 |
+
input_file_path: str,
|
| 75 |
+
output_file_path: str,
|
| 76 |
+
target_sample_rate: int = 0,
|
| 77 |
+
options_flags: int = 0
|
| 78 |
+
) -> Tuple[bool, str]:
|
| 79 |
+
payload = self._pack_process_audio_request(
|
| 80 |
+
input_file_path,
|
| 81 |
+
output_file_path,
|
| 82 |
+
target_sample_rate,
|
| 83 |
+
options_flags
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
response = self._send_command(COMMAND_PROCESS_AUDIO, payload)
|
| 87 |
+
|
| 88 |
+
if response is None:
|
| 89 |
+
return False, "Failed to communicate with accelerator"
|
| 90 |
+
|
| 91 |
+
response_string = response.decode("utf-8", errors="ignore")
|
| 92 |
+
|
| 93 |
+
if response_string.startswith("SUCCESS:"):
|
| 94 |
+
return True, response_string[8:]
|
| 95 |
+
|
| 96 |
+
elif response_string.startswith("ERROR:"):
|
| 97 |
+
return False, response_string[6:]
|
| 98 |
+
|
| 99 |
+
else:
|
| 100 |
+
return False, response_string
|
| 101 |
+
|
| 102 |
+
def convert_to_mono(
|
| 103 |
+
self,
|
| 104 |
+
input_file_path: str,
|
| 105 |
+
output_file_path: str
|
| 106 |
+
) -> Tuple[bool, str]:
|
| 107 |
+
payload = self._pack_process_audio_request(
|
| 108 |
+
input_file_path,
|
| 109 |
+
output_file_path,
|
| 110 |
+
0,
|
| 111 |
+
0
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
response = self._send_command(COMMAND_CONVERT_TO_MONO, payload)
|
| 115 |
+
|
| 116 |
+
if response is None:
|
| 117 |
+
return False, "Failed to communicate with accelerator"
|
| 118 |
+
|
| 119 |
+
response_string = response.decode("utf-8", errors="ignore")
|
| 120 |
+
|
| 121 |
+
if response_string.startswith("SUCCESS:"):
|
| 122 |
+
return True, response_string[8:]
|
| 123 |
+
|
| 124 |
+
elif response_string.startswith("ERROR:"):
|
| 125 |
+
return False, response_string[6:]
|
| 126 |
+
|
| 127 |
+
else:
|
| 128 |
+
return False, response_string
|
| 129 |
+
|
| 130 |
+
def convert_to_pcm(
|
| 131 |
+
self,
|
| 132 |
+
input_file_path: str,
|
| 133 |
+
output_file_path: str
|
| 134 |
+
) -> Tuple[bool, str]:
|
| 135 |
+
payload = self._pack_process_audio_request(
|
| 136 |
+
input_file_path,
|
| 137 |
+
output_file_path,
|
| 138 |
+
0,
|
| 139 |
+
0
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
response = self._send_command(COMMAND_CONVERT_TO_PCM, payload)
|
| 143 |
+
|
| 144 |
+
if response is None:
|
| 145 |
+
return False, "Failed to communicate with accelerator"
|
| 146 |
+
|
| 147 |
+
response_string = response.decode("utf-8", errors="ignore")
|
| 148 |
+
|
| 149 |
+
if response_string.startswith("SUCCESS:"):
|
| 150 |
+
return True, response_string[8:]
|
| 151 |
+
|
| 152 |
+
elif response_string.startswith("ERROR:"):
|
| 153 |
+
return False, response_string[6:]
|
| 154 |
+
|
| 155 |
+
else:
|
| 156 |
+
return False, response_string
|
| 157 |
+
|
| 158 |
+
def resample_audio(
|
| 159 |
+
self,
|
| 160 |
+
input_file_path: str,
|
| 161 |
+
output_file_path: str,
|
| 162 |
+
target_sample_rate: int
|
| 163 |
+
) -> Tuple[bool, str]:
|
| 164 |
+
payload = self._pack_process_audio_request(
|
| 165 |
+
input_file_path,
|
| 166 |
+
output_file_path,
|
| 167 |
+
target_sample_rate,
|
| 168 |
+
0
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
response = self._send_command(COMMAND_RESAMPLE_AUDIO, payload)
|
| 172 |
+
|
| 173 |
+
if response is None:
|
| 174 |
+
return False, "Failed to communicate with accelerator"
|
| 175 |
+
|
| 176 |
+
response_string = response.decode("utf-8", errors="ignore")
|
| 177 |
+
|
| 178 |
+
if response_string.startswith("SUCCESS:"):
|
| 179 |
+
return True, response_string[8:]
|
| 180 |
+
|
| 181 |
+
elif response_string.startswith("ERROR:"):
|
| 182 |
+
return False, response_string[6:]
|
| 183 |
+
|
| 184 |
+
else:
|
| 185 |
+
return False, response_string
|
| 186 |
+
|
| 187 |
+
def get_memory_stats(self) -> Optional[Dict[str, int]]:
|
| 188 |
+
response = self._send_command(COMMAND_GET_MEMORY_STATS, b"")
|
| 189 |
+
|
| 190 |
+
if response is None or len(response) < MEMORY_STATS_RESPONSE_SIZE:
|
| 191 |
+
return None
|
| 192 |
+
|
| 193 |
+
total_allocated, total_used, block_count = struct.unpack(
|
| 194 |
+
MEMORY_STATS_RESPONSE_FORMAT,
|
| 195 |
+
response[:MEMORY_STATS_RESPONSE_SIZE]
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
return {
|
| 199 |
+
"total_allocated_bytes": total_allocated,
|
| 200 |
+
"total_used_bytes": total_used,
|
| 201 |
+
"block_count": block_count
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
def clear_memory_pool(self) -> bool:
|
| 205 |
+
response = self._send_command(COMMAND_CLEAR_MEMORY_POOL, b"")
|
| 206 |
+
return response is not None
|
| 207 |
+
|
| 208 |
+
def shutdown_accelerator(self) -> bool:
|
| 209 |
+
response = self._send_command(COMMAND_SHUTDOWN, b"")
|
| 210 |
+
return response is not None
|
| 211 |
+
|
| 212 |
+
def _get_next_request_id(self) -> int:
|
| 213 |
+
global request_id_counter
|
| 214 |
+
|
| 215 |
+
with request_id_lock:
|
| 216 |
+
request_id_counter += 1
|
| 217 |
+
return request_id_counter
|
| 218 |
+
|
| 219 |
+
def _pack_process_audio_request(
|
| 220 |
+
self,
|
| 221 |
+
input_path: str,
|
| 222 |
+
output_path: str,
|
| 223 |
+
target_sample_rate: int,
|
| 224 |
+
options_flags: int
|
| 225 |
+
) -> bytes:
|
| 226 |
+
input_path_bytes = input_path.encode("utf-8")[:511] + b"\x00"
|
| 227 |
+
output_path_bytes = output_path.encode("utf-8")[:511] + b"\x00"
|
| 228 |
+
|
| 229 |
+
input_path_padded = input_path_bytes.ljust(512, b"\x00")
|
| 230 |
+
output_path_padded = output_path_bytes.ljust(512, b"\x00")
|
| 231 |
+
|
| 232 |
+
return struct.pack(
|
| 233 |
+
PROCESS_AUDIO_REQUEST_FORMAT,
|
| 234 |
+
input_path_padded,
|
| 235 |
+
output_path_padded,
|
| 236 |
+
target_sample_rate,
|
| 237 |
+
options_flags
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
def _send_command(
|
| 241 |
+
self,
|
| 242 |
+
command_type: int,
|
| 243 |
+
payload: bytes
|
| 244 |
+
) -> Optional[bytes]:
|
| 245 |
+
try:
|
| 246 |
+
client_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
| 247 |
+
client_socket.settimeout(self.connection_timeout)
|
| 248 |
+
client_socket.connect(self.socket_path)
|
| 249 |
+
|
| 250 |
+
request_id = self._get_next_request_id()
|
| 251 |
+
|
| 252 |
+
request_header = struct.pack(
|
| 253 |
+
REQUEST_HEADER_FORMAT,
|
| 254 |
+
PROTOCOL_MAGIC_NUMBER,
|
| 255 |
+
command_type,
|
| 256 |
+
len(payload),
|
| 257 |
+
request_id
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
client_socket.sendall(request_header)
|
| 261 |
+
|
| 262 |
+
if payload:
|
| 263 |
+
client_socket.sendall(payload)
|
| 264 |
+
|
| 265 |
+
client_socket.settimeout(self.read_timeout)
|
| 266 |
+
|
| 267 |
+
response_header_data = self._receive_exactly(client_socket, RESPONSE_HEADER_SIZE)
|
| 268 |
+
|
| 269 |
+
if response_header_data is None:
|
| 270 |
+
client_socket.close()
|
| 271 |
+
return None
|
| 272 |
+
|
| 273 |
+
magic_number, status_code, payload_size, response_request_id = struct.unpack(
|
| 274 |
+
RESPONSE_HEADER_FORMAT,
|
| 275 |
+
response_header_data
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
if magic_number != PROTOCOL_MAGIC_NUMBER:
|
| 279 |
+
client_socket.close()
|
| 280 |
+
return None
|
| 281 |
+
|
| 282 |
+
if response_request_id != request_id:
|
| 283 |
+
client_socket.close()
|
| 284 |
+
return None
|
| 285 |
+
|
| 286 |
+
response_payload = b""
|
| 287 |
+
|
| 288 |
+
if payload_size > 0:
|
| 289 |
+
response_payload = self._receive_exactly(client_socket, payload_size)
|
| 290 |
+
|
| 291 |
+
if response_payload is None:
|
| 292 |
+
client_socket.close()
|
| 293 |
+
return None
|
| 294 |
+
|
| 295 |
+
client_socket.close()
|
| 296 |
+
|
| 297 |
+
if status_code != RESPONSE_SUCCESS:
|
| 298 |
+
return response_payload if response_payload else None
|
| 299 |
+
|
| 300 |
+
return response_payload
|
| 301 |
+
|
| 302 |
+
except socket.timeout:
|
| 303 |
+
return None
|
| 304 |
+
|
| 305 |
+
except socket.error:
|
| 306 |
+
return None
|
| 307 |
+
|
| 308 |
+
except Exception:
|
| 309 |
+
return None
|
| 310 |
+
|
| 311 |
+
def _receive_exactly(
|
| 312 |
+
self,
|
| 313 |
+
client_socket: socket.socket,
|
| 314 |
+
num_bytes: int
|
| 315 |
+
) -> Optional[bytes]:
|
| 316 |
+
received_data = b""
|
| 317 |
+
remaining_bytes = num_bytes
|
| 318 |
+
|
| 319 |
+
while remaining_bytes > 0:
|
| 320 |
+
try:
|
| 321 |
+
chunk = client_socket.recv(remaining_bytes)
|
| 322 |
+
|
| 323 |
+
if not chunk:
|
| 324 |
+
return None
|
| 325 |
+
|
| 326 |
+
received_data += chunk
|
| 327 |
+
remaining_bytes -= len(chunk)
|
| 328 |
+
|
| 329 |
+
except socket.timeout:
|
| 330 |
+
return None
|
| 331 |
+
|
| 332 |
+
except socket.error:
|
| 333 |
+
return None
|
| 334 |
+
|
| 335 |
+
return received_data
|
| 336 |
+
|
| 337 |
+
def is_accelerator_available() -> bool:
|
| 338 |
+
if not os.path.exists(ACCELERATOR_SOCKET_PATH):
|
| 339 |
+
return False
|
| 340 |
+
|
| 341 |
+
client = AcceleratorClient()
|
| 342 |
+
return client.is_connected()
|
| 343 |
+
|
| 344 |
+
def start_accelerator_daemon() -> bool:
|
| 345 |
+
global accelerator_process_handle
|
| 346 |
+
|
| 347 |
+
with accelerator_process_lock:
|
| 348 |
+
if accelerator_process_handle is not None:
|
| 349 |
+
if accelerator_process_handle.poll() is None:
|
| 350 |
+
return True
|
| 351 |
+
|
| 352 |
+
if not os.path.exists(ACCELERATOR_BINARY_PATH):
|
| 353 |
+
return False
|
| 354 |
+
|
| 355 |
+
try:
|
| 356 |
+
accelerator_process_handle = subprocess.Popen(
|
| 357 |
+
[
|
| 358 |
+
ACCELERATOR_BINARY_PATH,
|
| 359 |
+
"--socket", ACCELERATOR_SOCKET_PATH,
|
| 360 |
+
"--threads", str(ACCELERATOR_WORKER_THREADS),
|
| 361 |
+
"--memory", str(ACCELERATOR_MEMORY_POOL_MB)
|
| 362 |
+
],
|
| 363 |
+
|
| 364 |
+
stdout=subprocess.DEVNULL,
|
| 365 |
+
stderr=subprocess.DEVNULL,
|
| 366 |
+
start_new_session=True
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
for attempt_index in range(50):
|
| 370 |
+
time.sleep(0.1)
|
| 371 |
+
|
| 372 |
+
if is_accelerator_available():
|
| 373 |
+
return True
|
| 374 |
+
|
| 375 |
+
return is_accelerator_available()
|
| 376 |
+
|
| 377 |
+
except Exception:
|
| 378 |
+
return False
|
| 379 |
+
|
| 380 |
+
def stop_accelerator_daemon() -> bool:
|
| 381 |
+
global accelerator_process_handle
|
| 382 |
+
|
| 383 |
+
with accelerator_process_lock:
|
| 384 |
+
if is_accelerator_available():
|
| 385 |
+
try:
|
| 386 |
+
client = AcceleratorClient()
|
| 387 |
+
client.shutdown_accelerator()
|
| 388 |
+
time.sleep(0.5)
|
| 389 |
+
|
| 390 |
+
except Exception:
|
| 391 |
+
pass
|
| 392 |
+
|
| 393 |
+
if accelerator_process_handle is not None:
|
| 394 |
+
if accelerator_process_handle.poll() is None:
|
| 395 |
+
try:
|
| 396 |
+
accelerator_process_handle.terminate()
|
| 397 |
+
accelerator_process_handle.wait(timeout=5)
|
| 398 |
+
|
| 399 |
+
except subprocess.TimeoutExpired:
|
| 400 |
+
accelerator_process_handle.kill()
|
| 401 |
+
accelerator_process_handle.wait()
|
| 402 |
+
|
| 403 |
+
accelerator_process_handle = None
|
| 404 |
+
|
| 405 |
+
return True
|
| 406 |
+
|
| 407 |
+
def process_audio_with_accelerator(
|
| 408 |
+
input_file_path: str,
|
| 409 |
+
output_file_path: str
|
| 410 |
+
) -> Tuple[bool, str]:
|
| 411 |
+
if not is_accelerator_available():
|
| 412 |
+
return False, "Accelerator not available"
|
| 413 |
+
|
| 414 |
+
client = AcceleratorClient()
|
| 415 |
+
return client.process_audio(input_file_path, output_file_path)
|
| 416 |
+
|
| 417 |
+
def convert_to_mono_with_accelerator(
|
| 418 |
+
input_file_path: str,
|
| 419 |
+
output_file_path: str
|
| 420 |
+
) -> Tuple[bool, str]:
|
| 421 |
+
if not is_accelerator_available():
|
| 422 |
+
return False, "Accelerator not available"
|
| 423 |
+
|
| 424 |
+
client = AcceleratorClient()
|
| 425 |
+
return client.convert_to_mono(input_file_path, output_file_path)
|
| 426 |
+
|
| 427 |
+
def convert_to_pcm_with_accelerator(
|
| 428 |
+
input_file_path: str,
|
| 429 |
+
output_file_path: str
|
| 430 |
+
) -> Tuple[bool, str]:
|
| 431 |
+
if not is_accelerator_available():
|
| 432 |
+
return False, "Accelerator not available"
|
| 433 |
+
|
| 434 |
+
client = AcceleratorClient()
|
| 435 |
+
return client.convert_to_pcm(input_file_path, output_file_path)
|
| 436 |
+
|
| 437 |
+
def get_accelerator_memory_stats() -> Optional[Dict[str, int]]:
|
| 438 |
+
if not is_accelerator_available():
|
| 439 |
+
return None
|
| 440 |
+
|
| 441 |
+
client = AcceleratorClient()
|
| 442 |
+
return client.get_memory_stats()
|
src/audio/converter.py
CHANGED
|
@@ -10,6 +10,11 @@ import numpy as np
|
|
| 10 |
import scipy.io.wavfile
|
| 11 |
from ..core.state import temporary_files_registry, temporary_files_lock
|
| 12 |
from ..core.memory import trigger_background_cleanup_check
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
def convert_audio_data_to_pcm_int16(audio_data):
|
| 15 |
if audio_data.dtype == np.float32 or audio_data.dtype == np.float64:
|
|
@@ -55,7 +60,30 @@ def register_temporary_file(file_path):
|
|
| 55 |
temporary_files_registry[file_path] = time.time()
|
| 56 |
trigger_background_cleanup_check()
|
| 57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
def convert_wav_file_to_pcm_format(input_path):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
try:
|
| 60 |
sample_rate, audio_data = scipy.io.wavfile.read(input_path)
|
| 61 |
|
|
|
|
| 10 |
import scipy.io.wavfile
|
| 11 |
from ..core.state import temporary_files_registry, temporary_files_lock
|
| 12 |
from ..core.memory import trigger_background_cleanup_check
|
| 13 |
+
from ..accelerator.client import (
|
| 14 |
+
is_accelerator_available,
|
| 15 |
+
convert_to_pcm_with_accelerator,
|
| 16 |
+
process_audio_with_accelerator
|
| 17 |
+
)
|
| 18 |
|
| 19 |
def convert_audio_data_to_pcm_int16(audio_data):
|
| 20 |
if audio_data.dtype == np.float32 or audio_data.dtype == np.float64:
|
|
|
|
| 60 |
temporary_files_registry[file_path] = time.time()
|
| 61 |
trigger_background_cleanup_check()
|
| 62 |
|
| 63 |
+
def convert_wav_file_to_pcm_format_with_accelerator(input_path):
|
| 64 |
+
output_file = tempfile.NamedTemporaryFile(suffix="_accel_pcm_converted.wav", delete=False)
|
| 65 |
+
output_path = output_file.name
|
| 66 |
+
output_file.close()
|
| 67 |
+
|
| 68 |
+
success, result_message = convert_to_pcm_with_accelerator(input_path, output_path)
|
| 69 |
+
|
| 70 |
+
if success:
|
| 71 |
+
register_temporary_file(output_path)
|
| 72 |
+
return output_path, None
|
| 73 |
+
else:
|
| 74 |
+
if os.path.exists(output_path):
|
| 75 |
+
try:
|
| 76 |
+
os.remove(output_path)
|
| 77 |
+
except Exception:
|
| 78 |
+
pass
|
| 79 |
+
return None, result_message
|
| 80 |
+
|
| 81 |
def convert_wav_file_to_pcm_format(input_path):
|
| 82 |
+
if is_accelerator_available():
|
| 83 |
+
accelerated_result, accelerated_error = convert_wav_file_to_pcm_format_with_accelerator(input_path)
|
| 84 |
+
if accelerated_result is not None:
|
| 85 |
+
return accelerated_result, None
|
| 86 |
+
|
| 87 |
try:
|
| 88 |
sample_rate, audio_data = scipy.io.wavfile.read(input_path)
|
| 89 |
|