File size: 6,325 Bytes
7bc4f04 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 | """
KuiXing (魁星) — HuggingFace 相容包裝層
AutoConfig → KuiXingHFConfig
AutoModelForCausalLM → KuiXingForCausalLM
權重以 float32 儲存於 model.safetensors。
如需 bfloat16 推理:model = model.to(torch.bfloat16).eval()
"""
import math, os
import torch
import torch.nn as nn
import torch.nn.functional as F
from safetensors.torch import load_file
from transformers import PretrainedConfig, PreTrainedModel, GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast
class KuiXingHFConfig(PretrainedConfig):
model_type = "kuixing"
def __init__(
self,
vocab_size=99384,
hidden_size=2400,
num_hidden_layers=12,
num_attention_heads=32,
intermediate_size=9600,
max_position_embeddings=2048,
dropout=0.1,
pad_token_id=0, bos_token_id=2, eos_token_id=3,
**kwargs,
):
super().__init__(pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id, **kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.max_position_embeddings = max_position_embeddings
self.dropout = dropout
class _Attention(nn.Module):
def __init__(self, cfg):
super().__init__()
self.n_heads = cfg.num_attention_heads
self.d_head = cfg.hidden_size // cfg.num_attention_heads
self.q_proj = nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=False)
self.k_proj = nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=False)
self.v_proj = nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=False)
self.o_proj = nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=False)
def forward(self, x, mask=None):
B, L, D = x.shape
H, Dh = self.n_heads, self.d_head
q = self.q_proj(x).view(B, L, H, Dh).transpose(1, 2)
k = self.k_proj(x).view(B, L, H, Dh).transpose(1, 2)
v = self.v_proj(x).view(B, L, H, Dh).transpose(1, 2)
w = (q.float() @ k.float().transpose(-2, -1)) / math.sqrt(Dh)
if mask is not None:
w = w + mask
w = F.softmax(w, dim=-1).to(x.dtype)
out = (w.float() @ v.float()).to(x.dtype)
return self.o_proj(out.transpose(1, 2).reshape(B, L, D))
class _MLP(nn.Module):
def __init__(self, cfg):
super().__init__()
self.fc1 = nn.Linear(cfg.hidden_size, cfg.intermediate_size, bias=False)
self.act = nn.GELU()
self.fc2 = nn.Linear(cfg.intermediate_size, cfg.hidden_size, bias=False)
def forward(self, x):
return self.fc2(self.act(self.fc1(x)))
class _Block(nn.Module):
def __init__(self, cfg):
super().__init__()
self.norm1 = nn.RMSNorm(cfg.hidden_size)
self.attention = _Attention(cfg)
self.norm2 = nn.RMSNorm(cfg.hidden_size)
self.mlp = _MLP(cfg)
def forward(self, x, mask=None):
x = x + self.attention(self.norm1(x), mask)
x = x + self.mlp(self.norm2(x))
return x
class _KuiXingCore(nn.Module):
def __init__(self, cfg):
super().__init__()
self.token_emb = nn.Embedding(cfg.vocab_size, cfg.hidden_size)
self.pos_emb = nn.Embedding(cfg.max_position_embeddings, cfg.hidden_size)
self.layers = nn.ModuleList([_Block(cfg) for _ in range(cfg.num_hidden_layers)])
self.norm_final = nn.RMSNorm(cfg.hidden_size)
self.lm_head = nn.Linear(cfg.hidden_size, cfg.vocab_size, bias=False)
self.lm_head.weight = self.token_emb.weight
def forward(self, input_ids):
B, L = input_ids.shape
pos = torch.arange(L, device=input_ids.device)
h = self.token_emb(input_ids) + self.pos_emb(pos)
mask = torch.triu(
torch.full((L, L), float("-inf"), device=input_ids.device), diagonal=1
).unsqueeze(0).unsqueeze(0)
for layer in self.layers:
h = layer(h, mask)
return self.lm_head(self.norm_final(h))
class KuiXingForCausalLM(PreTrainedModel, GenerationMixin):
config_class = KuiXingHFConfig
supports_gradient_checkpointing = False
def __init__(self, config):
super().__init__(config)
self.model = _KuiXingCore(config)
self.post_init()
@classmethod
def from_pretrained(cls, model_path, **kwargs):
import json
with open(os.path.join(model_path, "config.json")) as f:
cfg_dict = json.load(f)
valid = set(KuiXingHFConfig.__init__.__code__.co_varnames)
hf_cfg = KuiXingHFConfig(**{k: v for k, v in cfg_dict.items() if k in valid})
model = cls(hf_cfg)
sd = load_file(os.path.join(model_path, "model.safetensors"))
missing, unexpected = model.load_state_dict(sd, strict=False)
# lm_head.weight 不存入 safetensors(weight tying),載入後手動重建共享
model.model.lm_head.weight = model.model.token_emb.weight
# lm_head.weight 是刻意省略的(weight tying),從 missing 中排除再判斷
missing = [k for k in missing if k != "model.lm_head.weight"]
if not missing and not unexpected:
print("✅ 所有權重 key 完整對映,無缺漏。\n如需以 bfloat16 推理:model = model.to(torch.bfloat16).eval()")
else:
if missing: print(f"⚠️ 缺少 key({len(missing)}):{missing[:5]}")
if unexpected: print(f"⚠️ 多餘 key({len(unexpected)}):{unexpected[:5]}")
return model.eval()
def forward(self, input_ids=None, labels=None, **kwargs):
logits = self.model(input_ids)
loss = None
if labels is not None:
loss = F.cross_entropy(
logits[:, :-1].reshape(-1, logits.size(-1)),
labels[:, 1:].reshape(-1),
ignore_index=-100,
)
return CausalLMOutputWithPast(loss=loss, logits=logits)
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids}
|