|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
|
#include <cstdint>
|
|
|
#include <utility>
|
|
|
|
|
|
#include <ATen/mps/MPSDevice.h>
|
|
|
#include <c10/core/DeviceGuard.h>
|
|
|
#include <c10/core/Stream.h>
|
|
|
#include <c10/util/Exception.h>
|
|
|
|
|
|
#ifdef __OBJC__
|
|
|
#include <Foundation/Foundation.h>
|
|
|
#include <Metal/Metal.h>
|
|
|
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
|
|
|
#include <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h>
|
|
|
typedef MPSCommandBuffer* MPSCommandBuffer_t;
|
|
|
typedef id<MTLCommandQueue> MTLCommandQueue_t;
|
|
|
typedef id<MTLComputeCommandEncoder> MTLComputeCommandEncoder_t;
|
|
|
typedef id<MTLSharedEvent> MTLSharedEvent_t;
|
|
|
typedef id<MTLDevice> MTLDevice_t;
|
|
|
typedef id<MTLBuffer> MTLBuffer_t;
|
|
|
#else
|
|
|
#include <dispatch/dispatch.h>
|
|
|
typedef void* MPSCommandBuffer_t;
|
|
|
typedef void* MPSGraph;
|
|
|
typedef void* MPSGraphExecutionDescriptor;
|
|
|
typedef void* MPSGraphCompilationDescriptor;
|
|
|
typedef void* MTLCommandQueue_t;
|
|
|
typedef void* MTLComputeCommandEncoder_t;
|
|
|
typedef void* MTLSharedEvent_t;
|
|
|
typedef void* MTLDevice_t;
|
|
|
typedef void* MTLBuffer_t;
|
|
|
typedef void* MTLCommandBufferHandler;
|
|
|
typedef void* NSDictionary;
|
|
|
#define nil NULL
|
|
|
#endif
|
|
|
|
|
|
namespace at::mps {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
enum class SyncType {
|
|
|
NONE,
|
|
|
COMMIT,
|
|
|
COMMIT_AND_WAIT,
|
|
|
COMMIT_AND_CONTINUE,
|
|
|
COMMIT_ADAPTIVE,
|
|
|
};
|
|
|
|
|
|
class TORCH_API MPSStream {
|
|
|
public:
|
|
|
enum Unchecked { UNCHECKED };
|
|
|
|
|
|
|
|
|
|
|
|
explicit MPSStream(Stream stream);
|
|
|
|
|
|
~MPSStream();
|
|
|
|
|
|
MTLCommandQueue_t commandQueue() const {
|
|
|
return _commandQueue;
|
|
|
}
|
|
|
|
|
|
dispatch_queue_t queue() const {
|
|
|
return _serialQueue;
|
|
|
}
|
|
|
|
|
|
MPSCommandBuffer_t commandBuffer();
|
|
|
MTLComputeCommandEncoder_t commandEncoder();
|
|
|
void endKernelCoalescing();
|
|
|
void synchronize(SyncType syncType);
|
|
|
void fill(MTLBuffer_t buffer, uint8_t value, size_t length, size_t offset, SyncType syncType = SyncType::NONE);
|
|
|
void copy(MTLBuffer_t srcBuffer,
|
|
|
MTLBuffer_t dstBuffer,
|
|
|
size_t length,
|
|
|
size_t srcOffset,
|
|
|
size_t dstOffset,
|
|
|
uint64_t profileId,
|
|
|
SyncType syncType = SyncType::NONE);
|
|
|
void copy_and_sync(MTLBuffer_t srcBuffer,
|
|
|
MTLBuffer_t dstBuffer,
|
|
|
size_t length,
|
|
|
size_t srcOffset,
|
|
|
size_t dstOffset,
|
|
|
bool non_blocking,
|
|
|
uint64_t profileId);
|
|
|
void executeMPSGraph(MPSGraph* mpsGraph,
|
|
|
NSDictionary* feeds,
|
|
|
NSDictionary* results,
|
|
|
SyncType syncType = SyncType::NONE);
|
|
|
void addCompletedHandler(MTLCommandBufferHandler block);
|
|
|
|
|
|
|
|
|
c10::DeviceIndex device_index() const {
|
|
|
return _stream.device_index();
|
|
|
}
|
|
|
|
|
|
MTLCommandQueue_t stream() const {
|
|
|
return _commandQueue;
|
|
|
}
|
|
|
|
|
|
MTLDevice_t device() const;
|
|
|
|
|
|
|
|
|
Stream unwrap() const {
|
|
|
return _stream;
|
|
|
}
|
|
|
|
|
|
private:
|
|
|
Stream _stream;
|
|
|
MTLCommandQueue_t _commandQueue = nil;
|
|
|
MPSCommandBuffer_t _commandBuffer = nil;
|
|
|
MPSCommandBuffer_t _prevCommandBuffer = nil;
|
|
|
MTLComputeCommandEncoder_t _commandEncoder = nil;
|
|
|
MPSGraphExecutionDescriptor* _executionDescriptor = nil;
|
|
|
MPSGraphCompilationDescriptor* _compilationDescriptor = nil;
|
|
|
dispatch_queue_t _serialQueue = nullptr;
|
|
|
|
|
|
bool _enableCommitAndContinue = true;
|
|
|
|
|
|
|
|
|
void commit();
|
|
|
void commitAndWait();
|
|
|
void commitAndContinue();
|
|
|
void flush();
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
TORCH_API MPSStream* getCurrentMPSStream();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
TORCH_API MPSStream* getDefaultMPSStream();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TORCH_API MPSStreamImpl {
|
|
|
public:
|
|
|
|
|
|
|
|
|
|
|
|
static MPSStream* getInstance();
|
|
|
|
|
|
private:
|
|
|
static MPSStream* _stream;
|
|
|
MPSStreamImpl();
|
|
|
};
|
|
|
|
|
|
}
|
|
|
|