Kiy-K commited on
Commit
c335116
·
verified ·
1 Parent(s): 8a7c276

Update modeling_kiyengine.py

Browse files
Files changed (1) hide show
  1. modeling_kiyengine.py +143 -134
modeling_kiyengine.py CHANGED
@@ -1,185 +1,194 @@
 
 
 
 
1
  import torch
2
  import torch.nn as nn
 
3
  from transformers import PreTrainedModel
4
- from transformers.modeling_outputs import BaseModelOutput
 
 
 
5
  from .configuration_kiyengine import KiyEngineConfig
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  class MambaBlock(nn.Module):
8
- """Mamba SSM Block"""
9
- def __init__(self, d_model, d_state, d_conv, expansion_factor):
10
  super().__init__()
11
- self.d_model = d_model
12
- self.d_state = d_state
13
- self.d_conv = d_conv
14
- self.expansion = expansion_factor
 
 
 
15
 
16
- # Simplified Mamba components
17
- self.in_proj = nn.Linear(d_model, d_model * expansion_factor * 2)
18
  self.conv1d = nn.Conv1d(
19
- d_model * expansion_factor,
20
- d_model * expansion_factor,
21
- kernel_size=d_conv,
22
- padding=d_conv - 1,
23
- groups=d_model * expansion_factor
 
24
  )
25
- self.x_proj = nn.Linear(d_model * expansion_factor, d_state * 2)
26
- self.dt_proj = nn.Linear(d_model * expansion_factor, d_model)
27
- self.out_proj = nn.Linear(d_model * expansion_factor, d_model)
28
-
29
- def forward(self, x):
30
- # Simplified forward pass
31
- b, l, d = x.shape
32
-
33
- # Input projection
34
- x_and_res = self.in_proj(x)
35
- x, res = x_and_res.split(self.d_model * self.expansion, dim=-1)
36
-
37
- # Conv1d
38
- x = x.transpose(1, 2) # (B, D, L)
39
- x = self.conv1d(x)[:, :, :l]
40
- x = x.transpose(1, 2) # (B, L, D)
41
-
42
- # SSM
43
- x = nn.functional.silu(x)
44
-
45
- # Output projection
46
- y = self.out_proj(x * nn.functional.silu(res))
47
-
48
- return y
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  class MoELayer(nn.Module):
52
- """Mixture of Experts Layer (Fixed with Safe Routing)"""
53
- def __init__(self, d_model, n_experts, top_k):
54
  super().__init__()
55
- self.n_experts = n_experts
56
- self.top_k = top_k
57
-
58
- # Router
59
- self.gate = nn.Linear(d_model, n_experts)
60
-
61
- # Experts
62
- self.experts = nn.ModuleList([
63
- nn.Sequential(
64
- nn.Linear(d_model, d_model * 4),
65
- nn.GELU(),
66
- nn.Linear(d_model * 4, d_model)
67
- )
68
- for _ in range(n_experts)
69
- ])
70
-
71
- def forward(self, x):
72
- b, l, d = x.shape
73
 
74
- # Flatten for routing
75
- x_flat = x.view(-1, d)
76
-
77
- # Route to experts
78
- router_logits = self.gate(x_flat)
79
- router_probs = nn.functional.softmax(router_logits, dim=-1)
80
-
81
- # --- FIX: SAFE ROUTING LOGIC ---
82
- # Kiểm tra số lượng experts thực tế trong tensor
83
- num_available_experts = router_probs.size(-1)
84
-
85
- # Lấy min để đảm bảo k không bao giờ lớn hơn số expert hiện có
86
- k_safe = min(self.top_k, num_available_experts)
87
 
88
- # Select top-k experts using k_safe
89
- top_k_probs, top_k_indices = torch.topk(router_probs, k_safe, dim=-1)
 
 
90
 
91
- # Normalize probabilities
92
- top_k_probs = top_k_probs / (top_k_probs.sum(dim=-1, keepdim=True) + 1e-9) # Thêm epsilon tránh chia cho 0
93
 
94
- # Combine expert outputs
95
- expert_outputs = torch.zeros_like(x_flat)
96
 
97
- # Loop qua k_safe thay vì self.top_k
98
  for i in range(k_safe):
99
  expert_idx = top_k_indices[:, i]
100
- expert_prob = top_k_probs[:, i:i+1]
101
 
102
- for expert_id in range(self.n_experts):
103
- mask = (expert_idx == expert_id)
104
  if mask.any():
105
- expert_input = x_flat[mask]
106
- expert_output = self.experts[expert_id](expert_input)
107
- expert_outputs[mask] += expert_prob[mask] * expert_output
108
-
109
- return expert_outputs.view(b, l, d)
 
110
 
 
 
 
 
 
 
 
111
 
