File size: 2,292 Bytes
4b9fefd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 | // model_config.h — Qwen3 hparams loaded from HF config.json, plus TP-derived per-rank sizes.
#pragma once
#include <cstdint>
#include <string>
struct ModelConfig {
// ---- Raw hparams from config.json ----
int64_t vocab_size = 0;
int64_t hidden_size = 0; // D
int64_t intermediate_size = 0; // dense FFN (not used for MoE layers; kept for completeness)
int64_t moe_intermediate_size = 0; // I per expert
int64_t num_hidden_layers = 0; // = 94 for Qwen3-235B
int64_t num_attention_heads = 0; // = 64
int64_t num_key_value_heads = 0; // = 4 (GQA)
int64_t head_dim = 0; // = 128
int64_t num_experts = 0; // = 128
int64_t num_experts_per_tok = 0; // top_k = 8
int64_t max_position_embeddings = 0;
float rope_theta = 0.0f;
float rms_norm_eps = 1e-6f;
bool norm_topk_prob = true;
bool tie_word_embeddings = false;
int64_t bos_token_id = 0;
int64_t eos_token_id = 0;
// ---- TP configuration ----
int tp_size = 1;
int tp_rank = 0;
// ---- Derived per-rank sizes ----
// Attention Q: split along num_heads (head-parallel)
// n_heads_per_rank = num_attention_heads / tp_size
// q_dim_per_rank = n_heads_per_rank * head_dim
int64_t n_heads_per_rank = 0;
int64_t q_dim_per_rank = 0;
// Attention KV: GQA with num_kv_heads < tp_size needs special handling.
// For Qwen3-235B: num_kv_heads = 4, tp_size = 16 → each KV head is replicated 4× across ranks.
// Simple scheme: each rank computes ALL kv heads (small, 4 × 128 = 512 features)
// then slices attention output for its own q heads.
// Alternative: split KV heads if tp_size <= num_kv_heads.
int64_t n_kv_heads_per_rank = 0;
int64_t kv_dim_per_rank = 0;
// MoE: intermediate dim split. Each rank holds 1/tp_size of experts' intermediate_size.
// i_per_rank = moe_intermediate_size / tp_size
int64_t i_per_rank = 0;
bool load_from_json(const std::string& path);
void compute_derived(int tp_size, int tp_rank);
std::string describe() const;
};
|