File size: 6,934 Bytes
c335116 f0bcf7d c335116 95feed4 c335116 4fcbdfa f0bcf7d c335116 f0bcf7d c335116 f0bcf7d c335116 95feed4 c335116 95feed4 c335116 95feed4 c335116 f0bcf7d c335116 f0bcf7d c335116 95feed4 c335116 95feed4 c335116 8a7c276 c335116 8a7c276 c335116 95feed4 c335116 8a7c276 95feed4 c335116 95feed4 c335116 95feed4 c335116 f0bcf7d c335116 f0bcf7d c335116 f0bcf7d 95feed4 c335116 95feed4 c335116 95feed4 c335116 95feed4 c335116 95feed4 c335116 95feed4 c335116 95feed4 c335116 95feed4 f0bcf7d c335116 f0bcf7d 95feed4 c335116 f0bcf7d c335116 95feed4 c335116 95feed4 c335116 8a7c276 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
"""
KiyEngine V3: Mamba-MoE Chess Model
Matched exactly with standalone_train.py structure for 100% weight compatibility.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.modeling_outputs import ModelOutput
from dataclasses import dataclass
from typing import Optional, Tuple
from .configuration_kiyengine import KiyEngineConfig
# === Helper Classes (Copied & Adapted from Training Script) ===
class GaussianNoise(nn.Module):
def __init__(self, sigma: float = 0.01):
super().__init__()
self.sigma = sigma
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Trong Inference, ta luôn tắt Noise (sigma=0 hoặc mode eval)
if self.training and self.sigma != 0:
return x + torch.randn_like(x) * self.sigma
return x
class RMSNorm(nn.Module):
def __init__(self, d_model: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(d_model))
def forward(self, x: torch.Tensor) -> torch.Tensor:
norm = x.norm(2, dim=-1, keepdim=True) * (x.shape[-1] ** -0.5)
return x / (norm + self.eps) * self.weight
class MambaBlock(nn.Module):
def __init__(self, config):
super().__init__()
# Lấy tham số từ config object
d_model = config.d_model
d_state = config.d_state
d_conv = config.d_conv
exp_factor = config.expansion_factor
d_inner = d_model * exp_factor
# Định nghĩa y hệt training script để khớp keys
self.in_proj = nn.Linear(d_model, 2 * d_inner, bias=False)
self.conv1d = nn.Conv1d(
in_channels=d_inner,
out_channels=d_inner,
kernel_size=d_conv,
bias=True,
groups=d_inner,
padding=d_conv - 1
)
self.x_proj = nn.Linear(d_inner, d_inner + 2 * d_state, bias=False)
self.dt_proj = nn.Linear(d_inner, d_inner, bias=True)
self.A_log = nn.Parameter(torch.randn(d_inner, d_state))
self.D = nn.Parameter(torch.ones(d_inner))
self.out_proj = nn.Linear(d_inner, d_model, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Logic forward khớp với training script
# Lưu ý: Script training của sếp dùng mô hình simplified (Gated CNN)
# nên ta phải follow đúng logic đó để ra kết quả đúng.
_, L, C = x.shape
xz = self.in_proj(x)
x_inner, z = xz.chunk(2, dim=-1)
# Conv1d expects (B, C, L)
x_conv = self.conv1d(x_inner.transpose(1, 2))[:, :, :L].transpose(1, 2)
x_activated = F.silu(x_conv)
# Element-wise gating with D
y = x_activated * self.D.unsqueeze(0)
y = y * F.silu(z)
return self.out_proj(y)
class MoELayer(nn.Module):
def __init__(self, config):
super().__init__()
self.n_experts = config.n_experts
self.top_k = config.top_k
self.router = nn.Linear(config.d_model, self.n_experts)
self.experts = nn.ModuleList([MambaBlock(config) for _ in range(self.n_experts)])
def forward(self, x: torch.Tensor):
B, L, C = x.shape
x_flat = x.view(-1, C)
router_logits = self.router(x_flat)
router_probs = F.softmax(router_logits, dim=1)
# --- SAFE ROUTING FIX ---
# Giữ lại fix này để tránh crash nếu config lệch
num_available = router_probs.size(-1)
k_safe = min(self.top_k, num_available)
top_k_weights, top_k_indices = torch.topk(router_probs, k_safe, dim=-1)
top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-9)
final_output = torch.zeros_like(x_flat)
for i in range(k_safe):
expert_idx = top_k_indices[:, i]
weight = top_k_weights[:, i].unsqueeze(-1)
for j in range(self.n_experts):
mask = expert_idx == j
if mask.any():
# Logic: Input (N, D) -> Unsqueeze(1) -> (N, 1, D) -> Expert -> Squeeze(1)
inp = x_flat[mask].unsqueeze(1)
out = self.experts[j](inp).squeeze(1)
final_output[mask] += out * weight[mask]
return final_output.view(B, L, C)
# === Output Class for Hugging Face ===
@dataclass
class KiyEngineOutput(ModelOutput):
loss: Optional[torch.Tensor] = None
policy_logits: Optional[torch.Tensor] = None
value: Optional[torch.Tensor] = None
last_hidden_state: Optional[torch.Tensor] = None
# === Main Model Class ===
class KiyEngineModel(PreTrainedModel):
"""
KiyEngine V3: Matches exactly the structure of 'standalone_train.py'
"""
config_class = KiyEngineConfig
def __init__(self, config):
super().__init__(config)
self.config = config
# --- MATCHING KEYS WITH TRAIN SCRIPT ---
# Train script: self.embedding (NOT embeddings)
self.embedding = nn.Embedding(config.vocab_size, config.d_model)
self.noise = GaussianNoise(sigma=0.0) # Inference mode
# Train script: self.layers = ModuleList of MoELayer
self.layers = nn.ModuleList([MoELayer(config) for _ in range(config.n_layers)])
self.norm = RMSNorm(config.d_model)
# Train script has heads built-in
self.policy_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
self.value_head = nn.Sequential(
nn.Linear(config.d_model, 128),
nn.ReLU(),
nn.Linear(128, 1)
)
# Initialize weights
self.post_init()
def forward(
self,
input_ids: torch.Tensor,
return_dict: Optional[bool] = None,
**kwargs
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Forward pass matching training logic
x = self.noise(self.embedding(input_ids))
for layer in self.layers:
# Training script logic: x = x + layer(norm(x))[0]
# Our MoELayer returns just the tensor (we dropped aux_loss return for inference clean-up)
x = x + layer(self.norm(x))
x = self.norm(x)
# Last token logic
last_token_state = x[:, -1, :]
policy_logits = self.policy_head(last_token_state)
value = torch.tanh(self.value_head(last_token_state))
if not return_dict:
return (policy_logits, value, x)
return KiyEngineOutput(
policy_logits=policy_logits,
value=value,
last_hidden_state=x
) |