File size: 3,509 Bytes
4b9fefd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
// device_weights.h — load safetensors weights to device memory with proper TP shard.
//
// For M3 (attention only): loads attention + norm weights. MoE expert weights come in M4.
//
#pragma once
#include "acl_common.h"
#include "model_config.h"
#include "safetensors_loader.h"

#include <string>
#include <unordered_map>
#include <vector>

// Per-layer MoE weights on device (BF16).
// After loading: weights are in GMM-ready layout [E, K_in, N_out] row-major contiguous.
// For gate/up:  K_in=D,   N_out=I_per_rank
// For down:     K_in=I,   N_out=D
struct LayerMoEWeights {
    DeviceBuffer router;       // [E, D]            BF16 replicated
    DeviceBuffer gate_exps;    // [E, D, I_per_rank]  (permuted from HF [E, I, D])
    DeviceBuffer up_exps;      // [E, D, I_per_rank]
    DeviceBuffer down_exps;    // [E, I_per_rank, D]  (permuted from HF [E, D, I])
};

// Per-layer attention weights on device (BF16 unless noted).
struct LayerAttnWeights {
    DeviceBuffer input_layernorm;          // [D]            BF16
    DeviceBuffer post_attention_layernorm; // [D]            BF16
    // Q/K/V/O projections. HF stores as [out, in] BF16.
    // For M3 we keep HF layout as-is; matmul wrappers handle the transpose via aclnnMm semantics.
    DeviceBuffer q_proj;   // [Q_full,  D] on rank, but physical stored as [Q_rank, D] (sliced by head)
    DeviceBuffer k_proj;   // [KV, D]       (replicated if tp_size > num_kv_heads)
    DeviceBuffer v_proj;   // [KV, D]
    DeviceBuffer o_proj;   // [D, Q_rank]   (row-parallel on Q dim)
    DeviceBuffer q_norm;   // [head_dim]    BF16 (Qwen3 per-head norm)
    DeviceBuffer k_norm;   // [head_dim]    BF16
};

// Shared model weights (replicated across ranks).
struct SharedWeights {
    DeviceBuffer embed_tokens;   // [vocab, D]
    DeviceBuffer lm_head;        // [vocab, D]
    DeviceBuffer final_norm;     // [D]
};

class DeviceWeightsLoader {
public:
    DeviceWeightsLoader(SafetensorsLoader& st, const ModelConfig& cfg)
        : st_(st), cfg_(cfg) {}

    // Load shared (embed, norm, lm_head). Replicated on every rank.
    bool load_shared(SharedWeights& out);

    // Load ONE attention layer's weights with TP sharding.
    bool load_attention(int layer_idx, LayerAttnWeights& out);

    // Load ONE MoE layer's weights. Stacks 128 experts and permutes to GMM-ready layout.
    // stream: ACL stream for the permute op (aclnnInplaceCopy).
    bool load_moe(int layer_idx, aclrtStream stream, LayerMoEWeights& out);

    // Expose underlying safetensors for direct access (diagnostic use).
    SafetensorsLoader& st() { return st_; }

private:
    SafetensorsLoader& st_;
    const ModelConfig& cfg_;

    // Helper: load HF tensor (full shape) into device buffer (simple H2D).
    bool load_tensor_full_(const std::string& name, DeviceBuffer& buf);

    // Helper: load HF tensor and keep only [row_lo, row_hi) of first dim (TP shard by "out" dim).
    // HF format: tensor has shape [D0, D1, ...] stored row-major. We take rows [lo, hi) to form
    // a sharded tensor of shape [hi-lo, D1, ...].
    bool load_tensor_row_slice_(const std::string& name,
                                 int64_t row_lo, int64_t row_hi,
                                 DeviceBuffer& buf);

    // TP shard by "in" dim (second axis for 2D, etc.) — used for o_proj (row-parallel).
    bool load_tensor_col_slice_(const std::string& name,
                                 int64_t col_lo, int64_t col_hi,
                                 DeviceBuffer& buf);
};