File size: 5,844 Bytes
4b9fefd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
// runner.h — multi-layer transformer Runner for Qwen3-235B-A22B.
//
// Owns: shared weights, per-layer attention + MoE weights, KV cache, scratch buffers.
// Provides: prefill(tokens) and decode(new_token) methods returning logits [vocab] on device.
//
// Memory budget at TP=1 for testing a SUBSET of layers (num_layers_to_load <= 94). Full 94-layer
// inference requires TP=16 where per-rank MoE fits ~28GB.
#pragma once
#include "acl_common.h"
#include "acl_runtime.h"
#include "aclnn_ops.h"
#include "device_weights.h"
#include "engine.h"
#include "hccl_comm.h"
#include "model_config.h"
#include "safetensors_loader.h"

#include <vector>

class Runner {
public:
    Runner() = default;
    ~Runner() = default;
    Runner(const Runner&) = delete;
    Runner& operator=(const Runner&) = delete;

    // Initialize runtime, open safetensors, load shared weights. tp_size/tp_rank configure
    // MoE + attention sharding. num_layers is how many transformer blocks to load (1..94).
    // max_seq is the maximum sequence length (for KV cache allocation).
    bool init(const std::string& model_dir, int tp_size, int tp_rank,
              int num_layers_to_load, int64_t max_seq, int device_id = 0);

    // Prefill: ingest S>=1 tokens, produces logits [vocab] for the LAST position. Populates KV
    // cache starting at position 0. `hidden_out` optionally returns the final hidden state [S, D].
    bool prefill(const int32_t* tokens, int64_t S, DeviceBuffer& logits_out);

    // Decode: take 1 new token, produce logits [vocab] from the new position.
    bool decode(int32_t token, DeviceBuffer& logits_out);

    // Batched decode: take S tokens as "candidate verify batch" at positions [past_len..past_len+S),
    // produce logits [S, vocab]. Uses causal-with-past mask (token i sees past+tokens[0..i]).
    // Foundation for speculative decoding / PLD.
    //   tokens: [S] int32
    //   S: 1 .. 16
    //   all_logits_out: will hold S * vocab_size * 2 bytes BF16, row-major [S, V]
    // Updates past_len by +S on success.
    bool decode_batch(const int32_t* tokens, int64_t S, DeviceBuffer& all_logits_out);

    // Warmup: run N dummy decode() calls (resetting cache) to pre-compile aclnn executors,
    // warm HCCL collective buffers, and stabilize NPU thermals. Improves first-N-token latency
    // by ~1 s (especially noticeable on short generations or REPL cold start).
    // Call after init(); safe to call multiple times. Does NOT affect past_len.
    void warmup(int iterations = 3);

    // Accessors
    const ModelConfig& cfg() const { return cfg_; }
    aclrtStream stream() { return rt_.stream(); }
    int64_t past_len() const { return past_len_; }
    void reset_cache() { past_len_ = 0; }
    // Rewind past_len by n. Used by speculative decoding to discard rejected draft tokens'
    // KV cache entries (they'll be overwritten by subsequent writes).
    void rewind_cache(int64_t n) { if (n > 0 && n <= past_len_) past_len_ -= n; }
    HcclCtx& hccl_ctx() { return hccl_ctx_; }

    // Profiling: set via LCA_PROFILE=1 env in main_cli. If enabled, decode() accumulates
    // per-phase wall-clock ms into the timer accumulators below.
    bool profile_enabled = false;
    double t_embed_ms = 0, t_layers_ms = 0, t_final_ms = 0;
    int64_t profile_calls = 0;
    void print_profile_summary() const;

private:
    // One-layer forward: x_in [S, D] → x_out [S, D] via attention + residual + MoE + residual.
    // Uses this layer's KV cache starting at past_len; caller updates past_len after each call.
    // batch_decode_mode: true for S>1 at past_len>0 (spec decoding) — uses custom causal mask
    //                    with past instead of the 2048×2048 prefill mask.
    void layer_forward_(int layer_idx, int64_t S, void* x_in, void* x_out,
                        bool batch_decode_mode = false);

    // Build causal-with-past mask in batch_mask_dev_ for decode_batch at current past_len.
    // Shape [1, 1, S, past_len+S] bool, mask[i, j] = 1 iff j > past_len+i.
    void build_batch_decode_mask_(int64_t S);

    // Final: final_norm + lm_head on last position → logits [vocab].
    void final_logits_(void* hidden_last /*[1, D]*/, DeviceBuffer& logits_out);

    // Batched final: final_norm + lm_head on [S, D] → logits [S, V].
    void final_logits_batch_(void* hidden /*[S, D]*/, int64_t S, DeviceBuffer& logits_out);

    AclRuntime rt_;
    SafetensorsLoader st_;
    ModelConfig cfg_;
    HcclCtx hccl_ctx_;
    int num_layers_ = 0;
    int64_t max_seq_ = 0;

    SharedWeights                 shared_;
    std::vector<LayerAttnWeights> attn_;
    std::vector<LayerMoEWeights>  moe_;

    // Per-layer KV cache
    std::vector<DeviceBuffer> k_cache_;
    std::vector<DeviceBuffer> v_cache_;

    // Scratch (reallocated per-call sized by current S)
    DeviceBuffer q_sc_, k_sc_, v_sc_, xn_sc_, rstd_sc_, rope_sc_, attn_fias_sc_, attn_out_sc_;
    DeviceBuffer moe_xn_, moe_rstd_, moe_logits_;
    DeviceBuffer moe_topk_w_, moe_topk_idx_, moe_row_idx_;
    DeviceBuffer moe_ex_x_, moe_ex_ri_, moe_tpe_;
    DeviceBuffer moe_fwd_;
    DeviceBuffer moe_gate_, moe_up_, moe_down_;
    DeviceBuffer moe_packed_, moe_weighted_, moe_out_;
    DeviceBuffer moe_norm_sum_;       // BF16 [S, 1] for on-device topk_w normalize
    DeviceBuffer x_buf_a_, x_buf_b_;   // ping-pong for residual chain

    // Causal mask for prefill (2048 x 2048 bool); decode uses nullptr
    DeviceBuffer prefill_mask_dev_;

    // Batch decode mask: S_MAX × KV_MAX bool, where mask[i, j] = 1 (masked out) if
    // j > past_len + i. Built on-demand per-call (past_len changes).
    DeviceBuffer batch_mask_dev_;

    // Pre-computed RoPE cos/sin table (sized for max_seq_)
    RopeCache rope_cache_;

    int64_t past_len_ = 0;
    int64_t cur_S_capacity_ = 0;   // scratch sized for this many tokens
};