112
- class KiyEngineMambaBlock(nn.Module):
113
- """Combined Mamba + MoE Block"""
114
- def __init__(self, config):
115
- super().__init__()
116
- self.mamba = MambaBlock(
117
- config.d_model,
118
- config.d_state,
119
- config.d_conv,
120
- config.expansion_factor
121
- )
122
- self.moe = MoELayer(config.d_model, config.n_experts, config.top_k)
123
- self.norm1 = nn.LayerNorm(config.d_model)
124
- self.norm2 = nn.LayerNorm(config.d_model)
125
-
126
- def forward(self, x):
127
- # Mamba branch
128
- x = x + self.mamba(self.norm1(x))
129
- # MoE branch
130
- x = x + self.moe(self.norm2(x))
131
- return x
132
-
133
 
134
  class KiyEngineModel(PreTrainedModel):
135
  """
136
- KiyEngine V3: Mamba-MoE Chess Evaluation Model
137
  """
138
  config_class = KiyEngineConfig
139
-
140
  def __init__(self, config):
141
  super().__init__(config)
142
  self.config = config
143
 
144
- # Embedding layer
145
- self.embeddings = nn.Embedding(config.vocab_size, config.d_model)
 
 
146
 
147
- # Mamba-MoE blocks
148
- self.layers = nn.ModuleList([
149
- KiyEngineMambaBlock(config)
150
- for _ in range(config.n_layers)
151
- ])
152
 
153
- # Final layer norm
154
- self.norm = nn.LayerNorm(config.d_model)
 
 
 
 
 
 
 
155
 
156
  # Initialize weights
157
  self.post_init()
158
-
159
  def forward(
160
  self,
161
- input_ids=None,
162
- attention_mask=None,
163
- return_dict=None,
164
  **kwargs
165
  ):
166
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
167
 
168
- # Embed input
169
- hidden_states = self.embeddings(input_ids)
170
 
171
- # Pass through layers
172
  for layer in self.layers:
173
- hidden_states = layer(hidden_states)
 
 
 
 
174
 
175
- # Final norm
176
- hidden_states = self.norm(hidden_states)
177
 
178
- if not return_dict:
179
- return (hidden_states,)
180
 
181
- return BaseModelOutput(
182
- last_hidden_state=hidden_states,
183
- hidden_states=None,
184
- attentions=None,
 
 
 
185
  )
 
1
+ """
2
+ KiyEngine V3: Mamba-MoE Chess Model
3
+ Matched exactly with standalone_train.py structure for 100% weight compatibility.
4
+ """
5
  import torch
6
  import torch.nn as nn
7
+ import torch.nn.functional as F
8
  from transformers import PreTrainedModel
9
+ from transformers.modeling_outputs import ModelOutput
10
+ from dataclasses import dataclass
11
+ from typing import Optional, Tuple
12
+
13
  from .configuration_kiyengine import KiyEngineConfig
14
 
15
+ # === Helper Classes (Copied & Adapted from Training Script) ===
16
+
17
+ class GaussianNoise(nn.Module):
18
+ def __init__(self, sigma: float = 0.01):
19
+ super().__init__()
20
+ self.sigma = sigma
21
+
22
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
23
+ # Trong Inference, ta luôn tắt Noise (sigma=0 hoặc mode eval)
24
+ if self.training and self.sigma != 0:
25
+ return x + torch.randn_like(x) * self.sigma
26
+ return x
27
+
28
+ class RMSNorm(nn.Module):
29
+ def __init__(self, d_model: int, eps: float = 1e-5):
30
+ super().__init__()
31
+ self.eps = eps
32
+ self.weight = nn.Parameter(torch.ones(d_model))
33
+
34
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
35
+ norm = x.norm(2, dim=-1, keepdim=True) * (x.shape[-1] ** -0.5)
36
+ return x / (norm + self.eps) * self.weight
37
+
38
  class MambaBlock(nn.Module):
39
+ def __init__(self, config):
 
40
  super().__init__()
41
+ # Lấy tham số từ config object
42
+ d_model = config.d_model
43
+ d_state = config.d_state
44
+ d_conv = config.d_conv
45
+ exp_factor = config.expansion_factor
46
+
47
+ d_inner = d_model * exp_factor
48
 
49
+ # Định nghĩa y hệt training script để khớp keys
50
+ self.in_proj = nn.Linear(d_model, 2 * d_inner, bias=False)
51
  self.conv1d = nn.Conv1d(
52
+ in_channels=d_inner,
53
+ out_channels=d_inner,
54
+ kernel_size=d_conv,
55
+ bias=True,
56
+ groups=d_inner,
57
+ padding=d_conv - 1
58
  )
