File size: 3,842 Bytes
4b9fefd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
// hccl_comm.h — minimal HCCL wrapper for TP=N AllReduce.
//
// Multi-process mode (each rank is a separate process, device 0 each):
//   - Rank 0 calls HcclGetRootInfo, writes to /tmp/hccl_root_info.bin
//   - Rank 1..N-1 wait for that file, read it
//   - All ranks call HcclCommInitRootInfo → shared HcclComm
//   - allreduce() does in-place HcclAllReduce with SUM op
//
// Launcher sets HCCL_WHITELIST_DISABLE=1, ASCEND_RT_VISIBLE_DEVICES=<rank>, etc.
#pragma once
#include <hccl/hccl.h>
#include <hccl/hccl_types.h>
#include <acl/acl.h>

#include <chrono>
#include <cstdio>
#include <cstring>
#include <string>
#include <thread>

#define HCCL_ROOT_INFO_PATH "/tmp/hccl_root_info.bin"

struct HcclCtx {
    HcclComm comm = nullptr;
    int tp_size = 1;
    int tp_rank = 0;
    bool initialized = false;
};

inline bool hccl_init(HcclCtx& ctx, int tp_size, int tp_rank) {
    if (tp_size <= 1) { ctx.tp_size = 1; ctx.tp_rank = 0; ctx.initialized = true; return true; }
    ctx.tp_size = tp_size;
    ctx.tp_rank = tp_rank;

    HcclRootInfo rootInfo;
    std::memset(&rootInfo, 0, sizeof(rootInfo));

    if (tp_rank == 0) {
        if (HcclGetRootInfo(&rootInfo) != HCCL_SUCCESS) {
            fprintf(stderr, "[HCCL] HcclGetRootInfo failed\n"); return false;
        }
        FILE* f = fopen(HCCL_ROOT_INFO_PATH, "wb");
        if (!f) { fprintf(stderr, "[HCCL] cannot write %s\n", HCCL_ROOT_INFO_PATH); return false; }
        fwrite(&rootInfo, sizeof(rootInfo), 1, f);
        fclose(f);
    } else {
        bool found = false;
        for (int r = 0; r < 600; r++) {   // 60s timeout
            FILE* f = fopen(HCCL_ROOT_INFO_PATH, "rb");
            if (f) {
                size_t rd = fread(&rootInfo, 1, sizeof(rootInfo), f);
                fclose(f);
                if (rd == sizeof(rootInfo)) { found = true; break; }
            }
            std::this_thread::sleep_for(std::chrono::milliseconds(100));
        }
        if (!found) { fprintf(stderr, "[HCCL] rank %d timeout waiting for root info\n", tp_rank); return false; }
    }

    HcclResult r = HcclCommInitRootInfo((uint32_t)tp_size, &rootInfo, (uint32_t)tp_rank, &ctx.comm);
    if (r != HCCL_SUCCESS) {
        fprintf(stderr, "[HCCL] HcclCommInitRootInfo failed: %d (rank=%d)\n", (int)r, tp_rank);
        return false;
    }
    ctx.initialized = true;
    fprintf(stderr, "[HCCL] rank %d/%d comm OK\n", tp_rank, tp_size);
    return true;
}

// In-place AllReduce SUM on BF16 tensor. dtype = HCCL_DATA_TYPE_BFP16.
inline bool hccl_allreduce_bf16(const HcclCtx& ctx, void* data, int64_t count, aclrtStream stream) {
    if (!ctx.initialized) return false;
    if (ctx.tp_size <= 1) return true;   // no-op

    HcclResult r = HcclAllReduce(data, data, (uint64_t)count,
                                  HCCL_DATA_TYPE_BFP16, HCCL_REDUCE_SUM,
                                  ctx.comm, stream);
    if (r != HCCL_SUCCESS) {
        fprintf(stderr, "[HCCL] AllReduce failed: %d\n", (int)r);
        return false;
    }
    return true;
}

// Broadcast buffer from root (rank 0) to all ranks. Used to share prompt tokens across ranks.
// `data_dev` must be device memory. dtype generic (e.g., HCCL_DATA_TYPE_INT32).
inline bool hccl_broadcast(const HcclCtx& ctx, void* data_dev, int64_t count,
                           HcclDataType dtype, uint32_t root, aclrtStream stream) {
    if (!ctx.initialized) return false;
    if (ctx.tp_size <= 1) return true;

    HcclResult r = HcclBroadcast(data_dev, (uint64_t)count, dtype, root, ctx.comm, stream);
    if (r != HCCL_SUCCESS) {
        fprintf(stderr, "[HCCL] Broadcast failed: %d\n", (int)r);
        return false;
    }
    return true;
}

inline void hccl_shutdown(HcclCtx& ctx) {
    if (ctx.comm) {
        HcclCommDestroy(ctx.comm);
        ctx.comm = nullptr;
    }
    ctx.initialized = false;
}