| |
| |
| |
| |
| |
| |
| |
| |
| |
| #pragma once |
| #include <hccl/hccl.h> |
| #include <hccl/hccl_types.h> |
| #include <acl/acl.h> |
|
|
| #include <chrono> |
| #include <cstdio> |
| #include <cstring> |
| #include <string> |
| #include <thread> |
|
|
| #define HCCL_ROOT_INFO_PATH "/tmp/hccl_root_info.bin" |
|
|
| struct HcclCtx { |
| HcclComm comm = nullptr; |
| int tp_size = 1; |
| int tp_rank = 0; |
| bool initialized = false; |
| }; |
|
|
| inline bool hccl_init(HcclCtx& ctx, int tp_size, int tp_rank) { |
| if (tp_size <= 1) { ctx.tp_size = 1; ctx.tp_rank = 0; ctx.initialized = true; return true; } |
| ctx.tp_size = tp_size; |
| ctx.tp_rank = tp_rank; |
|
|
| HcclRootInfo rootInfo; |
| std::memset(&rootInfo, 0, sizeof(rootInfo)); |
|
|
| if (tp_rank == 0) { |
| if (HcclGetRootInfo(&rootInfo) != HCCL_SUCCESS) { |
| fprintf(stderr, "[HCCL] HcclGetRootInfo failed\n"); return false; |
| } |
| FILE* f = fopen(HCCL_ROOT_INFO_PATH, "wb"); |
| if (!f) { fprintf(stderr, "[HCCL] cannot write %s\n", HCCL_ROOT_INFO_PATH); return false; } |
| fwrite(&rootInfo, sizeof(rootInfo), 1, f); |
| fclose(f); |
| } else { |
| bool found = false; |
| for (int r = 0; r < 600; r++) { |
| FILE* f = fopen(HCCL_ROOT_INFO_PATH, "rb"); |
| if (f) { |
| size_t rd = fread(&rootInfo, 1, sizeof(rootInfo), f); |
| fclose(f); |
| if (rd == sizeof(rootInfo)) { found = true; break; } |
| } |
| std::this_thread::sleep_for(std::chrono::milliseconds(100)); |
| } |
| if (!found) { fprintf(stderr, "[HCCL] rank %d timeout waiting for root info\n", tp_rank); return false; } |
| } |
|
|
| HcclResult r = HcclCommInitRootInfo((uint32_t)tp_size, &rootInfo, (uint32_t)tp_rank, &ctx.comm); |
| if (r != HCCL_SUCCESS) { |
| fprintf(stderr, "[HCCL] HcclCommInitRootInfo failed: %d (rank=%d)\n", (int)r, tp_rank); |
| return false; |
| } |
| ctx.initialized = true; |
| fprintf(stderr, "[HCCL] rank %d/%d comm OK\n", tp_rank, tp_size); |
| return true; |
| } |
|
|
| |
| inline bool hccl_allreduce_bf16(const HcclCtx& ctx, void* data, int64_t count, aclrtStream stream) { |
| if (!ctx.initialized) return false; |
| if (ctx.tp_size <= 1) return true; |
|
|
| HcclResult r = HcclAllReduce(data, data, (uint64_t)count, |
| HCCL_DATA_TYPE_BFP16, HCCL_REDUCE_SUM, |
| ctx.comm, stream); |
| if (r != HCCL_SUCCESS) { |
| fprintf(stderr, "[HCCL] AllReduce failed: %d\n", (int)r); |
| return false; |
| } |
| return true; |
| } |
|
|
| |
| |
| inline bool hccl_broadcast(const HcclCtx& ctx, void* data_dev, int64_t count, |
| HcclDataType dtype, uint32_t root, aclrtStream stream) { |
| if (!ctx.initialized) return false; |
| if (ctx.tp_size <= 1) return true; |
|
|
| HcclResult r = HcclBroadcast(data_dev, (uint64_t)count, dtype, root, ctx.comm, stream); |
| if (r != HCCL_SUCCESS) { |
| fprintf(stderr, "[HCCL] Broadcast failed: %d\n", (int)r); |
| return false; |
| } |
| return true; |
| } |
|
|
| inline void hccl_shutdown(HcclCtx& ctx) { |
| if (ctx.comm) { |
| HcclCommDestroy(ctx.comm); |
| ctx.comm = nullptr; |
| } |
| ctx.initialized = false; |
| } |
|
|