| #pragma once |
|
|
| #include <musa_runtime.h> |
| #include <musa.h> |
| #include <mublas.h> |
| #include <musa_bf16.h> |
| #include <musa_fp16.h> |
| #define CUBLAS_COMPUTE_16F CUDA_R_16F |
| #define CUBLAS_COMPUTE_32F CUDA_R_32F |
| #define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F |
| #define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT |
| #define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT |
| #define CUBLAS_OP_N MUBLAS_OP_N |
| #define CUBLAS_OP_T MUBLAS_OP_T |
| #define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS |
| #define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT |
| #define CUDA_R_16F MUSA_R_16F |
| #define CUDA_R_32F MUSA_R_32F |
| #define cublasComputeType_t cudaDataType_t |
| #define cublasCreate mublasCreate |
| #define cublasDestroy mublasDestroy |
| #define cublasGemmEx mublasGemmEx |
| #define cublasGemmBatchedEx mublasGemmBatchedEx |
| #define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx |
| #define cublasHandle_t mublasHandle_t |
| #define cublasSetMathMode mublasSetMathMode |
| #define cublasSetStream mublasSetStream |
| #define cublasSgemm mublasSgemm |
| #define cublasStatus_t mublasStatus_t |
| #define cublasOperation_t mublasOperation_t |
| #define cublasGetStatusString mublasStatus_to_string |
| #define cudaDataType_t musaDataType_t |
| #define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer |
| #define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess |
| #define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess |
| #define cudaDeviceProp musaDeviceProp |
| #define cudaDeviceSynchronize musaDeviceSynchronize |
| #define cudaError_t musaError_t |
| #define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled |
| #define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled |
| #define cudaEventCreateWithFlags musaEventCreateWithFlags |
| #define cudaEventDisableTiming musaEventDisableTiming |
| #define cudaEventRecord musaEventRecord |
| #define cudaEventSynchronize musaEventSynchronize |
| #define cudaEvent_t musaEvent_t |
| #define cudaEventDestroy musaEventDestroy |
| #define cudaFree musaFree |
| #define cudaFreeHost musaFreeHost |
| #define cudaGetDevice musaGetDevice |
| #define cudaGetDeviceCount musaGetDeviceCount |
| #define cudaGetDeviceProperties musaGetDeviceProperties |
| #define cudaGetErrorString musaGetErrorString |
| #define cudaGetLastError musaGetLastError |
| #define cudaHostRegister musaHostRegister |
| #define cudaHostRegisterPortable musaHostRegisterPortable |
| #define cudaHostRegisterReadOnly musaHostRegisterReadOnly |
| #define cudaHostUnregister musaHostUnregister |
| #define cudaLaunchHostFunc musaLaunchHostFunc |
| #define cudaMalloc musaMalloc |
| #define cudaMallocHost musaMallocHost |
| #define cudaMallocManaged musaMallocManaged |
| #define cudaMemcpy musaMemcpy |
| #define cudaMemcpyAsync musaMemcpyAsync |
| #define cudaMemcpyPeerAsync musaMemcpyPeerAsync |
| #define cudaMemcpy2DAsync musaMemcpy2DAsync |
| #define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice |
| #define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost |
| #define cudaMemcpyHostToDevice musaMemcpyHostToDevice |
| #define cudaMemcpyKind musaMemcpyKind |
| #define cudaMemset musaMemset |
| #define cudaMemsetAsync musaMemsetAsync |
| #define cudaMemGetInfo musaMemGetInfo |
| #define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize |
| #define cudaSetDevice musaSetDevice |
| #define cudaStreamCreateWithFlags musaStreamCreateWithFlags |
| #define cudaStreamDestroy musaStreamDestroy |
| #define cudaStreamFireAndForget musaStreamFireAndForget |
| #define cudaStreamNonBlocking musaStreamNonBlocking |
| #define cudaStreamPerThread musaStreamPerThread |
| #define cudaStreamSynchronize musaStreamSynchronize |
| #define cudaStreamWaitEvent musaStreamWaitEvent |
| #define cudaStream_t musaStream_t |
| #define cudaSuccess musaSuccess |
|
|
| |
| #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED |
| #define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE |
| #define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED |
| #define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED |
| #define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE |
| #define CUdevice MUdevice |
| #define CUdeviceptr MUdeviceptr |
| #define CUmemAccessDesc MUmemAccessDesc |
| #define CUmemAllocationProp MUmemAllocationProp |
| #define CUmemGenericAllocationHandle MUmemGenericAllocationHandle |
| #define cuDeviceGet muDeviceGet |
| #define cuDeviceGetAttribute muDeviceGetAttribute |
| #define cuMemAddressFree muMemAddressFree |
| #define cuMemAddressReserve muMemAddressReserve |
| #define cuMemCreate muMemCreate |
| #define cuMemGetAllocationGranularity muMemGetAllocationGranularity |
| #define cuMemMap muMemMap |
| #define cuMemRelease muMemRelease |
| #define cuMemSetAccess muMemSetAccess |
| #define cuMemUnmap muMemUnmap |
| #define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize |
| #define cudaFuncSetAttribute musaFuncSetAttribute |
| #define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms |
| #define make_cudaExtent make_musaExtent |
| #define make_cudaPitchedPtr make_musaPitchedPtr |
|
|
| |
| #define CUDA_SUCCESS MUSA_SUCCESS |
| #define CUresult MUresult |
| #define cuGetErrorString muGetErrorString |
| #define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure |
| #define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction |
| #define cudaGraphDestroy musaGraphDestroy |
| #define cudaGraphExecDestroy musaGraphExecDestroy |
| #define cudaGraphExec_t musaGraphExec_t |
| #define cudaGraphExecUpdate musaGraphExecUpdate |
| #define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult |
| #define cudaGraphGetNodes musaGraphGetNodes |
| #define cudaGraphInstantiate musaGraphInstantiate |
| #define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams |
| #define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams |
| #define cudaGraphLaunch musaGraphLaunch |
| #define cudaGraphNodeGetType musaGraphNodeGetType |
| #define cudaGraphNode_t musaGraphNode_t |
| #define cudaGraphNodeType musaGraphNodeType |
| #define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel |
| #define cudaGraph_t musaGraph_t |
| #define cudaKernelNodeParams musaKernelNodeParams |
| #define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed |
| #define cudaStreamEndCapture musaStreamEndCapture |
|
|
| typedef mt_bfloat16 nv_bfloat16; |
|
|