|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#ifndef POCKET_TTS_THREAD_POOL_HPP |
|
|
#define POCKET_TTS_THREAD_POOL_HPP |
|
|
|
|
|
#include <atomic> |
|
|
#include <condition_variable> |
|
|
#include <functional> |
|
|
#include <future> |
|
|
#include <memory> |
|
|
#include <mutex> |
|
|
#include <queue> |
|
|
#include <thread> |
|
|
#include <vector> |
|
|
|
|
|
namespace pocket_tts_accelerator { |
|
|
|
|
|
class ThreadPool { |
|
|
public: |
|
|
explicit ThreadPool(std::size_t number_of_threads); |
|
|
~ThreadPool(); |
|
|
|
|
|
ThreadPool(const ThreadPool&) = delete; |
|
|
ThreadPool& operator=(const ThreadPool&) = delete; |
|
|
ThreadPool(ThreadPool&&) = delete; |
|
|
ThreadPool& operator=(ThreadPool&&) = delete; |
|
|
|
|
|
template<typename FunctionType, typename... ArgumentTypes> |
|
|
auto submit_task(FunctionType&& function, ArgumentTypes&&... arguments) |
|
|
-> std::future<typename std::invoke_result<FunctionType, ArgumentTypes...>::type>; |
|
|
|
|
|
void shutdown(); |
|
|
bool is_running() const; |
|
|
std::size_t get_pending_task_count() const; |
|
|
std::size_t get_thread_count() const; |
|
|
|
|
|
private: |
|
|
void worker_thread_function(); |
|
|
|
|
|
std::vector<std::thread> worker_threads; |
|
|
std::queue<std::function<void()>> task_queue; |
|
|
mutable std::mutex queue_mutex; |
|
|
std::condition_variable task_available_condition; |
|
|
std::atomic<bool> should_stop; |
|
|
std::atomic<bool> is_stopped; |
|
|
std::size_t thread_count; |
|
|
}; |
|
|
|
|
|
template<typename FunctionType, typename... ArgumentTypes> |
|
|
auto ThreadPool::submit_task(FunctionType&& function, ArgumentTypes&&... arguments) |
|
|
-> std::future<typename std::invoke_result<FunctionType, ArgumentTypes...>::type> { |
|
|
|
|
|
using ReturnType = typename std::invoke_result<FunctionType, ArgumentTypes...>::type; |
|
|
|
|
|
auto packaged_task = std::make_shared<std::packaged_task<ReturnType()>>( |
|
|
std::bind(std::forward<FunctionType>(function), std::forward<ArgumentTypes>(arguments)...) |
|
|
); |
|
|
|
|
|
std::future<ReturnType> result_future = packaged_task->get_future(); |
|
|
|
|
|
{ |
|
|
std::unique_lock<std::mutex> lock(queue_mutex); |
|
|
|
|
|
if (should_stop.load()) { |
|
|
throw std::runtime_error("Cannot submit task to stopped thread pool"); |
|
|
} |
|
|
|
|
|
task_queue.emplace([packaged_task]() { |
|
|
(*packaged_task)(); |
|
|
}); |
|
|
} |
|
|
|
|
|
task_available_condition.notify_one(); |
|
|
|
|
|
return result_future; |
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
#endif |