""" KiyEngine V3: Mamba-MoE Chess Model Matched exactly with standalone_train.py structure for 100% weight compatibility. """ import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel from transformers.modeling_outputs import ModelOutput from dataclasses import dataclass from typing import Optional, Tuple from .configuration_kiyengine import KiyEngineConfig # === Helper Classes (Copied & Adapted from Training Script) === class GaussianNoise(nn.Module): def __init__(self, sigma: float = 0.01): super().__init__() self.sigma = sigma def forward(self, x: torch.Tensor) -> torch.Tensor: # Trong Inference, ta luôn tắt Noise (sigma=0 hoặc mode eval) if self.training and self.sigma != 0: return x + torch.randn_like(x) * self.sigma return x class RMSNorm(nn.Module): def __init__(self, d_model: int, eps: float = 1e-5): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(d_model)) def forward(self, x: torch.Tensor) -> torch.Tensor: norm = x.norm(2, dim=-1, keepdim=True) * (x.shape[-1] ** -0.5) return x / (norm + self.eps) * self.weight class MambaBlock(nn.Module): def __init__(self, config): super().__init__() # Lấy tham số từ config object d_model = config.d_model d_state = config.d_state d_conv = config.d_conv exp_factor = config.expansion_factor d_inner = d_model * exp_factor # Định nghĩa y hệt training script để khớp keys self.in_proj = nn.Linear(d_model, 2 * d_inner, bias=False) self.conv1d = nn.Conv1d( in_channels=d_inner, out_channels=d_inner, kernel_size=d_conv, bias=True, groups=d_inner, padding=d_conv - 1 ) self.x_proj = nn.Linear(d_inner, d_inner + 2 * d_state, bias=False) self.dt_proj = nn.Linear(d_inner, d_inner, bias=True) self.A_log = nn.Parameter(torch.randn(d_inner, d_state)) self.D = nn.Parameter(torch.ones(d_inner)) self.out_proj = nn.Linear(d_inner, d_model, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: # Logic forward khớp với training script # Lưu ý: Script training của sếp dùng mô hình simplified (Gated CNN) # nên ta phải follow đúng logic đó để ra kết quả đúng. _, L, C = x.shape xz = self.in_proj(x) x_inner, z = xz.chunk(2, dim=-1) # Conv1d expects (B, C, L) x_conv = self.conv1d(x_inner.transpose(1, 2))[:, :, :L].transpose(1, 2) x_activated = F.silu(x_conv) # Element-wise gating with D y = x_activated * self.D.unsqueeze(0) y = y * F.silu(z) return self.out_proj(y) class MoELayer(nn.Module): def __init__(self, config): super().__init__() self.n_experts = config.n_experts self.top_k = config.top_k self.router = nn.Linear(config.d_model, self.n_experts) self.experts = nn.ModuleList([MambaBlock(config) for _ in range(self.n_experts)]) def forward(self, x: torch.Tensor): B, L, C = x.shape x_flat = x.view(-1, C) router_logits = self.router(x_flat) router_probs = F.softmax(router_logits, dim=1) # --- SAFE ROUTING FIX --- # Giữ lại fix này để tránh crash nếu config lệch num_available = router_probs.size(-1) k_safe = min(self.top_k, num_available) top_k_weights, top_k_indices = torch.topk(router_probs, k_safe, dim=-1) top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-9) final_output = torch.zeros_like(x_flat) for i in range(k_safe): expert_idx = top_k_indices[:, i] weight = top_k_weights[:, i].unsqueeze(-1) for j in range(self.n_experts): mask = expert_idx == j if mask.any(): # Logic: Input (N, D) -> Unsqueeze(1) -> (N, 1, D) -> Expert -> Squeeze(1) inp = x_flat[mask].unsqueeze(1) out = self.experts[j](inp).squeeze(1) final_output[mask] += out * weight[mask] return final_output.view(B, L, C) # === Output Class for Hugging Face === @dataclass class KiyEngineOutput(ModelOutput): loss: Optional[torch.Tensor] = None policy_logits: Optional[torch.Tensor] = None value: Optional[torch.Tensor] = None last_hidden_state: Optional[torch.Tensor] = None # === Main Model Class === class KiyEngineModel(PreTrainedModel): """ KiyEngine V3: Matches exactly the structure of 'standalone_train.py' """ config_class = KiyEngineConfig def __init__(self, config): super().__init__(config) self.config = config # --- MATCHING KEYS WITH TRAIN SCRIPT --- # Train script: self.embedding (NOT embeddings) self.embedding = nn.Embedding(config.vocab_size, config.d_model) self.noise = GaussianNoise(sigma=0.0) # Inference mode # Train script: self.layers = ModuleList of MoELayer self.layers = nn.ModuleList([MoELayer(config) for _ in range(config.n_layers)]) self.norm = RMSNorm(config.d_model) # Train script has heads built-in self.policy_head = nn.Linear(config.d_model, config.vocab_size, bias=False) self.value_head = nn.Sequential( nn.Linear(config.d_model, 128), nn.ReLU(), nn.Linear(128, 1) ) # Initialize weights self.post_init() def forward( self, input_ids: torch.Tensor, return_dict: Optional[bool] = None, **kwargs ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Forward pass matching training logic x = self.noise(self.embedding(input_ids)) for layer in self.layers: # Training script logic: x = x + layer(norm(x))[0] # Our MoELayer returns just the tensor (we dropped aux_loss return for inference clean-up) x = x + layer(self.norm(x)) x = self.norm(x) # Last token logic last_token_state = x[:, -1, :] policy_logits = self.policy_head(last_token_state) value = torch.tanh(self.value_head(last_token_state)) if not return_dict: return (policy_logits, value, x) return KiyEngineOutput( policy_logits=policy_logits, value=value, last_hidden_state=x )