// // SPDX-FileCopyrightText: Hadad // SPDX-License-Identifier: Apache-2.0 // #include "ipc_handler.hpp" #include #include #include #include namespace pocket_tts_accelerator { IpcHandler::IpcHandler(const std::string& socket_path) : socket_file_path(socket_path) , server_socket_fd(-1) , is_server_running(false) { } IpcHandler::~IpcHandler() { stop_server(); } bool IpcHandler::start_server() { if (is_server_running.load()) { return true; } unlink(socket_file_path.c_str()); server_socket_fd = socket(AF_UNIX, SOCK_STREAM, 0); if (server_socket_fd < 0) { return false; } struct sockaddr_un server_address; std::memset(&server_address, 0, sizeof(server_address)); server_address.sun_family = AF_UNIX; std::strncpy(server_address.sun_path, socket_file_path.c_str(), sizeof(server_address.sun_path) - 1); if (bind(server_socket_fd, reinterpret_cast(&server_address), sizeof(server_address)) < 0) { close(server_socket_fd); server_socket_fd = -1; return false; } if (listen(server_socket_fd, CONNECTION_BACKLOG) < 0) { close(server_socket_fd); server_socket_fd = -1; return false; } is_server_running.store(true); accept_thread = std::thread(&IpcHandler::accept_connections_loop, this); return true; } void IpcHandler::stop_server() { if (!is_server_running.load()) { return; } is_server_running.store(false); if (server_socket_fd >= 0) { shutdown(server_socket_fd, SHUT_RDWR); close(server_socket_fd); server_socket_fd = -1; } if (accept_thread.joinable()) { accept_thread.join(); } unlink(socket_file_path.c_str()); } bool IpcHandler::is_running() const { return is_server_running.load(); } void IpcHandler::register_command_handler(CommandType command_type, CommandHandlerFunction handler) { std::unique_lock lock(handlers_mutex); command_handlers[command_type] = std::move(handler); } void IpcHandler::set_shutdown_callback(std::function callback) { shutdown_callback = std::move(callback); } void IpcHandler::accept_connections_loop() { while (is_server_running.load()) { struct sockaddr_un client_address; socklen_t client_address_length = sizeof(client_address); int client_socket_fd = accept( server_socket_fd, reinterpret_cast(&client_address), &client_address_length ); if (client_socket_fd < 0) { if (!is_server_running.load()) { break; } continue; } handle_client_connection(client_socket_fd); close(client_socket_fd); } } void IpcHandler::handle_client_connection(int client_socket_fd) { RequestHeader request_header; std::vector request_payload; if (!receive_request(client_socket_fd, request_header, request_payload)) { return; } if (request_header.magic_number != PROTOCOL_MAGIC_NUMBER) { ResponseHeader error_response; error_response.magic_number = PROTOCOL_MAGIC_NUMBER; error_response.status_code = static_cast(ResponseStatus::ERROR_INVALID_COMMAND); error_response.payload_size = 0; error_response.request_id = request_header.request_id; send_response(client_socket_fd, error_response, {}); return; } CommandType command_type = static_cast(request_header.command_type); if (command_type == CommandType::SHUTDOWN) { ResponseHeader shutdown_response; shutdown_response.magic_number = PROTOCOL_MAGIC_NUMBER; shutdown_response.status_code = static_cast(ResponseStatus::SUCCESS); shutdown_response.payload_size = 0; shutdown_response.request_id = request_header.request_id; send_response(client_socket_fd, shutdown_response, {}); if (shutdown_callback) { shutdown_callback(); } return; } std::vector response_payload; ResponseStatus status = ResponseStatus::SUCCESS; { std::unique_lock lock(handlers_mutex); auto handler_iterator = command_handlers.find(command_type); if (handler_iterator != command_handlers.end()) { try { response_payload = handler_iterator->second(request_payload); } catch (...) { status = ResponseStatus::ERROR_INTERNAL; } } else { status = ResponseStatus::ERROR_INVALID_COMMAND; } } ResponseHeader response_header; response_header.magic_number = PROTOCOL_MAGIC_NUMBER; response_header.status_code = static_cast(status); response_header.payload_size = static_cast(response_payload.size()); response_header.request_id = request_header.request_id; send_response(client_socket_fd, response_header, response_payload); } bool IpcHandler::send_response( int socket_fd, const ResponseHeader& header, const std::vector& payload ) { ssize_t bytes_written = write(socket_fd, &header, sizeof(ResponseHeader)); if (bytes_written != sizeof(ResponseHeader)) { return false; } if (!payload.empty()) { bytes_written = write(socket_fd, payload.data(), payload.size()); if (bytes_written != static_cast(payload.size())) { return false; } } return true; } bool IpcHandler::receive_request( int socket_fd, RequestHeader& header, std::vector& payload ) { ssize_t bytes_read = read(socket_fd, &header, sizeof(RequestHeader)); if (bytes_read != sizeof(RequestHeader)) { return false; } if (header.payload_size > MAXIMUM_PAYLOAD_SIZE) { return false; } if (header.payload_size > 0) { payload.resize(header.payload_size); bytes_read = read(socket_fd, payload.data(), header.payload_size); if (bytes_read != static_cast(header.payload_size)) { return false; } } return true; } }