|
|
""" |
|
|
KiyEngine V3: Mamba-MoE Chess Model |
|
|
Matched exactly with standalone_train.py structure for 100% weight compatibility. |
|
|
""" |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from transformers import PreTrainedModel |
|
|
from transformers.modeling_outputs import ModelOutput |
|
|
from dataclasses import dataclass |
|
|
from typing import Optional, Tuple |
|
|
|
|
|
from .configuration_kiyengine import KiyEngineConfig |
|
|
|
|
|
|
|
|
|
|
|
class GaussianNoise(nn.Module): |
|
|
def __init__(self, sigma: float = 0.01): |
|
|
super().__init__() |
|
|
self.sigma = sigma |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
if self.training and self.sigma != 0: |
|
|
return x + torch.randn_like(x) * self.sigma |
|
|
return x |
|
|
|
|
|
class RMSNorm(nn.Module): |
|
|
def __init__(self, d_model: int, eps: float = 1e-5): |
|
|
super().__init__() |
|
|
self.eps = eps |
|
|
self.weight = nn.Parameter(torch.ones(d_model)) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
norm = x.norm(2, dim=-1, keepdim=True) * (x.shape[-1] ** -0.5) |
|
|
return x / (norm + self.eps) * self.weight |
|
|
|
|
|
class MambaBlock(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
|
|
|
d_model = config.d_model |
|
|
d_state = config.d_state |
|
|
d_conv = config.d_conv |
|
|
exp_factor = config.expansion_factor |
|
|
|
|
|
d_inner = d_model * exp_factor |
|
|
|
|
|
|
|
|
self.in_proj = nn.Linear(d_model, 2 * d_inner, bias=False) |
|
|
self.conv1d = nn.Conv1d( |
|
|
in_channels=d_inner, |
|
|
out_channels=d_inner, |
|
|
kernel_size=d_conv, |
|
|
bias=True, |
|
|
groups=d_inner, |
|
|
padding=d_conv - 1 |
|
|
) |
|
|
self.x_proj = nn.Linear(d_inner, d_inner + 2 * d_state, bias=False) |
|
|
self.dt_proj = nn.Linear(d_inner, d_inner, bias=True) |
|
|
self.A_log = nn.Parameter(torch.randn(d_inner, d_state)) |
|
|
self.D = nn.Parameter(torch.ones(d_inner)) |
|
|
self.out_proj = nn.Linear(d_inner, d_model, bias=False) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
|
|
|
|
|
|
_, L, C = x.shape |
|
|
xz = self.in_proj(x) |
|
|
x_inner, z = xz.chunk(2, dim=-1) |
|
|
|
|
|
|
|
|
x_conv = self.conv1d(x_inner.transpose(1, 2))[:, :, :L].transpose(1, 2) |
|
|
x_activated = F.silu(x_conv) |
|
|
|
|
|
|
|
|
y = x_activated * self.D.unsqueeze(0) |
|
|
y = y * F.silu(z) |
|
|
|
|
|
return self.out_proj(y) |
|
|
|
|
|
class MoELayer(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.n_experts = config.n_experts |
|
|
self.top_k = config.top_k |
|
|
|
|
|
self.router = nn.Linear(config.d_model, self.n_experts) |
|
|
self.experts = nn.ModuleList([MambaBlock(config) for _ in range(self.n_experts)]) |
|
|
|
|
|
def forward(self, x: torch.Tensor): |
|
|
B, L, C = x.shape |
|
|
x_flat = x.view(-1, C) |
|
|
router_logits = self.router(x_flat) |
|
|
router_probs = F.softmax(router_logits, dim=1) |
|
|
|
|
|
|
|
|
|
|
|
num_available = router_probs.size(-1) |
|
|
k_safe = min(self.top_k, num_available) |
|
|
|
|
|
top_k_weights, top_k_indices = torch.topk(router_probs, k_safe, dim=-1) |
|
|
top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-9) |
|
|
|
|
|
final_output = torch.zeros_like(x_flat) |
|
|
|
|
|
for i in range(k_safe): |
|
|
expert_idx = top_k_indices[:, i] |
|
|
weight = top_k_weights[:, i].unsqueeze(-1) |
|
|
|
|
|
for j in range(self.n_experts): |
|
|
mask = expert_idx == j |
|
|
if mask.any(): |
|
|
|
|
|
inp = x_flat[mask].unsqueeze(1) |
|
|
out = self.experts[j](inp).squeeze(1) |
|
|
final_output[mask] += out * weight[mask] |
|
|
|
|
|
return final_output.view(B, L, C) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class KiyEngineOutput(ModelOutput): |
|
|
loss: Optional[torch.Tensor] = None |
|
|
policy_logits: Optional[torch.Tensor] = None |
|
|
value: Optional[torch.Tensor] = None |
|
|
last_hidden_state: Optional[torch.Tensor] = None |
|
|
|
|
|
|
|
|
|
|
|
class KiyEngineModel(PreTrainedModel): |
|
|
""" |
|
|
KiyEngine V3: Matches exactly the structure of 'standalone_train.py' |
|
|
""" |
|
|
config_class = KiyEngineConfig |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
|
|
|
|
|
|
|
|
|
self.embedding = nn.Embedding(config.vocab_size, config.d_model) |
|
|
self.noise = GaussianNoise(sigma=0.0) |
|
|
|
|
|
|
|
|
self.layers = nn.ModuleList([MoELayer(config) for _ in range(config.n_layers)]) |
|
|
|
|
|
self.norm = RMSNorm(config.d_model) |
|
|
|
|
|
|
|
|
self.policy_head = nn.Linear(config.d_model, config.vocab_size, bias=False) |
|
|
self.value_head = nn.Sequential( |
|
|
nn.Linear(config.d_model, 128), |
|
|
nn.ReLU(), |
|
|
nn.Linear(128, 1) |
|
|
) |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
return_dict: Optional[bool] = None, |
|
|
**kwargs |
|
|
): |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
|
|
|
x = self.noise(self.embedding(input_ids)) |
|
|
|
|
|
for layer in self.layers: |
|
|
|
|
|
|
|
|
x = x + layer(self.norm(x)) |
|
|
|
|
|
x = self.norm(x) |
|
|
|
|
|
|
|
|
last_token_state = x[:, -1, :] |
|
|
|
|
|
policy_logits = self.policy_head(last_token_state) |
|
|
value = torch.tanh(self.value_head(last_token_state)) |
|
|
|
|
|
if not return_dict: |
|
|
return (policy_logits, value, x) |
|
|
|
|
|
return KiyEngineOutput( |
|
|
policy_logits=policy_logits, |
|
|
value=value, |
|
|
last_hidden_state=x |
|
|
) |