""" KuiXing (魁星) — HuggingFace 相容包裝層 AutoConfig → KuiXingHFConfig AutoModelForCausalLM → KuiXingForCausalLM 權重以 float32 儲存於 model.safetensors。 如需 bfloat16 推理:model = model.to(torch.bfloat16).eval() """ import math, os import torch import torch.nn as nn import torch.nn.functional as F from safetensors.torch import load_file from transformers import PretrainedConfig, PreTrainedModel, GenerationMixin from transformers.modeling_outputs import CausalLMOutputWithPast class KuiXingHFConfig(PretrainedConfig): model_type = "kuixing" def __init__( self, vocab_size=99384, hidden_size=2400, num_hidden_layers=12, num_attention_heads=32, intermediate_size=9600, max_position_embeddings=2048, dropout=0.1, pad_token_id=0, bos_token_id=2, eos_token_id=3, **kwargs, ): super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.intermediate_size = intermediate_size self.max_position_embeddings = max_position_embeddings self.dropout = dropout class _Attention(nn.Module): def __init__(self, cfg): super().__init__() self.n_heads = cfg.num_attention_heads self.d_head = cfg.hidden_size // cfg.num_attention_heads self.q_proj = nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=False) self.k_proj = nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=False) self.v_proj = nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=False) self.o_proj = nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=False) def forward(self, x, mask=None): B, L, D = x.shape H, Dh = self.n_heads, self.d_head q = self.q_proj(x).view(B, L, H, Dh).transpose(1, 2) k = self.k_proj(x).view(B, L, H, Dh).transpose(1, 2) v = self.v_proj(x).view(B, L, H, Dh).transpose(1, 2) w = (q.float() @ k.float().transpose(-2, -1)) / math.sqrt(Dh) if mask is not None: w = w + mask w = F.softmax(w, dim=-1).to(x.dtype) out = (w.float() @ v.float()).to(x.dtype) return self.o_proj(out.transpose(1, 2).reshape(B, L, D)) class _MLP(nn.Module): def __init__(self, cfg): super().__init__() self.fc1 = nn.Linear(cfg.hidden_size, cfg.intermediate_size, bias=False) self.act = nn.GELU() self.fc2 = nn.Linear(cfg.intermediate_size, cfg.hidden_size, bias=False) def forward(self, x): return self.fc2(self.act(self.fc1(x))) class _Block(nn.Module): def __init__(self, cfg): super().__init__() self.norm1 = nn.RMSNorm(cfg.hidden_size) self.attention = _Attention(cfg) self.norm2 = nn.RMSNorm(cfg.hidden_size) self.mlp = _MLP(cfg) def forward(self, x, mask=None): x = x + self.attention(self.norm1(x), mask) x = x + self.mlp(self.norm2(x)) return x class _KuiXingCore(nn.Module): def __init__(self, cfg): super().__init__() self.token_emb = nn.Embedding(cfg.vocab_size, cfg.hidden_size) self.pos_emb = nn.Embedding(cfg.max_position_embeddings, cfg.hidden_size) self.layers = nn.ModuleList([_Block(cfg) for _ in range(cfg.num_hidden_layers)]) self.norm_final = nn.RMSNorm(cfg.hidden_size) self.lm_head = nn.Linear(cfg.hidden_size, cfg.vocab_size, bias=False) self.lm_head.weight = self.token_emb.weight def forward(self, input_ids): B, L = input_ids.shape pos = torch.arange(L, device=input_ids.device) h = self.token_emb(input_ids) + self.pos_emb(pos) mask = torch.triu( torch.full((L, L), float("-inf"), device=input_ids.device), diagonal=1 ).unsqueeze(0).unsqueeze(0) for layer in self.layers: h = layer(h, mask) return self.lm_head(self.norm_final(h)) class KuiXingForCausalLM(PreTrainedModel, GenerationMixin): config_class = KuiXingHFConfig supports_gradient_checkpointing = False def __init__(self, config): super().__init__(config) self.model = _KuiXingCore(config) self.post_init() @classmethod def from_pretrained(cls, model_path, **kwargs): import json with open(os.path.join(model_path, "config.json")) as f: cfg_dict = json.load(f) valid = set(KuiXingHFConfig.__init__.__code__.co_varnames) hf_cfg = KuiXingHFConfig(**{k: v for k, v in cfg_dict.items() if k in valid}) model = cls(hf_cfg) sd = load_file(os.path.join(model_path, "model.safetensors")) missing, unexpected = model.load_state_dict(sd, strict=False) # lm_head.weight 不存入 safetensors(weight tying),載入後手動重建共享 model.model.lm_head.weight = model.model.token_emb.weight # lm_head.weight 是刻意省略的(weight tying),從 missing 中排除再判斷 missing = [k for k in missing if k != "model.lm_head.weight"] if not missing and not unexpected: print("✅ 所有權重 key 完整對映,無缺漏。\n如需以 bfloat16 推理:model = model.to(torch.bfloat16).eval()") else: if missing: print(f"⚠️ 缺少 key({len(missing)}):{missing[:5]}") if unexpected: print(f"⚠️ 多餘 key({len(unexpected)}):{unexpected[:5]}") return model.eval() def forward(self, input_ids=None, labels=None, **kwargs): logits = self.model(input_ids) loss = None if labels is not None: loss = F.cross_entropy( logits[:, :-1].reshape(-1, logits.size(-1)), labels[:, 1:].reshape(-1), ignore_index=-100, ) return CausalLMOutputWithPast(loss=loss, logits=logits) def prepare_inputs_for_generation(self, input_ids, **kwargs): return {"input_ids": input_ids}