aaaaaaaaaaaaaaa / accelerator /src /memory_pool.cpp
arifather51's picture
Upload 28 files
a57f260 verified
//
// SPDX-FileCopyrightText: Hadad <hadad@linuxmail.org>
// SPDX-License-Identifier: Apache-2.0
//
#include "memory_pool.hpp"
#include <algorithm>
#include <chrono>
#include <cstring>
namespace pocket_tts_accelerator {
MemoryPool::MemoryPool(std::size_t initial_pool_size_bytes)
: total_allocated_bytes(0)
, total_used_bytes(0)
, maximum_pool_size_bytes(initial_pool_size_bytes) {
}
MemoryPool::~MemoryPool() {
reset_pool();
}
std::uint8_t* MemoryPool::allocate(std::size_t requested_size_bytes) {
std::unique_lock<std::mutex> lock(pool_mutex);
std::size_t block_index = find_suitable_block_index(requested_size_bytes);
if (block_index != static_cast<std::size_t>(-1)) {
MemoryBlock& existing_block = memory_blocks[block_index];
existing_block.is_in_use = true;
existing_block.last_access_timestamp = get_current_timestamp();
total_used_bytes += existing_block.block_size;
return existing_block.data.get();
}
if (total_allocated_bytes + requested_size_bytes > maximum_pool_size_bytes) {
clear_unused_blocks();
}
std::size_t aligned_size = ((requested_size_bytes + 63) / 64) * 64;
memory_blocks.push_back(MemoryBlock{
std::make_unique<std::uint8_t[]>(aligned_size),
aligned_size,
true,
get_current_timestamp()
});
std::uint8_t* allocated_pointer = memory_blocks.back().data.get();
pointer_to_block_index[allocated_pointer] = memory_blocks.size() - 1;
total_allocated_bytes += aligned_size;
total_used_bytes += aligned_size;
return allocated_pointer;
}
void MemoryPool::deallocate(std::uint8_t* pointer) {
if (pointer == nullptr) {
return;
}
std::unique_lock<std::mutex> lock(pool_mutex);
auto iterator = pointer_to_block_index.find(pointer);
if (iterator != pointer_to_block_index.end()) {
std::size_t block_index = iterator->second;
if (block_index < memory_blocks.size()) {
MemoryBlock& block = memory_blocks[block_index];
if (block.is_in_use) {
block.is_in_use = false;
block.last_access_timestamp = get_current_timestamp();
total_used_bytes -= block.block_size;
}
}
}
}
void MemoryPool::clear_unused_blocks() {
std::vector<std::size_t> indices_to_remove;
for (std::size_t index = 0; index < memory_blocks.size(); ++index) {
if (!memory_blocks[index].is_in_use) {
indices_to_remove.push_back(index);
}
}
std::sort(indices_to_remove.rbegin(), indices_to_remove.rend());
for (std::size_t index : indices_to_remove) {
std::uint8_t* pointer = memory_blocks[index].data.get();
total_allocated_bytes -= memory_blocks[index].block_size;
pointer_to_block_index.erase(pointer);
memory_blocks.erase(memory_blocks.begin() + static_cast<std::ptrdiff_t>(index));
}
for (std::size_t index = 0; index < memory_blocks.size(); ++index) {
pointer_to_block_index[memory_blocks[index].data.get()] = index;
}
}
void MemoryPool::reset_pool() {
std::unique_lock<std::mutex> lock(pool_mutex);
memory_blocks.clear();
pointer_to_block_index.clear();
total_allocated_bytes = 0;
total_used_bytes = 0;
}
std::size_t MemoryPool::get_total_allocated_bytes() const {
std::unique_lock<std::mutex> lock(pool_mutex);
return total_allocated_bytes;
}
std::size_t MemoryPool::get_total_used_bytes() const {
std::unique_lock<std::mutex> lock(pool_mutex);
return total_used_bytes;
}
std::size_t MemoryPool::get_block_count() const {
std::unique_lock<std::mutex> lock(pool_mutex);
return memory_blocks.size();
}
std::size_t MemoryPool::find_suitable_block_index(std::size_t requested_size) const {
std::size_t best_fit_index = static_cast<std::size_t>(-1);
std::size_t best_fit_size = static_cast<std::size_t>(-1);
for (std::size_t index = 0; index < memory_blocks.size(); ++index) {
const MemoryBlock& block = memory_blocks[index];
if (!block.is_in_use && block.block_size >= requested_size) {
if (block.block_size < best_fit_size) {
best_fit_size = block.block_size;
best_fit_index = index;
}
}
}
return best_fit_index;
}
void MemoryPool::create_new_block(std::size_t block_size) {
std::size_t aligned_size = ((block_size + 63) / 64) * 64;
memory_blocks.push_back(MemoryBlock{
std::make_unique<std::uint8_t[]>(aligned_size),
aligned_size,
false,
get_current_timestamp()
});
pointer_to_block_index[memory_blocks.back().data.get()] = memory_blocks.size() - 1;
total_allocated_bytes += aligned_size;
}
std::uint64_t MemoryPool::get_current_timestamp() const {
auto current_time = std::chrono::steady_clock::now();
auto duration = current_time.time_since_epoch();
return std::chrono::duration_cast<std::chrono::milliseconds>(duration).count();
}
ScopedMemoryAllocation::ScopedMemoryAllocation(MemoryPool& pool, std::size_t size)
: memory_pool_pointer(&pool)
, allocated_pointer(pool.allocate(size))
, allocation_size(size) {
}
ScopedMemoryAllocation::~ScopedMemoryAllocation() {
if (memory_pool_pointer != nullptr && allocated_pointer != nullptr) {
memory_pool_pointer->deallocate(allocated_pointer);
}
}
ScopedMemoryAllocation::ScopedMemoryAllocation(ScopedMemoryAllocation&& other) noexcept
: memory_pool_pointer(other.memory_pool_pointer)
, allocated_pointer(other.allocated_pointer)
, allocation_size(other.allocation_size) {
other.memory_pool_pointer = nullptr;
other.allocated_pointer = nullptr;
other.allocation_size = 0;
}
ScopedMemoryAllocation& ScopedMemoryAllocation::operator=(ScopedMemoryAllocation&& other) noexcept {
if (this != &other) {
if (memory_pool_pointer != nullptr && allocated_pointer != nullptr) {
memory_pool_pointer->deallocate(allocated_pointer);
}
memory_pool_pointer = other.memory_pool_pointer;
allocated_pointer = other.allocated_pointer;
allocation_size = other.allocation_size;
other.memory_pool_pointer = nullptr;
other.allocated_pointer = nullptr;
other.allocation_size = 0;
}
return *this;
}
std::uint8_t* ScopedMemoryAllocation::get() const {
return allocated_pointer;
}
std::size_t ScopedMemoryAllocation::size() const {
return allocation_size;
}
}