File size: 2,786 Bytes
c1af2fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
// Copyright © 2023 Apple Inc.
#pragma once
#include <ATen/core/ATen_fwd.h>
#include <c10/core/Allocator.h>
#include <c10/util/Registry.h>
#define MB(x) (x * 1048576UL)
namespace at::mps {
// this is a public interface to access MPSAllocator.
// Do not declare methods that would depend on MPS or Metal frameworks.
class IMPSAllocator : public c10::Allocator {
public:
// see the comments in MPSAllocator.h for the description of these methods.
virtual void emptyCache() const = 0;
virtual void freeInactiveBuffers() const = 0;
virtual ssize_t getUnalignedBufferSize(const void* ptr) const = 0;
virtual IntArrayRef getBufferShape(const void* ptr) const = 0;
virtual id_t getBufferId(const void* ptr) const = 0;
virtual void setBufferShape(const void* ptr, const IntArrayRef& shape)
const = 0;
virtual bool isSharedBuffer(const void* ptr) const = 0;
virtual bool isSharedStorageSupported() const = 0;
virtual c10::DataPtr allocScalarBufferWithValue(void* value, size_t size)
const = 0;
virtual std::string formatSize(size_t size) const = 0;
virtual void setLowWatermarkRatio(double ratio) const = 0;
virtual void setHighWatermarkRatio(double ratio) const = 0;
virtual ssize_t getLowWatermarkValue() const = 0;
virtual size_t getLowWatermarkLimit() const = 0;
virtual size_t getHighWatermarkLimit() const = 0;
virtual size_t getTotalAllocatedMemory() const = 0;
virtual size_t getCurrentAllocatedMemory() const = 0;
virtual size_t getDriverAllocatedMemory() const = 0;
virtual size_t getRecommendedMaxMemory() const = 0;
virtual std::pair<const void*, uint32_t> getSharedBufferPtr(
const void* ptr) const = 0;
virtual bool recordEvents(c10::ArrayRef<const void*> buffers) const = 0;
virtual bool waitForEvents(c10::ArrayRef<const void*> buffers) const = 0;
};
class IMpsAllocatorCallback {
public:
enum class EventType {
ALLOCATED, // buffer got allocated to be used immediately
RECYCLED, // buffer pulled from free list to be reused
FREED, // buffer put to free list for future recycling
RELEASED, // buffer memory released
ALLOCATION_FAILED // buffer allocation failed
};
virtual ~IMpsAllocatorCallback() = default;
virtual void executeMPSAllocatorCallback(void* ptr, EventType event) = 0;
};
// MPS allocator will execute every registered callback when a block of memory
// is freed.
TORCH_DECLARE_REGISTRY(MPSAllocatorCallbacksRegistry, IMpsAllocatorCallback);
#define REGISTER_MPS_ALLOCATOR_CALLBACK(name, ...) \
C10_REGISTER_CLASS(MPSAllocatorCallbacksRegistry, name, __VA_ARGS__)
IMPSAllocator* getIMPSAllocator(bool sharedAllocator = false);
bool isMPSPinnedPtr(const void* data);
} // namespace at::mps
|