// // SPDX-FileCopyrightText: Hadad // SPDX-License-Identifier: Apache-2.0 // #ifndef POCKET_TTS_THREAD_POOL_HPP #define POCKET_TTS_THREAD_POOL_HPP #include #include #include #include #include #include #include #include #include namespace pocket_tts_accelerator { class ThreadPool { public: explicit ThreadPool(std::size_t number_of_threads); ~ThreadPool(); ThreadPool(const ThreadPool&) = delete; ThreadPool& operator=(const ThreadPool&) = delete; ThreadPool(ThreadPool&&) = delete; ThreadPool& operator=(ThreadPool&&) = delete; template auto submit_task(FunctionType&& function, ArgumentTypes&&... arguments) -> std::future::type>; void shutdown(); bool is_running() const; std::size_t get_pending_task_count() const; std::size_t get_thread_count() const; private: void worker_thread_function(); std::vector worker_threads; std::queue> task_queue; mutable std::mutex queue_mutex; std::condition_variable task_available_condition; std::atomic should_stop; std::atomic is_stopped; std::size_t thread_count; }; template auto ThreadPool::submit_task(FunctionType&& function, ArgumentTypes&&... arguments) -> std::future::type> { using ReturnType = typename std::invoke_result::type; auto packaged_task = std::make_shared>( std::bind(std::forward(function), std::forward(arguments)...) ); std::future result_future = packaged_task->get_future(); { std::unique_lock lock(queue_mutex); if (should_stop.load()) { throw std::runtime_error("Cannot submit task to stopped thread pool"); } task_queue.emplace([packaged_task]() { (*packaged_task)(); }); } task_available_condition.notify_one(); return result_future; } } #endif