Spaces:
Paused
Paused
| typedef id<MTLBuffer> MTLBuffer_t; | |
| typedef id<MTLComputeCommandEncoder> MTLComputeCommandEncoder_t; | |
| typedef void* MTLBuffer; | |
| typedef void* MTLBuffer_t; | |
| typedef void* MTLComputeCommandEncoder; | |
| typedef void* MTLComputeCommandEncoder_t; | |
| // utils | |
| static inline MTLBuffer_t getMTLBufferStorage(const at::Tensor& tensor) { | |
| return __builtin_bit_cast(MTLBuffer_t, tensor.storage().data()); | |
| } | |
| template <typename T, | |
| std::enable_if_t<!std::is_same<std::decay_t<T>, at::Tensor>::value, bool> = true> | |
| void setMTLArg(MTLComputeCommandEncoder_t encoder, int index, T&& t); | |
| template <typename T, | |
| std::enable_if_t<std::is_same<std::decay_t<T>, at::Tensor>::value, bool> = true> | |
| void setMTLArg(MTLComputeCommandEncoder_t encoder, int index, T&& t) { | |
| [encoder setBuffer:getMTLBufferStorage(t) offset:0 atIndex:index]; | |
| } | |
| template <typename T, std::enable_if_t<!std::is_same<std::decay_t<T>, at::Tensor>::value, bool>> | |
| void setMTLArg(MTLComputeCommandEncoder_t encoder, int index, T&& t) { | |
| [encoder setBytes:&t length:sizeof(t) atIndex:index]; | |
| } | |
| inline void setMTLArgsImpl(MTLComputeCommandEncoder_t, int) {} | |
| template <typename T, typename... Args> | |
| void setMTLArgsImpl(MTLComputeCommandEncoder_t encoder, int index, T&& t, Args&&... args) { | |
| setMTLArg(encoder, index, std::forward<T>(t)); | |
| setMTLArgsImpl(encoder, index + 1, std::forward<Args>(args)...); | |
| } | |
| template <typename... Args> | |
| void setMTLArgs(MTLComputeCommandEncoder_t encoder, MTLComputePipelineState_t pso, Args&&... args) { | |
| [encoder setComputePipelineState:pso]; | |
| setMTLArgsImpl(encoder, 0, std::forward<Args>(args)...); | |
| } | |