// // SPDX-FileCopyrightText: Hadad // SPDX-License-Identifier: Apache-2.0 // #include "memory_pool.hpp" #include #include #include namespace pocket_tts_accelerator { MemoryPool::MemoryPool(std::size_t initial_pool_size_bytes) : total_allocated_bytes(0) , total_used_bytes(0) , maximum_pool_size_bytes(initial_pool_size_bytes) { } MemoryPool::~MemoryPool() { reset_pool(); } std::uint8_t* MemoryPool::allocate(std::size_t requested_size_bytes) { std::unique_lock lock(pool_mutex); std::size_t block_index = find_suitable_block_index(requested_size_bytes); if (block_index != static_cast(-1)) { MemoryBlock& existing_block = memory_blocks[block_index]; existing_block.is_in_use = true; existing_block.last_access_timestamp = get_current_timestamp(); total_used_bytes += existing_block.block_size; return existing_block.data.get(); } if (total_allocated_bytes + requested_size_bytes > maximum_pool_size_bytes) { clear_unused_blocks(); } std::size_t aligned_size = ((requested_size_bytes + 63) / 64) * 64; memory_blocks.push_back(MemoryBlock{ std::make_unique(aligned_size), aligned_size, true, get_current_timestamp() }); std::uint8_t* allocated_pointer = memory_blocks.back().data.get(); pointer_to_block_index[allocated_pointer] = memory_blocks.size() - 1; total_allocated_bytes += aligned_size; total_used_bytes += aligned_size; return allocated_pointer; } void MemoryPool::deallocate(std::uint8_t* pointer) { if (pointer == nullptr) { return; } std::unique_lock lock(pool_mutex); auto iterator = pointer_to_block_index.find(pointer); if (iterator != pointer_to_block_index.end()) { std::size_t block_index = iterator->second; if (block_index < memory_blocks.size()) { MemoryBlock& block = memory_blocks[block_index]; if (block.is_in_use) { block.is_in_use = false; block.last_access_timestamp = get_current_timestamp(); total_used_bytes -= block.block_size; } } } } void MemoryPool::clear_unused_blocks() { std::vector indices_to_remove; for (std::size_t index = 0; index < memory_blocks.size(); ++index) { if (!memory_blocks[index].is_in_use) { indices_to_remove.push_back(index); } } std::sort(indices_to_remove.rbegin(), indices_to_remove.rend()); for (std::size_t index : indices_to_remove) { std::uint8_t* pointer = memory_blocks[index].data.get(); total_allocated_bytes -= memory_blocks[index].block_size; pointer_to_block_index.erase(pointer); memory_blocks.erase(memory_blocks.begin() + static_cast(index)); } for (std::size_t index = 0; index < memory_blocks.size(); ++index) { pointer_to_block_index[memory_blocks[index].data.get()] = index; } } void MemoryPool::reset_pool() { std::unique_lock lock(pool_mutex); memory_blocks.clear(); pointer_to_block_index.clear(); total_allocated_bytes = 0; total_used_bytes = 0; } std::size_t MemoryPool::get_total_allocated_bytes() const { std::unique_lock lock(pool_mutex); return total_allocated_bytes; } std::size_t MemoryPool::get_total_used_bytes() const { std::unique_lock lock(pool_mutex); return total_used_bytes; } std::size_t MemoryPool::get_block_count() const { std::unique_lock lock(pool_mutex); return memory_blocks.size(); } std::size_t MemoryPool::find_suitable_block_index(std::size_t requested_size) const { std::size_t best_fit_index = static_cast(-1); std::size_t best_fit_size = static_cast(-1); for (std::size_t index = 0; index < memory_blocks.size(); ++index) { const MemoryBlock& block = memory_blocks[index]; if (!block.is_in_use && block.block_size >= requested_size) { if (block.block_size < best_fit_size) { best_fit_size = block.block_size; best_fit_index = index; } } } return best_fit_index; } void MemoryPool::create_new_block(std::size_t block_size) { std::size_t aligned_size = ((block_size + 63) / 64) * 64; memory_blocks.push_back(MemoryBlock{ std::make_unique(aligned_size), aligned_size, false, get_current_timestamp() }); pointer_to_block_index[memory_blocks.back().data.get()] = memory_blocks.size() - 1; total_allocated_bytes += aligned_size; } std::uint64_t MemoryPool::get_current_timestamp() const { auto current_time = std::chrono::steady_clock::now(); auto duration = current_time.time_since_epoch(); return std::chrono::duration_cast(duration).count(); } ScopedMemoryAllocation::ScopedMemoryAllocation(MemoryPool& pool, std::size_t size) : memory_pool_pointer(&pool) , allocated_pointer(pool.allocate(size)) , allocation_size(size) { } ScopedMemoryAllocation::~ScopedMemoryAllocation() { if (memory_pool_pointer != nullptr && allocated_pointer != nullptr) { memory_pool_pointer->deallocate(allocated_pointer); } } ScopedMemoryAllocation::ScopedMemoryAllocation(ScopedMemoryAllocation&& other) noexcept : memory_pool_pointer(other.memory_pool_pointer) , allocated_pointer(other.allocated_pointer) , allocation_size(other.allocation_size) { other.memory_pool_pointer = nullptr; other.allocated_pointer = nullptr; other.allocation_size = 0; } ScopedMemoryAllocation& ScopedMemoryAllocation::operator=(ScopedMemoryAllocation&& other) noexcept { if (this != &other) { if (memory_pool_pointer != nullptr && allocated_pointer != nullptr) { memory_pool_pointer->deallocate(allocated_pointer); } memory_pool_pointer = other.memory_pool_pointer; allocated_pointer = other.allocated_pointer; allocation_size = other.allocation_size; other.memory_pool_pointer = nullptr; other.allocated_pointer = nullptr; other.allocation_size = 0; } return *this; } std::uint8_t* ScopedMemoryAllocation::get() const { return allocated_pointer; } std::size_t ScopedMemoryAllocation::size() const { return allocation_size; } }