llm_mutil_npu / include /hccl_comm.h
xianglarry's picture
Initial C++ aclnn EAGER inference for Qwen3-235B-A22B MoE on Ascend 910 × 16 NPU
4b9fefd
// hccl_comm.h — minimal HCCL wrapper for TP=N AllReduce.
//
// Multi-process mode (each rank is a separate process, device 0 each):
// - Rank 0 calls HcclGetRootInfo, writes to /tmp/hccl_root_info.bin
// - Rank 1..N-1 wait for that file, read it
// - All ranks call HcclCommInitRootInfo → shared HcclComm
// - allreduce() does in-place HcclAllReduce with SUM op
//
// Launcher sets HCCL_WHITELIST_DISABLE=1, ASCEND_RT_VISIBLE_DEVICES=<rank>, etc.
#pragma once
#include <hccl/hccl.h>
#include <hccl/hccl_types.h>
#include <acl/acl.h>
#include <chrono>
#include <cstdio>
#include <cstring>
#include <string>
#include <thread>
#define HCCL_ROOT_INFO_PATH "/tmp/hccl_root_info.bin"
struct HcclCtx {
HcclComm comm = nullptr;
int tp_size = 1;
int tp_rank = 0;
bool initialized = false;
};
inline bool hccl_init(HcclCtx& ctx, int tp_size, int tp_rank) {
if (tp_size <= 1) { ctx.tp_size = 1; ctx.tp_rank = 0; ctx.initialized = true; return true; }
ctx.tp_size = tp_size;
ctx.tp_rank = tp_rank;
HcclRootInfo rootInfo;
std::memset(&rootInfo, 0, sizeof(rootInfo));
if (tp_rank == 0) {
if (HcclGetRootInfo(&rootInfo) != HCCL_SUCCESS) {
fprintf(stderr, "[HCCL] HcclGetRootInfo failed\n"); return false;
}
FILE* f = fopen(HCCL_ROOT_INFO_PATH, "wb");
if (!f) { fprintf(stderr, "[HCCL] cannot write %s\n", HCCL_ROOT_INFO_PATH); return false; }
fwrite(&rootInfo, sizeof(rootInfo), 1, f);
fclose(f);
} else {
bool found = false;
for (int r = 0; r < 600; r++) { // 60s timeout
FILE* f = fopen(HCCL_ROOT_INFO_PATH, "rb");
if (f) {
size_t rd = fread(&rootInfo, 1, sizeof(rootInfo), f);
fclose(f);
if (rd == sizeof(rootInfo)) { found = true; break; }
}
std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
if (!found) { fprintf(stderr, "[HCCL] rank %d timeout waiting for root info\n", tp_rank); return false; }
}
HcclResult r = HcclCommInitRootInfo((uint32_t)tp_size, &rootInfo, (uint32_t)tp_rank, &ctx.comm);
if (r != HCCL_SUCCESS) {
fprintf(stderr, "[HCCL] HcclCommInitRootInfo failed: %d (rank=%d)\n", (int)r, tp_rank);
return false;
}
ctx.initialized = true;
fprintf(stderr, "[HCCL] rank %d/%d comm OK\n", tp_rank, tp_size);
return true;
}
// In-place AllReduce SUM on BF16 tensor. dtype = HCCL_DATA_TYPE_BFP16.
inline bool hccl_allreduce_bf16(const HcclCtx& ctx, void* data, int64_t count, aclrtStream stream) {
if (!ctx.initialized) return false;
if (ctx.tp_size <= 1) return true; // no-op
HcclResult r = HcclAllReduce(data, data, (uint64_t)count,
HCCL_DATA_TYPE_BFP16, HCCL_REDUCE_SUM,
ctx.comm, stream);
if (r != HCCL_SUCCESS) {
fprintf(stderr, "[HCCL] AllReduce failed: %d\n", (int)r);
return false;
}
return true;
}
// Broadcast buffer from root (rank 0) to all ranks. Used to share prompt tokens across ranks.
// `data_dev` must be device memory. dtype generic (e.g., HCCL_DATA_TYPE_INT32).
inline bool hccl_broadcast(const HcclCtx& ctx, void* data_dev, int64_t count,
HcclDataType dtype, uint32_t root, aclrtStream stream) {
if (!ctx.initialized) return false;
if (ctx.tp_size <= 1) return true;
HcclResult r = HcclBroadcast(data_dev, (uint64_t)count, dtype, root, ctx.comm, stream);
if (r != HCCL_SUCCESS) {
fprintf(stderr, "[HCCL] Broadcast failed: %d\n", (int)r);
return false;
}
return true;
}
inline void hccl_shutdown(HcclCtx& ctx) {
if (ctx.comm) {
HcclCommDestroy(ctx.comm);
ctx.comm = nullptr;
}
ctx.initialized = false;
}