59
+ self.x_proj = nn.Linear(d_inner, d_inner + 2 * d_state, bias=False)
60
+ self.dt_proj = nn.Linear(d_inner, d_inner, bias=True)
61
+ self.A_log = nn.Parameter(torch.randn(d_inner, d_state))
62
+ self.D = nn.Parameter(torch.ones(d_inner))
63
+ self.out_proj = nn.Linear(d_inner, d_model, bias=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
66
+ # Logic forward khớp với training script
67
+ # Lưu ý: Script training của sếp dùng mô hình simplified (Gated CNN)
68
+ # nên ta phải follow đúng logic đó để ra kết quả đúng.
69
+ _, L, C = x.shape
70
+ xz = self.in_proj(x)
71
+ x_inner, z = xz.chunk(2, dim=-1)
72
+
73
+ # Conv1d expects (B, C, L)
74
+ x_conv = self.conv1d(x_inner.transpose(1, 2))[:, :, :L].transpose(1, 2)
75
+ x_activated = F.silu(x_conv)
76
+
77
+ # Element-wise gating with D
78
+ y = x_activated * self.D.unsqueeze(0)
79
+ y = y * F.silu(z)
80
+
81
+ return self.out_proj(y)
82
 
83
  class MoELayer(nn.Module):
84
+ def __init__(self, config):
 
85
  super().__init__()
86
+ self.n_experts = config.n_experts
87
+ self.top_k = config.top_k
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
+ self.router = nn.Linear(config.d_model, self.n_experts)
90
+ self.experts = nn.ModuleList([MambaBlock(config) for _ in range(self.n_experts)])
91
+
92
+ def forward(self, x: torch.Tensor):
93
+ B, L, C = x.shape
94
+ x_flat = x.view(-1, C)
95
+ router_logits = self.router(x_flat)
96
+ router_probs = F.softmax(router_logits, dim=1)
 
 
 
 
 
97
 
98
+ # --- SAFE ROUTING FIX ---
99
+ # Giữ lại fix này để tránh crash nếu config lệch
100
+ num_available = router_probs.size(-1)
101
+ k_safe = min(self.top_k, num_available)
102
 
103
+ top_k_weights, top_k_indices = torch.topk(router_probs, k_safe, dim=-1)
104
+ top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-9)
105
 
106
+ final_output = torch.zeros_like(x_flat)
 
107
 
 
108
  for i in range(k_safe):
109
  expert_idx = top_k_indices[:, i]
110
+ weight = top_k_weights[:, i].unsqueeze(-1)
111
 
112
+ for j in range(self.n_experts):
113
+ mask = expert_idx == j
114
  if mask.any():
115
+ # Logic: Input (N, D) -> Unsqueeze(1) -> (N, 1, D) -> Expert -> Squeeze(1)
116
+ inp = x_flat[mask].unsqueeze(1)
117
+ out = self.experts[j](inp).squeeze(1)
118
+ final_output[mask] += out * weight[mask]
119
+
120
+ return final_output.view(B, L, C)
121
 
122
+ # === Output Class for Hugging Face ===
123
+ @dataclass
124
+ class KiyEngineOutput(ModelOutput):
125
+ loss: Optional[torch.Tensor] = None
126
+ policy_logits: Optional[torch.Tensor] = None
127
+ value: Optional[torch.Tensor] = None
128
+ last_hidden_state: Optional[torch.Tensor] = None
129
 
130
+ # === Main Model Class ===
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  class KiyEngineModel(PreTrainedModel):
133
  """
134
+ KiyEngine V3: Matches exactly the structure of 'standalone_train.py'
135
  """
136
  config_class = KiyEngineConfig
137
+
138
  def __init__(self, config):
139
  super().__init__(config)
140
  self.config = config
141
 
142
+ # --- MATCHING KEYS WITH TRAIN SCRIPT ---
143
+ # Train script: self.embedding (NOT embeddings)
144
+ self.embedding = nn.Embedding(config.vocab_size, config.d_model)
145
+ self.noise = GaussianNoise(sigma=0.0) # Inference mode
146
 
147
+ # Train script: self.layers = ModuleList of MoELayer
148
+ self.layers = nn.ModuleList([MoELayer(config) for _ in range(config.n_layers)])
 
 
 
149
 
150
+ self.norm = RMSNorm(config.d_model)
151
+
152
+ # Train script has heads built-in
153
+ self.policy_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
154
+ self.value_head = nn.Sequential(
155
+ nn.Linear(config.d_model, 128),
156
+ nn.ReLU(),
157
+ nn.Linear(128, 1)
158
+ )
159
 
160
  # Initialize weights
161
  self.post_init()
162
+
163
  def forward(
164
  self,
165
+ input_ids: torch.Tensor,
166
+ return_dict: Optional[bool] = None,
 
167
  **kwargs
168
  ):
169
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
170
 
171
+ # Forward pass matching training logic
172
+ x = self.noise(self.embedding(input_ids))
173
 
 
174
  for layer in self.layers:
175
+ # Training script logic: x = x + layer(norm(x))[0]
176
+ # Our MoELayer returns just the tensor (we dropped aux_loss return for inference clean-up)
177
+ x = x + layer(self.norm(x))
178
+
179
+ x = self.norm(x)
180
 
181
+ # Last token logic
182
+ last_token_state = x[:, -1, :]
183
 
184
+ policy_logits = self.policy_head(last_token_state)
185
+ value = torch.tanh(self.value_head(last_token_state))
186
 
187
+ if not return_dict:
188
+ return (policy_logits, value, x)
189
+
190
+ return KiyEngineOutput(
191
+ policy_logits=policy_logits,
192
+ value=value,
193
+ last_hidden_state=x
194
  )