KiyEngine-V3 / modeling_kiyengine.py
Kiy-K's picture
Update modeling_kiyengine.py
c335116 verified
"""
KiyEngine V3: Mamba-MoE Chess Model
Matched exactly with standalone_train.py structure for 100% weight compatibility.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.modeling_outputs import ModelOutput
from dataclasses import dataclass
from typing import Optional, Tuple
from .configuration_kiyengine import KiyEngineConfig
# === Helper Classes (Copied & Adapted from Training Script) ===
class GaussianNoise(nn.Module):
def __init__(self, sigma: float = 0.01):
super().__init__()
self.sigma = sigma
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Trong Inference, ta luôn tắt Noise (sigma=0 hoặc mode eval)
if self.training and self.sigma != 0:
return x + torch.randn_like(x) * self.sigma
return x
class RMSNorm(nn.Module):
def __init__(self, d_model: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(d_model))
def forward(self, x: torch.Tensor) -> torch.Tensor:
norm = x.norm(2, dim=-1, keepdim=True) * (x.shape[-1] ** -0.5)
return x / (norm + self.eps) * self.weight
class MambaBlock(nn.Module):
def __init__(self, config):
super().__init__()
# Lấy tham số từ config object
d_model = config.d_model
d_state = config.d_state
d_conv = config.d_conv
exp_factor = config.expansion_factor
d_inner = d_model * exp_factor
# Định nghĩa y hệt training script để khớp keys
self.in_proj = nn.Linear(d_model, 2 * d_inner, bias=False)
self.conv1d = nn.Conv1d(
in_channels=d_inner,
out_channels=d_inner,
kernel_size=d_conv,
bias=True,
groups=d_inner,
padding=d_conv - 1
)
self.x_proj = nn.Linear(d_inner, d_inner + 2 * d_state, bias=False)
self.dt_proj = nn.Linear(d_inner, d_inner, bias=True)
self.A_log = nn.Parameter(torch.randn(d_inner, d_state))
self.D = nn.Parameter(torch.ones(d_inner))
self.out_proj = nn.Linear(d_inner, d_model, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Logic forward khớp với training script
# Lưu ý: Script training của sếp dùng mô hình simplified (Gated CNN)
# nên ta phải follow đúng logic đó để ra kết quả đúng.
_, L, C = x.shape
xz = self.in_proj(x)
x_inner, z = xz.chunk(2, dim=-1)
# Conv1d expects (B, C, L)
x_conv = self.conv1d(x_inner.transpose(1, 2))[:, :, :L].transpose(1, 2)
x_activated = F.silu(x_conv)
# Element-wise gating with D
y = x_activated * self.D.unsqueeze(0)
y = y * F.silu(z)
return self.out_proj(y)
class MoELayer(nn.Module):
def __init__(self, config):
super().__init__()
self.n_experts = config.n_experts
self.top_k = config.top_k
self.router = nn.Linear(config.d_model, self.n_experts)
self.experts = nn.ModuleList([MambaBlock(config) for _ in range(self.n_experts)])
def forward(self, x: torch.Tensor):
B, L, C = x.shape
x_flat = x.view(-1, C)
router_logits = self.router(x_flat)
router_probs = F.softmax(router_logits, dim=1)
# --- SAFE ROUTING FIX ---
# Giữ lại fix này để tránh crash nếu config lệch
num_available = router_probs.size(-1)
k_safe = min(self.top_k, num_available)
top_k_weights, top_k_indices = torch.topk(router_probs, k_safe, dim=-1)
top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-9)
final_output = torch.zeros_like(x_flat)
for i in range(k_safe):
expert_idx = top_k_indices[:, i]
weight = top_k_weights[:, i].unsqueeze(-1)
for j in range(self.n_experts):
mask = expert_idx == j
if mask.any():
# Logic: Input (N, D) -> Unsqueeze(1) -> (N, 1, D) -> Expert -> Squeeze(1)
inp = x_flat[mask].unsqueeze(1)
out = self.experts[j](inp).squeeze(1)
final_output[mask] += out * weight[mask]
return final_output.view(B, L, C)
# === Output Class for Hugging Face ===
@dataclass
class KiyEngineOutput(ModelOutput):
loss: Optional[torch.Tensor] = None
policy_logits: Optional[torch.Tensor] = None
value: Optional[torch.Tensor] = None
last_hidden_state: Optional[torch.Tensor] = None
# === Main Model Class ===
class KiyEngineModel(PreTrainedModel):
"""
KiyEngine V3: Matches exactly the structure of 'standalone_train.py'
"""
config_class = KiyEngineConfig
def __init__(self, config):
super().__init__(config)
self.config = config
# --- MATCHING KEYS WITH TRAIN SCRIPT ---
# Train script: self.embedding (NOT embeddings)
self.embedding = nn.Embedding(config.vocab_size, config.d_model)
self.noise = GaussianNoise(sigma=0.0) # Inference mode
# Train script: self.layers = ModuleList of MoELayer
self.layers = nn.ModuleList([MoELayer(config) for _ in range(config.n_layers)])
self.norm = RMSNorm(config.d_model)
# Train script has heads built-in
self.policy_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
self.value_head = nn.Sequential(
nn.Linear(config.d_model, 128),
nn.ReLU(),
nn.Linear(128, 1)
)
# Initialize weights
self.post_init()
def forward(
self,
input_ids: torch.Tensor,
return_dict: Optional[bool] = None,
**kwargs
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Forward pass matching training logic
x = self.noise(self.embedding(input_ids))
for layer in self.layers:
# Training script logic: x = x + layer(norm(x))[0]
# Our MoELayer returns just the tensor (we dropped aux_loss return for inference clean-up)
x = x + layer(self.norm(x))
x = self.norm(x)
# Last token logic
last_token_state = x[:, -1, :]
policy_logits = self.policy_head(last_token_state)
value = torch.tanh(self.value_head(last_token_state))
if not return_dict:
return (policy_logits, value, x)
return KiyEngineOutput(
policy_logits=policy_logits,
value=value,
last_hidden_state=x
)