|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
|
#include <ATen/core/ATen_fwd.h>
|
|
|
#include <c10/core/Allocator.h>
|
|
|
#include <c10/util/Registry.h>
|
|
|
|
|
|
#define MB(x) (x * 1048576UL)
|
|
|
|
|
|
namespace at::mps {
|
|
|
|
|
|
|
|
|
|
|
|
class IMPSAllocator : public c10::Allocator {
|
|
|
public:
|
|
|
|
|
|
virtual void emptyCache() const = 0;
|
|
|
virtual void freeInactiveBuffers() const = 0;
|
|
|
virtual ssize_t getUnalignedBufferSize(const void* ptr) const = 0;
|
|
|
virtual IntArrayRef getBufferShape(const void* ptr) const = 0;
|
|
|
virtual id_t getBufferId(const void* ptr) const = 0;
|
|
|
virtual void setBufferShape(const void* ptr, const IntArrayRef& shape)
|
|
|
const = 0;
|
|
|
virtual bool isSharedBuffer(const void* ptr) const = 0;
|
|
|
virtual bool isSharedStorageSupported() const = 0;
|
|
|
virtual c10::DataPtr allocScalarBufferWithValue(void* value, size_t size)
|
|
|
const = 0;
|
|
|
virtual std::string formatSize(size_t size) const = 0;
|
|
|
virtual void setLowWatermarkRatio(double ratio) const = 0;
|
|
|
virtual void setHighWatermarkRatio(double ratio) const = 0;
|
|
|
virtual ssize_t getLowWatermarkValue() const = 0;
|
|
|
virtual size_t getLowWatermarkLimit() const = 0;
|
|
|
virtual size_t getHighWatermarkLimit() const = 0;
|
|
|
virtual size_t getTotalAllocatedMemory() const = 0;
|
|
|
virtual size_t getCurrentAllocatedMemory() const = 0;
|
|
|
virtual size_t getDriverAllocatedMemory() const = 0;
|
|
|
virtual size_t getRecommendedMaxMemory() const = 0;
|
|
|
virtual std::pair<const void*, uint32_t> getSharedBufferPtr(
|
|
|
const void* ptr) const = 0;
|
|
|
virtual bool recordEvents(c10::ArrayRef<const void*> buffers) const = 0;
|
|
|
virtual bool waitForEvents(c10::ArrayRef<const void*> buffers) const = 0;
|
|
|
};
|
|
|
|
|
|
class IMpsAllocatorCallback {
|
|
|
public:
|
|
|
enum class EventType {
|
|
|
ALLOCATED,
|
|
|
RECYCLED,
|
|
|
FREED,
|
|
|
RELEASED,
|
|
|
ALLOCATION_FAILED
|
|
|
};
|
|
|
virtual ~IMpsAllocatorCallback() = default;
|
|
|
virtual void executeMPSAllocatorCallback(void* ptr, EventType event) = 0;
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
TORCH_DECLARE_REGISTRY(MPSAllocatorCallbacksRegistry, IMpsAllocatorCallback);
|
|
|
#define REGISTER_MPS_ALLOCATOR_CALLBACK(name, ...) \
|
|
|
C10_REGISTER_CLASS(MPSAllocatorCallbacksRegistry, name, __VA_ARGS__)
|
|
|
|
|
|
IMPSAllocator* getIMPSAllocator(bool sharedAllocator = false);
|
|
|
|
|
|
bool isMPSPinnedPtr(const void* data);
|
|
|
|
|
|
}
|
|
|
|