| | |
| | |
| |
|
| | #include <atomic> |
| | #include <cstdlib> |
| | #include <functional> |
| | #include <memory> |
| | #include <mutex> |
| | #include <stdexcept> |
| | #include <thread> |
| | #include <unordered_map> |
| | #include <vector> |
| |
|
| | #include <catch2/catch_test_macros.hpp> |
| |
|
| | #include "common/common_types.h" |
| | #include "common/fiber.h" |
| |
|
| | namespace Common { |
| |
|
| | class ThreadIds { |
| | public: |
| | void Register(u32 id) { |
| | const auto thread_id = std::this_thread::get_id(); |
| | std::scoped_lock lock{mutex}; |
| | if (ids.contains(thread_id)) { |
| | throw std::logic_error{"Registering the same thread twice"}; |
| | } |
| | ids.emplace(thread_id, id); |
| | } |
| |
|
| | [[nodiscard]] u32 Get() const { |
| | std::scoped_lock lock{mutex}; |
| | return ids.at(std::this_thread::get_id()); |
| | } |
| |
|
| | private: |
| | mutable std::mutex mutex; |
| | std::unordered_map<std::thread::id, u32> ids; |
| | }; |
| |
|
| | class TestControl1 { |
| | public: |
| | TestControl1() = default; |
| |
|
| | void DoWork() { |
| | const u32 id = thread_ids.Get(); |
| | u32 value = items[id]; |
| | for (u32 i = 0; i < id; i++) { |
| | value++; |
| | } |
| | results[id] = value; |
| | Fiber::YieldTo(work_fibers[id], *thread_fibers[id]); |
| | } |
| |
|
| | void ExecuteThread(u32 id); |
| |
|
| | ThreadIds thread_ids; |
| | std::vector<std::shared_ptr<Common::Fiber>> thread_fibers; |
| | std::vector<std::shared_ptr<Common::Fiber>> work_fibers; |
| | std::vector<u32> items; |
| | std::vector<u32> results; |
| | }; |
| |
|
| | void TestControl1::ExecuteThread(u32 id) { |
| | thread_ids.Register(id); |
| | auto thread_fiber = Fiber::ThreadToFiber(); |
| | thread_fibers[id] = thread_fiber; |
| | work_fibers[id] = std::make_shared<Fiber>([this] { DoWork(); }); |
| | items[id] = rand() % 256; |
| | Fiber::YieldTo(thread_fibers[id], *work_fibers[id]); |
| | thread_fibers[id]->Exit(); |
| | } |
| |
|
| | |
| | |
| | |
| | TEST_CASE("Fibers::Setup", "[common]") { |
| | constexpr std::size_t num_threads = 7; |
| | TestControl1 test_control{}; |
| | test_control.thread_fibers.resize(num_threads); |
| | test_control.work_fibers.resize(num_threads); |
| | test_control.items.resize(num_threads, 0); |
| | test_control.results.resize(num_threads, 0); |
| | std::vector<std::thread> threads; |
| | for (u32 i = 0; i < num_threads; i++) { |
| | threads.emplace_back([&test_control, i] { test_control.ExecuteThread(i); }); |
| | } |
| | for (u32 i = 0; i < num_threads; i++) { |
| | threads[i].join(); |
| | } |
| | for (u32 i = 0; i < num_threads; i++) { |
| | REQUIRE(test_control.items[i] + i == test_control.results[i]); |
| | } |
| | } |
| |
|
| | class TestControl2 { |
| | public: |
| | TestControl2() = default; |
| |
|
| | void DoWork1() { |
| | trap2 = false; |
| | while (trap.load()) |
| | ; |
| | for (u32 i = 0; i < 12000; i++) { |
| | value1 += i; |
| | } |
| | Fiber::YieldTo(fiber1, *fiber3); |
| | const u32 id = thread_ids.Get(); |
| | assert1 = id == 1; |
| | value2 += 5000; |
| | Fiber::YieldTo(fiber1, *thread_fibers[id]); |
| | } |
| |
|
| | void DoWork2() { |
| | while (trap2.load()) |
| | ; |
| | value2 = 2000; |
| | trap = false; |
| | Fiber::YieldTo(fiber2, *fiber1); |
| | assert3 = false; |
| | } |
| |
|
| | void DoWork3() { |
| | const u32 id = thread_ids.Get(); |
| | assert2 = id == 0; |
| | value1 += 1000; |
| | Fiber::YieldTo(fiber3, *thread_fibers[id]); |
| | } |
| |
|
| | void ExecuteThread(u32 id); |
| |
|
| | void CallFiber1() { |
| | const u32 id = thread_ids.Get(); |
| | Fiber::YieldTo(thread_fibers[id], *fiber1); |
| | } |
| |
|
| | void CallFiber2() { |
| | const u32 id = thread_ids.Get(); |
| | Fiber::YieldTo(thread_fibers[id], *fiber2); |
| | } |
| |
|
| | void Exit(); |
| |
|
| | bool assert1{}; |
| | bool assert2{}; |
| | bool assert3{true}; |
| | u32 value1{}; |
| | u32 value2{}; |
| | std::atomic<bool> trap{true}; |
| | std::atomic<bool> trap2{true}; |
| | ThreadIds thread_ids; |
| | std::vector<std::shared_ptr<Common::Fiber>> thread_fibers; |
| | std::shared_ptr<Common::Fiber> fiber1; |
| | std::shared_ptr<Common::Fiber> fiber2; |
| | std::shared_ptr<Common::Fiber> fiber3; |
| | }; |
| |
|
| | void TestControl2::ExecuteThread(u32 id) { |
| | thread_ids.Register(id); |
| | auto thread_fiber = Fiber::ThreadToFiber(); |
| | thread_fibers[id] = thread_fiber; |
| | } |
| |
|
| | void TestControl2::Exit() { |
| | const u32 id = thread_ids.Get(); |
| | thread_fibers[id]->Exit(); |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | TEST_CASE("Fibers::InterExchange", "[common]") { |
| | TestControl2 test_control{}; |
| | test_control.thread_fibers.resize(2); |
| | test_control.fiber1 = std::make_shared<Fiber>([&test_control] { test_control.DoWork1(); }); |
| | test_control.fiber2 = std::make_shared<Fiber>([&test_control] { test_control.DoWork2(); }); |
| | test_control.fiber3 = std::make_shared<Fiber>([&test_control] { test_control.DoWork3(); }); |
| | std::thread thread1{[&test_control] { |
| | test_control.ExecuteThread(0); |
| | test_control.CallFiber1(); |
| | test_control.Exit(); |
| | }}; |
| | std::thread thread2{[&test_control] { |
| | test_control.ExecuteThread(1); |
| | test_control.CallFiber2(); |
| | test_control.Exit(); |
| | }}; |
| | thread1.join(); |
| | thread2.join(); |
| | REQUIRE(test_control.assert1); |
| | REQUIRE(test_control.assert2); |
| | REQUIRE(test_control.assert3); |
| | REQUIRE(test_control.value2 == 7000); |
| | u32 cal_value = 0; |
| | for (u32 i = 0; i < 12000; i++) { |
| | cal_value += i; |
| | } |
| | cal_value += 1000; |
| | REQUIRE(test_control.value1 == cal_value); |
| | } |
| |
|
| | class TestControl3 { |
| | public: |
| | TestControl3() = default; |
| |
|
| | void DoWork1() { |
| | value1 += 1; |
| | Fiber::YieldTo(fiber1, *fiber2); |
| | const u32 id = thread_ids.Get(); |
| | value3 += 1; |
| | Fiber::YieldTo(fiber1, *thread_fibers[id]); |
| | } |
| |
|
| | void DoWork2() { |
| | value2 += 1; |
| | const u32 id = thread_ids.Get(); |
| | Fiber::YieldTo(fiber2, *thread_fibers[id]); |
| | } |
| |
|
| | void ExecuteThread(u32 id); |
| |
|
| | void CallFiber1() { |
| | const u32 id = thread_ids.Get(); |
| | Fiber::YieldTo(thread_fibers[id], *fiber1); |
| | } |
| |
|
| | void Exit(); |
| |
|
| | u32 value1{}; |
| | u32 value2{}; |
| | u32 value3{}; |
| | ThreadIds thread_ids; |
| | std::vector<std::shared_ptr<Common::Fiber>> thread_fibers; |
| | std::shared_ptr<Common::Fiber> fiber1; |
| | std::shared_ptr<Common::Fiber> fiber2; |
| | }; |
| |
|
| | void TestControl3::ExecuteThread(u32 id) { |
| | thread_ids.Register(id); |
| | auto thread_fiber = Fiber::ThreadToFiber(); |
| | thread_fibers[id] = thread_fiber; |
| | } |
| |
|
| | void TestControl3::Exit() { |
| | const u32 id = thread_ids.Get(); |
| | thread_fibers[id]->Exit(); |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | TEST_CASE("Fibers::StartRace", "[common]") { |
| | TestControl3 test_control{}; |
| | test_control.thread_fibers.resize(2); |
| | test_control.fiber1 = std::make_shared<Fiber>([&test_control] { test_control.DoWork1(); }); |
| | test_control.fiber2 = std::make_shared<Fiber>([&test_control] { test_control.DoWork2(); }); |
| | const auto race_function{[&test_control](u32 id) { |
| | test_control.ExecuteThread(id); |
| | test_control.CallFiber1(); |
| | test_control.Exit(); |
| | }}; |
| | std::thread thread1([&] { race_function(0); }); |
| | std::thread thread2([&] { race_function(1); }); |
| | thread1.join(); |
| | thread2.join(); |
| | REQUIRE(test_control.value1 == 1); |
| | REQUIRE(test_control.value2 == 1); |
| | REQUIRE(test_control.value3 == 1); |
| | } |
| |
|
| | class TestControl4; |
| |
|
| | class TestControl4 { |
| | public: |
| | TestControl4() { |
| | fiber1 = std::make_shared<Fiber>([this] { DoWork(); }); |
| | goal_reached = false; |
| | rewinded = false; |
| | } |
| |
|
| | void Execute() { |
| | thread_fiber = Fiber::ThreadToFiber(); |
| | Fiber::YieldTo(thread_fiber, *fiber1); |
| | thread_fiber->Exit(); |
| | } |
| |
|
| | void DoWork() { |
| | fiber1->SetRewindPoint([this] { DoWork(); }); |
| | if (rewinded) { |
| | goal_reached = true; |
| | Fiber::YieldTo(fiber1, *thread_fiber); |
| | } |
| | rewinded = true; |
| | fiber1->Rewind(); |
| | } |
| |
|
| | std::shared_ptr<Common::Fiber> fiber1; |
| | std::shared_ptr<Common::Fiber> thread_fiber; |
| | bool goal_reached; |
| | bool rewinded; |
| | }; |
| |
|
| | TEST_CASE("Fibers::Rewind", "[common]") { |
| | TestControl4 test_control{}; |
| | test_control.Execute(); |
| | REQUIRE(test_control.goal_reached); |
| | REQUIRE(test_control.rewinded); |
| | } |
| |
|
| | } |
| |
|