aaaaaaaaaaaaaaa / accelerator /src /ipc_handler.cpp
arifather51's picture
Upload 28 files
a57f260 verified
//
// SPDX-FileCopyrightText: Hadad <hadad@linuxmail.org>
// SPDX-License-Identifier: Apache-2.0
//
#include "ipc_handler.hpp"
#include <cstring>
#include <sys/socket.h>
#include <sys/un.h>
#include <unistd.h>
namespace pocket_tts_accelerator {
IpcHandler::IpcHandler(const std::string& socket_path)
: socket_file_path(socket_path)
, server_socket_fd(-1)
, is_server_running(false) {
}
IpcHandler::~IpcHandler() {
stop_server();
}
bool IpcHandler::start_server() {
if (is_server_running.load()) {
return true;
}
unlink(socket_file_path.c_str());
server_socket_fd = socket(AF_UNIX, SOCK_STREAM, 0);
if (server_socket_fd < 0) {
return false;
}
struct sockaddr_un server_address;
std::memset(&server_address, 0, sizeof(server_address));
server_address.sun_family = AF_UNIX;
std::strncpy(server_address.sun_path, socket_file_path.c_str(), sizeof(server_address.sun_path) - 1);
if (bind(server_socket_fd, reinterpret_cast<struct sockaddr*>(&server_address), sizeof(server_address)) < 0) {
close(server_socket_fd);
server_socket_fd = -1;
return false;
}
if (listen(server_socket_fd, CONNECTION_BACKLOG) < 0) {
close(server_socket_fd);
server_socket_fd = -1;
return false;
}
is_server_running.store(true);
accept_thread = std::thread(&IpcHandler::accept_connections_loop, this);
return true;
}
void IpcHandler::stop_server() {
if (!is_server_running.load()) {
return;
}
is_server_running.store(false);
if (server_socket_fd >= 0) {
shutdown(server_socket_fd, SHUT_RDWR);
close(server_socket_fd);
server_socket_fd = -1;
}
if (accept_thread.joinable()) {
accept_thread.join();
}
unlink(socket_file_path.c_str());
}
bool IpcHandler::is_running() const {
return is_server_running.load();
}
void IpcHandler::register_command_handler(CommandType command_type, CommandHandlerFunction handler) {
std::unique_lock<std::mutex> lock(handlers_mutex);
command_handlers[command_type] = std::move(handler);
}
void IpcHandler::set_shutdown_callback(std::function<void()> callback) {
shutdown_callback = std::move(callback);
}
void IpcHandler::accept_connections_loop() {
while (is_server_running.load()) {
struct sockaddr_un client_address;
socklen_t client_address_length = sizeof(client_address);
int client_socket_fd = accept(
server_socket_fd,
reinterpret_cast<struct sockaddr*>(&client_address),
&client_address_length
);
if (client_socket_fd < 0) {
if (!is_server_running.load()) {
break;
}
continue;
}
handle_client_connection(client_socket_fd);
close(client_socket_fd);
}
}
void IpcHandler::handle_client_connection(int client_socket_fd) {
RequestHeader request_header;
std::vector<std::uint8_t> request_payload;
if (!receive_request(client_socket_fd, request_header, request_payload)) {
return;
}
if (request_header.magic_number != PROTOCOL_MAGIC_NUMBER) {
ResponseHeader error_response;
error_response.magic_number = PROTOCOL_MAGIC_NUMBER;
error_response.status_code = static_cast<std::uint32_t>(ResponseStatus::ERROR_INVALID_COMMAND);
error_response.payload_size = 0;
error_response.request_id = request_header.request_id;
send_response(client_socket_fd, error_response, {});
return;
}
CommandType command_type = static_cast<CommandType>(request_header.command_type);
if (command_type == CommandType::SHUTDOWN) {
ResponseHeader shutdown_response;
shutdown_response.magic_number = PROTOCOL_MAGIC_NUMBER;
shutdown_response.status_code = static_cast<std::uint32_t>(ResponseStatus::SUCCESS);
shutdown_response.payload_size = 0;
shutdown_response.request_id = request_header.request_id;
send_response(client_socket_fd, shutdown_response, {});
if (shutdown_callback) {
shutdown_callback();
}
return;
}
std::vector<std::uint8_t> response_payload;
ResponseStatus status = ResponseStatus::SUCCESS;
{
std::unique_lock<std::mutex> lock(handlers_mutex);
auto handler_iterator = command_handlers.find(command_type);
if (handler_iterator != command_handlers.end()) {
try {
response_payload = handler_iterator->second(request_payload);
} catch (...) {
status = ResponseStatus::ERROR_INTERNAL;
}
} else {
status = ResponseStatus::ERROR_INVALID_COMMAND;
}
}
ResponseHeader response_header;
response_header.magic_number = PROTOCOL_MAGIC_NUMBER;
response_header.status_code = static_cast<std::uint32_t>(status);
response_header.payload_size = static_cast<std::uint32_t>(response_payload.size());
response_header.request_id = request_header.request_id;
send_response(client_socket_fd, response_header, response_payload);
}
bool IpcHandler::send_response(
int socket_fd,
const ResponseHeader& header,
const std::vector<std::uint8_t>& payload
) {
ssize_t bytes_written = write(socket_fd, &header, sizeof(ResponseHeader));
if (bytes_written != sizeof(ResponseHeader)) {
return false;
}
if (!payload.empty()) {
bytes_written = write(socket_fd, payload.data(), payload.size());
if (bytes_written != static_cast<ssize_t>(payload.size())) {
return false;
}
}
return true;
}
bool IpcHandler::receive_request(
int socket_fd,
RequestHeader& header,
std::vector<std::uint8_t>& payload
) {
ssize_t bytes_read = read(socket_fd, &header, sizeof(RequestHeader));
if (bytes_read != sizeof(RequestHeader)) {
return false;
}
if (header.payload_size > MAXIMUM_PAYLOAD_SIZE) {
return false;
}
if (header.payload_size > 0) {
payload.resize(header.payload_size);
bytes_read = read(socket_fd, payload.data(), header.payload_size);
if (bytes_read != static_cast<ssize_t>(header.payload_size)) {
return false;
}
}
return true;
}
}