Spaces:
Runtime error
Runtime error
| // | |
| // SPDX-FileCopyrightText: Hadad <hadad@linuxmail.org> | |
| // SPDX-License-Identifier: Apache-2.0 | |
| // | |
| namespace pocket_tts_accelerator { | |
| IpcHandler::IpcHandler(const std::string& socket_path) | |
| : socket_file_path(socket_path) | |
| , server_socket_fd(-1) | |
| , is_server_running(false) { | |
| } | |
| IpcHandler::~IpcHandler() { | |
| stop_server(); | |
| } | |
| bool IpcHandler::start_server() { | |
| if (is_server_running.load()) { | |
| return true; | |
| } | |
| unlink(socket_file_path.c_str()); | |
| server_socket_fd = socket(AF_UNIX, SOCK_STREAM, 0); | |
| if (server_socket_fd < 0) { | |
| return false; | |
| } | |
| struct sockaddr_un server_address; | |
| std::memset(&server_address, 0, sizeof(server_address)); | |
| server_address.sun_family = AF_UNIX; | |
| std::strncpy(server_address.sun_path, socket_file_path.c_str(), sizeof(server_address.sun_path) - 1); | |
| if (bind(server_socket_fd, reinterpret_cast<struct sockaddr*>(&server_address), sizeof(server_address)) < 0) { | |
| close(server_socket_fd); | |
| server_socket_fd = -1; | |
| return false; | |
| } | |
| if (listen(server_socket_fd, CONNECTION_BACKLOG) < 0) { | |
| close(server_socket_fd); | |
| server_socket_fd = -1; | |
| return false; | |
| } | |
| is_server_running.store(true); | |
| accept_thread = std::thread(&IpcHandler::accept_connections_loop, this); | |
| return true; | |
| } | |
| void IpcHandler::stop_server() { | |
| if (!is_server_running.load()) { | |
| return; | |
| } | |
| is_server_running.store(false); | |
| if (server_socket_fd >= 0) { | |
| shutdown(server_socket_fd, SHUT_RDWR); | |
| close(server_socket_fd); | |
| server_socket_fd = -1; | |
| } | |
| if (accept_thread.joinable()) { | |
| accept_thread.join(); | |
| } | |
| unlink(socket_file_path.c_str()); | |
| } | |
| bool IpcHandler::is_running() const { | |
| return is_server_running.load(); | |
| } | |
| void IpcHandler::register_command_handler(CommandType command_type, CommandHandlerFunction handler) { | |
| std::unique_lock<std::mutex> lock(handlers_mutex); | |
| command_handlers[command_type] = std::move(handler); | |
| } | |
| void IpcHandler::set_shutdown_callback(std::function<void()> callback) { | |
| shutdown_callback = std::move(callback); | |
| } | |
| void IpcHandler::accept_connections_loop() { | |
| while (is_server_running.load()) { | |
| struct sockaddr_un client_address; | |
| socklen_t client_address_length = sizeof(client_address); | |
| int client_socket_fd = accept( | |
| server_socket_fd, | |
| reinterpret_cast<struct sockaddr*>(&client_address), | |
| &client_address_length | |
| ); | |
| if (client_socket_fd < 0) { | |
| if (!is_server_running.load()) { | |
| break; | |
| } | |
| continue; | |
| } | |
| handle_client_connection(client_socket_fd); | |
| close(client_socket_fd); | |
| } | |
| } | |
| void IpcHandler::handle_client_connection(int client_socket_fd) { | |
| RequestHeader request_header; | |
| std::vector<std::uint8_t> request_payload; | |
| if (!receive_request(client_socket_fd, request_header, request_payload)) { | |
| return; | |
| } | |
| if (request_header.magic_number != PROTOCOL_MAGIC_NUMBER) { | |
| ResponseHeader error_response; | |
| error_response.magic_number = PROTOCOL_MAGIC_NUMBER; | |
| error_response.status_code = static_cast<std::uint32_t>(ResponseStatus::ERROR_INVALID_COMMAND); | |
| error_response.payload_size = 0; | |
| error_response.request_id = request_header.request_id; | |
| send_response(client_socket_fd, error_response, {}); | |
| return; | |
| } | |
| CommandType command_type = static_cast<CommandType>(request_header.command_type); | |
| if (command_type == CommandType::SHUTDOWN) { | |
| ResponseHeader shutdown_response; | |
| shutdown_response.magic_number = PROTOCOL_MAGIC_NUMBER; | |
| shutdown_response.status_code = static_cast<std::uint32_t>(ResponseStatus::SUCCESS); | |
| shutdown_response.payload_size = 0; | |
| shutdown_response.request_id = request_header.request_id; | |
| send_response(client_socket_fd, shutdown_response, {}); | |
| if (shutdown_callback) { | |
| shutdown_callback(); | |
| } | |
| return; | |
| } | |
| std::vector<std::uint8_t> response_payload; | |
| ResponseStatus status = ResponseStatus::SUCCESS; | |
| { | |
| std::unique_lock<std::mutex> lock(handlers_mutex); | |
| auto handler_iterator = command_handlers.find(command_type); | |
| if (handler_iterator != command_handlers.end()) { | |
| try { | |
| response_payload = handler_iterator->second(request_payload); | |
| } catch (...) { | |
| status = ResponseStatus::ERROR_INTERNAL; | |
| } | |
| } else { | |
| status = ResponseStatus::ERROR_INVALID_COMMAND; | |
| } | |
| } | |
| ResponseHeader response_header; | |
| response_header.magic_number = PROTOCOL_MAGIC_NUMBER; | |
| response_header.status_code = static_cast<std::uint32_t>(status); | |
| response_header.payload_size = static_cast<std::uint32_t>(response_payload.size()); | |
| response_header.request_id = request_header.request_id; | |
| send_response(client_socket_fd, response_header, response_payload); | |
| } | |
| bool IpcHandler::send_response( | |
| int socket_fd, | |
| const ResponseHeader& header, | |
| const std::vector<std::uint8_t>& payload | |
| ) { | |
| ssize_t bytes_written = write(socket_fd, &header, sizeof(ResponseHeader)); | |
| if (bytes_written != sizeof(ResponseHeader)) { | |
| return false; | |
| } | |
| if (!payload.empty()) { | |
| bytes_written = write(socket_fd, payload.data(), payload.size()); | |
| if (bytes_written != static_cast<ssize_t>(payload.size())) { | |
| return false; | |
| } | |
| } | |
| return true; | |
| } | |
| bool IpcHandler::receive_request( | |
| int socket_fd, | |
| RequestHeader& header, | |
| std::vector<std::uint8_t>& payload | |
| ) { | |
| ssize_t bytes_read = read(socket_fd, &header, sizeof(RequestHeader)); | |
| if (bytes_read != sizeof(RequestHeader)) { | |
| return false; | |
| } | |
| if (header.payload_size > MAXIMUM_PAYLOAD_SIZE) { | |
| return false; | |
| } | |
| if (header.payload_size > 0) { | |
| payload.resize(header.payload_size); | |
| bytes_read = read(socket_fd, payload.data(), header.payload_size); | |
| if (bytes_read != static_cast<ssize_t>(header.payload_size)) { | |
| return false; | |
| } | |
| } | |
| return true; | |
| } | |
| } |