|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
|
#include <ATen/mps/MPSAllocatorInterface.h>
|
|
|
#include <ATen/mps/MPSEvent.h>
|
|
|
#include <ATen/mps/MPSStream.h>
|
|
|
|
|
|
#include <c10/util/flat_hash_map.h>
|
|
|
#include <mach/vm_page_size.h>
|
|
|
#include <cstdio>
|
|
|
#include <mutex>
|
|
|
#include <set>
|
|
|
#include <unordered_set>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
namespace at::mps::HeapAllocator {
|
|
|
|
|
|
static const size_t kMaxSmallAlloc = MB(1);
|
|
|
static const size_t kMinLargeAlloc = MB(10);
|
|
|
static const size_t kRoundLarge = MB(2);
|
|
|
static const size_t kSmallHeap = MB(8);
|
|
|
static const size_t kLargeHeap = MB(32);
|
|
|
static const size_t kXLargeHeapD =
|
|
|
MB(128);
|
|
|
static const size_t kXLargeHeapU =
|
|
|
MB(1024);
|
|
|
static const size_t kMaxScalarAlloc = (sizeof(int64_t));
|
|
|
|
|
|
|
|
|
enum UsageFlags : uint32_t {
|
|
|
PRIVATE = 0,
|
|
|
SMALL = (1 << 0),
|
|
|
SHARED = (1 << 1),
|
|
|
MANAGED = (1 << 2),
|
|
|
HAZARD = (1 << 3),
|
|
|
SCALAR = (1 << 4),
|
|
|
};
|
|
|
|
|
|
enum DebugVerbosity : uint32_t {
|
|
|
SILENT = 0,
|
|
|
PROFILING = (1 << 0),
|
|
|
ALLOCATIONS = (1 << 1),
|
|
|
RECYCLES = (1 << 2),
|
|
|
RELEASES = (1 << 3),
|
|
|
LARGE_ONLY = (1 << 4),
|
|
|
};
|
|
|
|
|
|
struct HeapBlock;
|
|
|
|
|
|
struct BufferBlock {
|
|
|
id<MTLBuffer> buffer;
|
|
|
void* cpu_ptr = nullptr;
|
|
|
size_t size;
|
|
|
size_t requested_size;
|
|
|
|
|
|
std::vector<int64_t> shape;
|
|
|
bool in_use = false;
|
|
|
HeapBlock* heap;
|
|
|
id_t buf_id;
|
|
|
|
|
|
uint32_t gc_count = 0;
|
|
|
uint32_t use_count = 0;
|
|
|
|
|
|
static uint64_t buffer_counter;
|
|
|
|
|
|
MPSEventPtr event;
|
|
|
|
|
|
BufferBlock(size_t Size, size_t RequestedSize = 0, const id<MTLBuffer> Buffer = nullptr, HeapBlock* Heap = nullptr)
|
|
|
: buffer(Buffer), size(Size), requested_size(RequestedSize), heap(Heap), buf_id(Buffer ? ++buffer_counter : 0) {}
|
|
|
|
|
|
static bool Comparator(const BufferBlock* a, const BufferBlock* b) {
|
|
|
return (a->size != b->size) ? a->size < b->size : (uintptr_t)a->buffer < (uintptr_t)b->buffer;
|
|
|
}
|
|
|
static size_t alignUp(size_t Size, size_t Alignment) {
|
|
|
assert(((Alignment - 1) & Alignment) == 0);
|
|
|
return ((Size + Alignment - 1) & ~(Alignment - 1));
|
|
|
}
|
|
|
uint32_t retainCount() const {
|
|
|
return [buffer retainCount];
|
|
|
}
|
|
|
};
|
|
|
typedef bool (*BufferComparison)(const BufferBlock*, const BufferBlock*);
|
|
|
|
|
|
struct BufferPool;
|
|
|
struct AllocParams {
|
|
|
AllocParams(size_t Alloc_Size, size_t Requested_Size, BufferPool* Pool)
|
|
|
: search_key(Alloc_Size), pool(Pool), requested_size(Requested_Size) {}
|
|
|
size_t size() const {
|
|
|
return search_key.size;
|
|
|
}
|
|
|
|
|
|
BufferBlock search_key;
|
|
|
BufferPool* pool;
|
|
|
BufferBlock* buffer_block = nullptr;
|
|
|
size_t requested_size;
|
|
|
|
|
|
|
|
|
bool has_memory_pressure = false;
|
|
|
|
|
|
bool has_unified_memory = true;
|
|
|
};
|
|
|
|
|
|
struct HeapBlock {
|
|
|
id<MTLHeap> heap;
|
|
|
struct {
|
|
|
size_t total, available;
|
|
|
} size;
|
|
|
BufferPool* pool;
|
|
|
unsigned int n_buffers = 0;
|
|
|
id_t heap_id;
|
|
|
|
|
|
bool is_split;
|
|
|
|
|
|
static uint64_t heap_counter;
|
|
|
|
|
|
HeapBlock(size_t Size, const id<MTLHeap> Heap = nullptr, BufferPool* Pool = nullptr)
|
|
|
: heap(Heap),
|
|
|
size({.total = Size, .available = Size}),
|
|
|
pool(Pool),
|
|
|
heap_id(Heap ? ++heap_counter : 0),
|
|
|
is_split(true) {}
|
|
|
|
|
|
static MTLResourceOptions getOptions(uint32_t usage) {
|
|
|
|
|
|
MTLResourceOptions options = MTLResourceCPUCacheModeDefaultCache;
|
|
|
|
|
|
if (usage & UsageFlags::MANAGED)
|
|
|
options |= MTLResourceStorageModeManaged;
|
|
|
else if (usage & UsageFlags::SHARED)
|
|
|
options |= MTLResourceStorageModeShared;
|
|
|
else
|
|
|
options |= MTLResourceStorageModePrivate;
|
|
|
|
|
|
options |=
|
|
|
(usage & UsageFlags::HAZARD) ? MTLResourceHazardTrackingModeTracked : MTLResourceHazardTrackingModeUntracked;
|
|
|
|
|
|
return options;
|
|
|
}
|
|
|
|
|
|
static HeapBlock* createHeapBlock(AllocParams& params, id<MTLDevice> device, uint32_t usage) {
|
|
|
HeapBlock* heapBlock = nullptr;
|
|
|
bool is_split = true;
|
|
|
const size_t size = params.size();
|
|
|
MTLHeapDescriptor* d = [MTLHeapDescriptor new];
|
|
|
if (d) {
|
|
|
const size_t kXLargeHeap = params.has_unified_memory ? kXLargeHeapU : kXLargeHeapD;
|
|
|
if (size <= kMaxSmallAlloc) {
|
|
|
d.size = kSmallHeap;
|
|
|
} else if (size < kMinLargeAlloc) {
|
|
|
d.size = kLargeHeap;
|
|
|
} else if (size < kXLargeHeap / 2 && !params.has_memory_pressure) {
|
|
|
d.size = kXLargeHeap;
|
|
|
} else {
|
|
|
d.size = kRoundLarge * ((size + kRoundLarge - 1) / kRoundLarge);
|
|
|
is_split = false;
|
|
|
}
|
|
|
d.storageMode = (usage & UsageFlags::SHARED) ? MTLStorageModeShared : MTLStorageModePrivate;
|
|
|
d.cpuCacheMode = MTLCPUCacheModeDefaultCache;
|
|
|
|
|
|
|
|
|
d.hazardTrackingMode =
|
|
|
(usage & UsageFlags::HAZARD) ? MTLHazardTrackingModeTracked : MTLHazardTrackingModeUntracked;
|
|
|
d.resourceOptions = getOptions(usage);
|
|
|
d.type = MTLHeapTypeAutomatic;
|
|
|
id<MTLHeap> heap = [device newHeapWithDescriptor:d];
|
|
|
if (heap) {
|
|
|
[heap setPurgeableState:MTLPurgeableStateNonVolatile];
|
|
|
const size_t heap_size = heapAvailableSize(heap);
|
|
|
heapBlock = new HeapBlock(heap_size, heap, params.pool);
|
|
|
if (heapBlock) {
|
|
|
heapBlock->is_split = is_split;
|
|
|
}
|
|
|
}
|
|
|
[d release];
|
|
|
}
|
|
|
return heapBlock;
|
|
|
}
|
|
|
static bool Comparator(const HeapBlock* a, const HeapBlock* b) {
|
|
|
return (a->size.available != b->size.available) ? a->size.available < b->size.available
|
|
|
: (uintptr_t)a->heap < (uintptr_t)b->heap;
|
|
|
}
|
|
|
static NSUInteger heapAvailableSize(id<MTLHeap> heap, size_t Alignment = vm_page_size) {
|
|
|
return [heap maxAvailableSizeWithAlignment:Alignment];
|
|
|
}
|
|
|
NSUInteger Size() {
|
|
|
return [heap size];
|
|
|
}
|
|
|
id<MTLBuffer> newMTLBuffer(size_t length, uint32_t usage) {
|
|
|
id<MTLBuffer> buf = [heap newBufferWithLength:length options:getOptions(usage)];
|
|
|
if (buf) {
|
|
|
updateAvailableSize();
|
|
|
n_buffers++;
|
|
|
}
|
|
|
return buf;
|
|
|
}
|
|
|
|
|
|
uint32_t releaseMTLBuffer(id<MTLBuffer>& buffer) {
|
|
|
const uint32_t retainCount = [buffer retainCount];
|
|
|
[buffer release];
|
|
|
buffer = nil;
|
|
|
updateAvailableSize();
|
|
|
n_buffers--;
|
|
|
return retainCount;
|
|
|
}
|
|
|
|
|
|
uint32_t releaseMTLHeap() {
|
|
|
const uint32_t retainCount = [heap retainCount];
|
|
|
TORCH_INTERNAL_ASSERT(!n_buffers);
|
|
|
[heap setPurgeableState:MTLPurgeableStateEmpty];
|
|
|
[heap release];
|
|
|
heap = nil;
|
|
|
size.available = 0;
|
|
|
return retainCount;
|
|
|
}
|
|
|
uint32_t retainCount() const {
|
|
|
return [heap retainCount];
|
|
|
}
|
|
|
void updateAvailableSize() {
|
|
|
size.available = heapAvailableSize(heap);
|
|
|
}
|
|
|
};
|
|
|
typedef bool (*HeapComparison)(const HeapBlock*, const HeapBlock*);
|
|
|
|
|
|
struct BufferPool {
|
|
|
enum class Kind {
|
|
|
PRIVATE_SMALL,
|
|
|
PRIVATE_LARGE,
|
|
|
SHARED_SMALL,
|
|
|
SHARED_LARGE,
|
|
|
SCALAR,
|
|
|
};
|
|
|
|
|
|
BufferPool(const id<MTLDevice> Device, uint32_t Usage)
|
|
|
: device(Device), usage(Usage), heaps(HeapBlock::Comparator), available_buffers(BufferBlock::Comparator) {}
|
|
|
|
|
|
const id<MTLDevice> device;
|
|
|
|
|
|
const uint32_t usage;
|
|
|
|
|
|
uint32_t n_buffers = 0;
|
|
|
|
|
|
size_t allocated_size = 0;
|
|
|
|
|
|
size_t available_size = 0;
|
|
|
|
|
|
std::set<HeapBlock*, HeapComparison> heaps;
|
|
|
|
|
|
std::set<BufferBlock*, BufferComparison> available_buffers;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::unordered_set<BufferBlock*> buffers_pending_free;
|
|
|
|
|
|
std::unordered_set<HeapBlock*> heaps_pending_update;
|
|
|
};
|
|
|
|
|
|
class MPSHeapAllocatorImpl {
|
|
|
public:
|
|
|
explicit MPSHeapAllocatorImpl()
|
|
|
: m_device(at::mps::MPSDevice::getInstance()->device()),
|
|
|
m_max_buffer_size([m_device maxBufferLength]),
|
|
|
m_stream(getDefaultMPSStream()),
|
|
|
m_event_pool(getMPSEventPool()) {
|
|
|
init_allocator();
|
|
|
}
|
|
|
~MPSHeapAllocatorImpl() {
|
|
|
emptyCache();
|
|
|
}
|
|
|
|
|
|
id<MTLBuffer> malloc(size_t size, uint32_t usage);
|
|
|
|
|
|
void free(void* ptr);
|
|
|
|
|
|
void emptyCache();
|
|
|
|
|
|
void freeInactiveBuffers();
|
|
|
|
|
|
bool isSharedBuffer(const void* ptr);
|
|
|
|
|
|
ssize_t getUnalignedBufferSize(const void* ptr);
|
|
|
|
|
|
void setBufferShape(const void* ptr, const IntArrayRef& shape);
|
|
|
|
|
|
IntArrayRef getBufferShape(const void* ptr);
|
|
|
|
|
|
id_t getBufferId(const void* ptr);
|
|
|
|
|
|
id<MTLBuffer> allocScalarBufferWithValue(void* value, size_t size);
|
|
|
|
|
|
|
|
|
std::pair<const void*, uint32_t> getSharedBufferPtr(const void* buffer);
|
|
|
|
|
|
|
|
|
bool recordEvents(c10::ArrayRef<const void*> buffers);
|
|
|
|
|
|
|
|
|
|
|
|
bool waitForEvents(c10::ArrayRef<const void*> buffers);
|
|
|
|
|
|
|
|
|
|
|
|
ssize_t getLowWatermarkValue();
|
|
|
|
|
|
void setLowWatermarkRatio(double ratio);
|
|
|
|
|
|
void setHighWatermarkRatio(double ratio);
|
|
|
|
|
|
size_t getLowWatermarkLimit() const {
|
|
|
return m_low_watermark_limit;
|
|
|
}
|
|
|
|
|
|
size_t getHighWatermarkLimit() const {
|
|
|
return m_max_total_allowed_size;
|
|
|
}
|
|
|
|
|
|
size_t getTotalAllocatedMemory() const {
|
|
|
return m_total_allocated_memory;
|
|
|
}
|
|
|
|
|
|
size_t getCurrentAllocatedMemory() const {
|
|
|
return m_current_allocated_memory;
|
|
|
}
|
|
|
|
|
|
|
|
|
size_t getDriverAllocatedMemory() const {
|
|
|
return current_allocated_size();
|
|
|
}
|
|
|
|
|
|
size_t getRecommendedMaxMemory() const {
|
|
|
return max_device_size();
|
|
|
}
|
|
|
|
|
|
uint32_t getDebugVerbosity() const {
|
|
|
return m_debug_verbosity;
|
|
|
}
|
|
|
|
|
|
inline id<MTLDevice> Device() const {
|
|
|
return m_device;
|
|
|
}
|
|
|
|
|
|
inline std::string format_size(uint64_t size) const;
|
|
|
|
|
|
private:
|
|
|
|
|
|
constexpr static double default_high_watermark_ratio = 1.7;
|
|
|
|
|
|
constexpr static double default_high_watermark_upper_bound = 2.0;
|
|
|
|
|
|
|
|
|
constexpr static double default_low_watermark_ratio_unified = 1.4;
|
|
|
constexpr static double default_low_watermark_ratio_discrete = 1.0;
|
|
|
|
|
|
const id<MTLDevice> m_device;
|
|
|
std::recursive_mutex m_mutex;
|
|
|
|
|
|
ska::flat_hash_map<const void*, BufferBlock*> m_allocated_buffers;
|
|
|
|
|
|
ska::flat_hash_map<BufferPool::Kind, std::unique_ptr<BufferPool>> m_pools;
|
|
|
|
|
|
size_t m_total_allocated_memory = 0;
|
|
|
|
|
|
size_t m_current_allocated_memory = 0;
|
|
|
|
|
|
size_t m_max_buffer_size = 0;
|
|
|
|
|
|
size_t m_max_total_allowed_size = 0;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
double m_high_watermark_ratio;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
double m_low_watermark_ratio;
|
|
|
|
|
|
size_t m_low_watermark_limit;
|
|
|
|
|
|
uint32_t m_debug_verbosity;
|
|
|
|
|
|
MPSStream* m_stream;
|
|
|
|
|
|
std::shared_ptr<MPSEventPool> m_event_pool;
|
|
|
|
|
|
void init_allocator();
|
|
|
void init_buffer_pools();
|
|
|
HeapBlock* get_free_heap(AllocParams& params);
|
|
|
bool get_free_buffer(AllocParams& params);
|
|
|
BufferBlock* get_allocated_buffer_block(const void* ptr);
|
|
|
BufferBlock* alloc_buffer_block(size_t size, uint32_t usage);
|
|
|
bool alloc_buffer(AllocParams& params);
|
|
|
void free_buffer(BufferBlock* buffer_block);
|
|
|
|
|
|
bool release_buffer(BufferBlock* buffer_block, bool remove_empty_heap = true);
|
|
|
void release_buffers(BufferPool& pool);
|
|
|
bool release_available_cached_buffers(AllocParams& params);
|
|
|
bool release_cached_buffers();
|
|
|
|
|
|
void garbage_collect_cached_buffers(AllocParams& params);
|
|
|
|
|
|
|
|
|
BufferPool& get_pool(size_t requested_size, size_t aligned_size, uint32_t usage);
|
|
|
|
|
|
|
|
|
size_t get_allocation_size(size_t size, uint32_t usage) const;
|
|
|
|
|
|
|
|
|
size_t max_device_size() const {
|
|
|
return [m_device recommendedMaxWorkingSetSize];
|
|
|
}
|
|
|
|
|
|
|
|
|
size_t current_allocated_size() const {
|
|
|
return [m_device currentAllocatedSize];
|
|
|
}
|
|
|
|
|
|
bool trigger_memory_callbacks(BufferBlock* buffer_block, IMpsAllocatorCallback::EventType event) const {
|
|
|
for (const auto& name : MPSAllocatorCallbacksRegistry()->Keys()) {
|
|
|
MPSAllocatorCallbacksRegistry()->Create(name)->executeMPSAllocatorCallback(
|
|
|
buffer_block ? buffer_block->buffer : nullptr, event);
|
|
|
}
|
|
|
return true;
|
|
|
}
|
|
|
};
|
|
|
|
|
|
}
|
|
|
|