|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include "ipc_handler.hpp" |
|
|
#include <cstring> |
|
|
#include <iostream> |
|
|
#include <sys/socket.h> |
|
|
#include <sys/un.h> |
|
|
#include <unistd.h> |
|
|
|
|
|
namespace pocket_tts_accelerator { |
|
|
|
|
|
IpcHandler::IpcHandler(const std::string& socket_path) |
|
|
: socket_file_path(socket_path) |
|
|
, server_socket_fd(-1) |
|
|
, is_server_running(false) { |
|
|
} |
|
|
|
|
|
IpcHandler::~IpcHandler() { |
|
|
stop_server(); |
|
|
} |
|
|
|
|
|
bool IpcHandler::start_server() { |
|
|
if (is_server_running.load()) { |
|
|
return true; |
|
|
} |
|
|
|
|
|
unlink(socket_file_path.c_str()); |
|
|
|
|
|
server_socket_fd = socket(AF_UNIX, SOCK_STREAM, 0); |
|
|
|
|
|
if (server_socket_fd < 0) { |
|
|
std::cerr << "Failed to create socket: " << strerror(errno) << std::endl; |
|
|
return false; |
|
|
} |
|
|
|
|
|
struct sockaddr_un server_address; |
|
|
std::memset(&server_address, 0, sizeof(server_address)); |
|
|
server_address.sun_family = AF_UNIX; |
|
|
std::strncpy(server_address.sun_path, socket_file_path.c_str(), sizeof(server_address.sun_path) - 1); |
|
|
|
|
|
if (bind(server_socket_fd, reinterpret_cast<struct sockaddr*>(&server_address), sizeof(server_address)) < 0) { |
|
|
std::cerr << "Failed to bind socket: " << strerror(errno) << std::endl; |
|
|
close(server_socket_fd); |
|
|
server_socket_fd = -1; |
|
|
return false; |
|
|
} |
|
|
|
|
|
if (listen(server_socket_fd, CONNECTION_BACKLOG) < 0) { |
|
|
std::cerr << "Failed to listen on socket: " << strerror(errno) << std::endl; |
|
|
close(server_socket_fd); |
|
|
server_socket_fd = -1; |
|
|
return false; |
|
|
} |
|
|
|
|
|
is_server_running.store(true); |
|
|
accept_thread = std::thread(&IpcHandler::accept_connections_loop, this); |
|
|
|
|
|
return true; |
|
|
} |
|
|
|
|
|
void IpcHandler::stop_server() { |
|
|
if (!is_server_running.load()) { |
|
|
return; |
|
|
} |
|
|
|
|
|
is_server_running.store(false); |
|
|
|
|
|
if (server_socket_fd >= 0) { |
|
|
shutdown(server_socket_fd, SHUT_RDWR); |
|
|
close(server_socket_fd); |
|
|
server_socket_fd = -1; |
|
|
} |
|
|
|
|
|
if (accept_thread.joinable()) { |
|
|
accept_thread.join(); |
|
|
} |
|
|
|
|
|
unlink(socket_file_path.c_str()); |
|
|
} |
|
|
|
|
|
bool IpcHandler::is_running() const { |
|
|
return is_server_running.load(); |
|
|
} |
|
|
|
|
|
void IpcHandler::register_command_handler(CommandType command_type, CommandHandlerFunction handler) { |
|
|
std::unique_lock<std::mutex> lock(handlers_mutex); |
|
|
command_handlers[command_type] = std::move(handler); |
|
|
} |
|
|
|
|
|
void IpcHandler::set_shutdown_callback(std::function<void()> callback) { |
|
|
shutdown_callback = std::move(callback); |
|
|
} |
|
|
|
|
|
void IpcHandler::accept_connections_loop() { |
|
|
while (is_server_running.load()) { |
|
|
struct sockaddr_un client_address; |
|
|
socklen_t client_address_length = sizeof(client_address); |
|
|
|
|
|
int client_socket_fd = accept( |
|
|
server_socket_fd, |
|
|
reinterpret_cast<struct sockaddr*>(&client_address), |
|
|
&client_address_length |
|
|
); |
|
|
|
|
|
if (client_socket_fd < 0) { |
|
|
if (!is_server_running.load()) { |
|
|
break; |
|
|
} |
|
|
continue; |
|
|
} |
|
|
|
|
|
handle_client_connection(client_socket_fd); |
|
|
close(client_socket_fd); |
|
|
} |
|
|
} |
|
|
|
|
|
void IpcHandler::handle_client_connection(int client_socket_fd) { |
|
|
RequestHeader request_header; |
|
|
std::vector<std::uint8_t> request_payload; |
|
|
|
|
|
if (!receive_request(client_socket_fd, request_header, request_payload)) { |
|
|
return; |
|
|
} |
|
|
|
|
|
if (request_header.magic_number != PROTOCOL_MAGIC_NUMBER) { |
|
|
ResponseHeader error_response; |
|
|
error_response.magic_number = PROTOCOL_MAGIC_NUMBER; |
|
|
error_response.status_code = static_cast<std::uint32_t>(ResponseStatus::ERROR_INVALID_COMMAND); |
|
|
error_response.payload_size = 0; |
|
|
error_response.request_id = request_header.request_id; |
|
|
send_response(client_socket_fd, error_response, {}); |
|
|
return; |
|
|
} |
|
|
|
|
|
CommandType command_type = static_cast<CommandType>(request_header.command_type); |
|
|
|
|
|
std::vector<std::uint8_t> response_payload; |
|
|
ResponseStatus status = ResponseStatus::SUCCESS; |
|
|
|
|
|
{ |
|
|
std::unique_lock<std::mutex> lock(handlers_mutex); |
|
|
auto handler_iterator = command_handlers.find(command_type); |
|
|
|
|
|
if (handler_iterator != command_handlers.end()) { |
|
|
try { |
|
|
response_payload = handler_iterator->second(request_payload); |
|
|
} catch (const std::exception& exception) { |
|
|
std::cerr << "Handler exception: " << exception.what() << std::endl; |
|
|
status = ResponseStatus::ERROR_INTERNAL; |
|
|
} catch (...) { |
|
|
std::cerr << "Handler unknown exception" << std::endl; |
|
|
status = ResponseStatus::ERROR_INTERNAL; |
|
|
} |
|
|
} else { |
|
|
status = ResponseStatus::ERROR_INVALID_COMMAND; |
|
|
} |
|
|
} |
|
|
|
|
|
ResponseHeader response_header; |
|
|
response_header.magic_number = PROTOCOL_MAGIC_NUMBER; |
|
|
response_header.status_code = static_cast<std::uint32_t>(status); |
|
|
response_header.payload_size = static_cast<std::uint32_t>(response_payload.size()); |
|
|
response_header.request_id = request_header.request_id; |
|
|
|
|
|
send_response(client_socket_fd, response_header, response_payload); |
|
|
} |
|
|
|
|
|
bool IpcHandler::send_response( |
|
|
int socket_fd, |
|
|
const ResponseHeader& header, |
|
|
const std::vector<std::uint8_t>& payload |
|
|
) { |
|
|
ssize_t bytes_written = write(socket_fd, &header, sizeof(ResponseHeader)); |
|
|
|
|
|
if (bytes_written != sizeof(ResponseHeader)) { |
|
|
return false; |
|
|
} |
|
|
|
|
|
if (!payload.empty()) { |
|
|
bytes_written = write(socket_fd, payload.data(), payload.size()); |
|
|
|
|
|
if (bytes_written != static_cast<ssize_t>(payload.size())) { |
|
|
return false; |
|
|
} |
|
|
} |
|
|
|
|
|
return true; |
|
|
} |
|
|
|
|
|
bool IpcHandler::receive_request( |
|
|
int socket_fd, |
|
|
RequestHeader& header, |
|
|
std::vector<std::uint8_t>& payload |
|
|
) { |
|
|
ssize_t bytes_read = read(socket_fd, &header, sizeof(RequestHeader)); |
|
|
|
|
|
if (bytes_read != sizeof(RequestHeader)) { |
|
|
return false; |
|
|
} |
|
|
|
|
|
if (header.payload_size > MAXIMUM_PAYLOAD_SIZE) { |
|
|
return false; |
|
|
} |
|
|
|
|
|
if (header.payload_size > 0) { |
|
|
payload.resize(header.payload_size); |
|
|
std::size_t total_read = 0; |
|
|
|
|
|
while (total_read < header.payload_size) { |
|
|
bytes_read = read(socket_fd, payload.data() + total_read, header.payload_size - total_read); |
|
|
|
|
|
if (bytes_read <= 0) { |
|
|
return false; |
|
|
} |
|
|
|
|
|
total_read += static_cast<std::size_t>(bytes_read); |
|
|
} |
|
|
} |
|
|
|
|
|
return true; |
|
|
} |
|
|
|
|
|
} |