File size: 6,325 Bytes
7bc4f04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""
KuiXing (魁星) — HuggingFace 相容包裝層
AutoConfig           → KuiXingHFConfig
AutoModelForCausalLM → KuiXingForCausalLM

權重以 float32 儲存於 model.safetensors。
如需 bfloat16 推理:model = model.to(torch.bfloat16).eval()
"""
import math, os
import torch
import torch.nn as nn
import torch.nn.functional as F
from safetensors.torch import load_file
from transformers import PretrainedConfig, PreTrainedModel, GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast


class KuiXingHFConfig(PretrainedConfig):
    model_type = "kuixing"
    def __init__(
        self,
        vocab_size=99384,
        hidden_size=2400,
        num_hidden_layers=12,
        num_attention_heads=32,
        intermediate_size=9600,
        max_position_embeddings=2048,
        dropout=0.1,
        pad_token_id=0, bos_token_id=2, eos_token_id=3,
        **kwargs,
    ):
        super().__init__(pad_token_id=pad_token_id,
                         bos_token_id=bos_token_id,
                         eos_token_id=eos_token_id, **kwargs)
        self.vocab_size              = vocab_size
        self.hidden_size             = hidden_size
        self.num_hidden_layers       = num_hidden_layers
        self.num_attention_heads     = num_attention_heads
        self.intermediate_size       = intermediate_size
        self.max_position_embeddings = max_position_embeddings
        self.dropout                 = dropout


class _Attention(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.n_heads = cfg.num_attention_heads
        self.d_head  = cfg.hidden_size // cfg.num_attention_heads
        self.q_proj  = nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=False)
        self.k_proj  = nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=False)
        self.v_proj  = nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=False)
        self.o_proj  = nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=False)

    def forward(self, x, mask=None):
        B, L, D = x.shape
        H, Dh   = self.n_heads, self.d_head
        q = self.q_proj(x).view(B, L, H, Dh).transpose(1, 2)
        k = self.k_proj(x).view(B, L, H, Dh).transpose(1, 2)
        v = self.v_proj(x).view(B, L, H, Dh).transpose(1, 2)
        w = (q.float() @ k.float().transpose(-2, -1)) / math.sqrt(Dh)
        if mask is not None:
            w = w + mask
        w   = F.softmax(w, dim=-1).to(x.dtype)
        out = (w.float() @ v.float()).to(x.dtype)
        return self.o_proj(out.transpose(1, 2).reshape(B, L, D))


class _MLP(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.fc1  = nn.Linear(cfg.hidden_size, cfg.intermediate_size, bias=False)
        self.act  = nn.GELU()
        self.fc2  = nn.Linear(cfg.intermediate_size, cfg.hidden_size, bias=False)
    def forward(self, x):
        return self.fc2(self.act(self.fc1(x)))


class _Block(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.norm1     = nn.RMSNorm(cfg.hidden_size)
        self.attention = _Attention(cfg)
        self.norm2     = nn.RMSNorm(cfg.hidden_size)
        self.mlp       = _MLP(cfg)
    def forward(self, x, mask=None):
        x = x + self.attention(self.norm1(x), mask)
        x = x + self.mlp(self.norm2(x))
        return x


class _KuiXingCore(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.token_emb  = nn.Embedding(cfg.vocab_size, cfg.hidden_size)
        self.pos_emb    = nn.Embedding(cfg.max_position_embeddings, cfg.hidden_size)
        self.layers     = nn.ModuleList([_Block(cfg) for _ in range(cfg.num_hidden_layers)])
        self.norm_final = nn.RMSNorm(cfg.hidden_size)
        self.lm_head    = nn.Linear(cfg.hidden_size, cfg.vocab_size, bias=False)
        self.lm_head.weight = self.token_emb.weight

    def forward(self, input_ids):
        B, L = input_ids.shape
        pos  = torch.arange(L, device=input_ids.device)
        h    = self.token_emb(input_ids) + self.pos_emb(pos)
        mask = torch.triu(
            torch.full((L, L), float("-inf"), device=input_ids.device), diagonal=1
        ).unsqueeze(0).unsqueeze(0)
        for layer in self.layers:
            h = layer(h, mask)
        return self.lm_head(self.norm_final(h))


class KuiXingForCausalLM(PreTrainedModel, GenerationMixin):
    config_class = KuiXingHFConfig
    supports_gradient_checkpointing = False

    def __init__(self, config):
        super().__init__(config)
        self.model = _KuiXingCore(config)
        self.post_init()

    @classmethod
    def from_pretrained(cls, model_path, **kwargs):
        import json
        with open(os.path.join(model_path, "config.json")) as f:
            cfg_dict = json.load(f)
        valid = set(KuiXingHFConfig.__init__.__code__.co_varnames)
        hf_cfg = KuiXingHFConfig(**{k: v for k, v in cfg_dict.items() if k in valid})
        model  = cls(hf_cfg)
        sd     = load_file(os.path.join(model_path, "model.safetensors"))
        missing, unexpected = model.load_state_dict(sd, strict=False)
        # lm_head.weight 不存入 safetensors(weight tying),載入後手動重建共享
        model.model.lm_head.weight = model.model.token_emb.weight
        # lm_head.weight 是刻意省略的(weight tying),從 missing 中排除再判斷
        missing = [k for k in missing if k != "model.lm_head.weight"]
        if not missing and not unexpected:
            print("✅ 所有權重 key 完整對映,無缺漏。\n如需以 bfloat16 推理:model = model.to(torch.bfloat16).eval()")
        else:
            if missing:    print(f"⚠️  缺少 key({len(missing)}):{missing[:5]}")
            if unexpected: print(f"⚠️  多餘 key({len(unexpected)}):{unexpected[:5]}")
        return model.eval()

    def forward(self, input_ids=None, labels=None, **kwargs):
        logits = self.model(input_ids)
        loss   = None
        if labels is not None:
            loss = F.cross_entropy(
                logits[:, :-1].reshape(-1, logits.size(-1)),
                labels[:, 1:].reshape(-1),
                ignore_index=-100,
            )
        return CausalLMOutputWithPast(loss=loss, logits=logits)

    def prepare_inputs_for_generation(self, input_ids, **kwargs):
        return {"input_ids": input_ids}