// hccl_comm.h — minimal HCCL wrapper for TP=N AllReduce. // // Multi-process mode (each rank is a separate process, device 0 each): // - Rank 0 calls HcclGetRootInfo, writes to /tmp/hccl_root_info.bin // - Rank 1..N-1 wait for that file, read it // - All ranks call HcclCommInitRootInfo → shared HcclComm // - allreduce() does in-place HcclAllReduce with SUM op // // Launcher sets HCCL_WHITELIST_DISABLE=1, ASCEND_RT_VISIBLE_DEVICES=, etc. #pragma once #include #include #include #include #include #include #include #include #define HCCL_ROOT_INFO_PATH "/tmp/hccl_root_info.bin" struct HcclCtx { HcclComm comm = nullptr; int tp_size = 1; int tp_rank = 0; bool initialized = false; }; inline bool hccl_init(HcclCtx& ctx, int tp_size, int tp_rank) { if (tp_size <= 1) { ctx.tp_size = 1; ctx.tp_rank = 0; ctx.initialized = true; return true; } ctx.tp_size = tp_size; ctx.tp_rank = tp_rank; HcclRootInfo rootInfo; std::memset(&rootInfo, 0, sizeof(rootInfo)); if (tp_rank == 0) { if (HcclGetRootInfo(&rootInfo) != HCCL_SUCCESS) { fprintf(stderr, "[HCCL] HcclGetRootInfo failed\n"); return false; } FILE* f = fopen(HCCL_ROOT_INFO_PATH, "wb"); if (!f) { fprintf(stderr, "[HCCL] cannot write %s\n", HCCL_ROOT_INFO_PATH); return false; } fwrite(&rootInfo, sizeof(rootInfo), 1, f); fclose(f); } else { bool found = false; for (int r = 0; r < 600; r++) { // 60s timeout FILE* f = fopen(HCCL_ROOT_INFO_PATH, "rb"); if (f) { size_t rd = fread(&rootInfo, 1, sizeof(rootInfo), f); fclose(f); if (rd == sizeof(rootInfo)) { found = true; break; } } std::this_thread::sleep_for(std::chrono::milliseconds(100)); } if (!found) { fprintf(stderr, "[HCCL] rank %d timeout waiting for root info\n", tp_rank); return false; } } HcclResult r = HcclCommInitRootInfo((uint32_t)tp_size, &rootInfo, (uint32_t)tp_rank, &ctx.comm); if (r != HCCL_SUCCESS) { fprintf(stderr, "[HCCL] HcclCommInitRootInfo failed: %d (rank=%d)\n", (int)r, tp_rank); return false; } ctx.initialized = true; fprintf(stderr, "[HCCL] rank %d/%d comm OK\n", tp_rank, tp_size); return true; } // In-place AllReduce SUM on BF16 tensor. dtype = HCCL_DATA_TYPE_BFP16. inline bool hccl_allreduce_bf16(const HcclCtx& ctx, void* data, int64_t count, aclrtStream stream) { if (!ctx.initialized) return false; if (ctx.tp_size <= 1) return true; // no-op HcclResult r = HcclAllReduce(data, data, (uint64_t)count, HCCL_DATA_TYPE_BFP16, HCCL_REDUCE_SUM, ctx.comm, stream); if (r != HCCL_SUCCESS) { fprintf(stderr, "[HCCL] AllReduce failed: %d\n", (int)r); return false; } return true; } // Broadcast buffer from root (rank 0) to all ranks. Used to share prompt tokens across ranks. // `data_dev` must be device memory. dtype generic (e.g., HCCL_DATA_TYPE_INT32). inline bool hccl_broadcast(const HcclCtx& ctx, void* data_dev, int64_t count, HcclDataType dtype, uint32_t root, aclrtStream stream) { if (!ctx.initialized) return false; if (ctx.tp_size <= 1) return true; HcclResult r = HcclBroadcast(data_dev, (uint64_t)count, dtype, root, ctx.comm, stream); if (r != HCCL_SUCCESS) { fprintf(stderr, "[HCCL] Broadcast failed: %d\n", (int)r); return false; } return true; } inline void hccl_shutdown(HcclCtx& ctx) { if (ctx.comm) { HcclCommDestroy(ctx.comm); ctx.comm = nullptr; } ctx.initialized = false; }