File size: 6,934 Bytes
c335116
 
 
 
f0bcf7d
 
c335116
95feed4
c335116
 
 
 
4fcbdfa
f0bcf7d
c335116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f0bcf7d
c335116
f0bcf7d
c335116
 
 
 
 
 
 
95feed4
c335116
 
95feed4
c335116
 
 
 
 
 
95feed4
c335116
 
 
 
 
f0bcf7d
c335116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f0bcf7d
 
c335116
95feed4
c335116
 
95feed4
c335116
 
 
 
 
 
 
 
8a7c276
c335116
 
 
 
8a7c276
c335116
 
95feed4
c335116
8a7c276
 
95feed4
c335116
95feed4
c335116
 
95feed4
c335116
 
 
 
 
 
f0bcf7d
c335116
 
 
 
 
 
 
f0bcf7d
c335116
f0bcf7d
95feed4
 
c335116
95feed4
 
c335116
95feed4
 
 
 
c335116
 
 
 
95feed4
c335116
 
95feed4
c335116
 
 
 
 
 
 
 
 
95feed4
 
 
c335116
95feed4
 
c335116
 
95feed4
 
 
f0bcf7d
c335116
 
f0bcf7d
95feed4
c335116
 
 
 
 
f0bcf7d
c335116
 
95feed4
c335116
 
95feed4
c335116
 
 
 
 
 
 
8a7c276
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
"""
KiyEngine V3: Mamba-MoE Chess Model
Matched exactly with standalone_train.py structure for 100% weight compatibility.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.modeling_outputs import ModelOutput
from dataclasses import dataclass
from typing import Optional, Tuple

from .configuration_kiyengine import KiyEngineConfig

# === Helper Classes (Copied & Adapted from Training Script) ===

class GaussianNoise(nn.Module):
    def __init__(self, sigma: float = 0.01):
        super().__init__()
        self.sigma = sigma
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Trong Inference, ta luôn tắt Noise (sigma=0 hoặc mode eval)
        if self.training and self.sigma != 0:
            return x + torch.randn_like(x) * self.sigma
        return x

class RMSNorm(nn.Module):
    def __init__(self, d_model: int, eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d_model))
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        norm = x.norm(2, dim=-1, keepdim=True) * (x.shape[-1] ** -0.5)
        return x / (norm + self.eps) * self.weight

class MambaBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        # Lấy tham số từ config object
        d_model = config.d_model
        d_state = config.d_state
        d_conv = config.d_conv
        exp_factor = config.expansion_factor
        
        d_inner = d_model * exp_factor
        
        # Định nghĩa y hệt training script để khớp keys
        self.in_proj = nn.Linear(d_model, 2 * d_inner, bias=False)
        self.conv1d = nn.Conv1d(
            in_channels=d_inner, 
            out_channels=d_inner, 
            kernel_size=d_conv, 
            bias=True, 
            groups=d_inner, 
            padding=d_conv - 1
        )
        self.x_proj = nn.Linear(d_inner, d_inner + 2 * d_state, bias=False)
        self.dt_proj = nn.Linear(d_inner, d_inner, bias=True)
        self.A_log = nn.Parameter(torch.randn(d_inner, d_state))
        self.D = nn.Parameter(torch.ones(d_inner))
        self.out_proj = nn.Linear(d_inner, d_model, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Logic forward khớp với training script
        # Lưu ý: Script training của sếp dùng mô hình simplified (Gated CNN)
        # nên ta phải follow đúng logic đó để ra kết quả đúng.
        _, L, C = x.shape
        xz = self.in_proj(x)
        x_inner, z = xz.chunk(2, dim=-1)
        
        # Conv1d expects (B, C, L)
        x_conv = self.conv1d(x_inner.transpose(1, 2))[:, :, :L].transpose(1, 2)
        x_activated = F.silu(x_conv)
        
        # Element-wise gating with D
        y = x_activated * self.D.unsqueeze(0)
        y = y * F.silu(z)
        
        return self.out_proj(y)

class MoELayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_experts = config.n_experts
        self.top_k = config.top_k
        
        self.router = nn.Linear(config.d_model, self.n_experts)
        self.experts = nn.ModuleList([MambaBlock(config) for _ in range(self.n_experts)])

    def forward(self, x: torch.Tensor):
        B, L, C = x.shape
        x_flat = x.view(-1, C)
        router_logits = self.router(x_flat)
        router_probs = F.softmax(router_logits, dim=1)
        
        # --- SAFE ROUTING FIX ---
        # Giữ lại fix này để tránh crash nếu config lệch
        num_available = router_probs.size(-1)
        k_safe = min(self.top_k, num_available)
        
        top_k_weights, top_k_indices = torch.topk(router_probs, k_safe, dim=-1)
        top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-9)
        
        final_output = torch.zeros_like(x_flat)
        
        for i in range(k_safe):
            expert_idx = top_k_indices[:, i]
            weight = top_k_weights[:, i].unsqueeze(-1)
            
            for j in range(self.n_experts):
                mask = expert_idx == j
                if mask.any():
                    # Logic: Input (N, D) -> Unsqueeze(1) -> (N, 1, D) -> Expert -> Squeeze(1)
                    inp = x_flat[mask].unsqueeze(1)
                    out = self.experts[j](inp).squeeze(1)
                    final_output[mask] += out * weight[mask]
                    
        return final_output.view(B, L, C)

# === Output Class for Hugging Face ===
@dataclass
class KiyEngineOutput(ModelOutput):
    loss: Optional[torch.Tensor] = None
    policy_logits: Optional[torch.Tensor] = None
    value: Optional[torch.Tensor] = None
    last_hidden_state: Optional[torch.Tensor] = None

# === Main Model Class ===

class KiyEngineModel(PreTrainedModel):
    """
    KiyEngine V3: Matches exactly the structure of 'standalone_train.py'
    """
    config_class = KiyEngineConfig

    def __init__(self, config):
        super().__init__(config)
        self.config = config
        
        # --- MATCHING KEYS WITH TRAIN SCRIPT ---
        # Train script: self.embedding (NOT embeddings)
        self.embedding = nn.Embedding(config.vocab_size, config.d_model)
        self.noise = GaussianNoise(sigma=0.0) # Inference mode
        
        # Train script: self.layers = ModuleList of MoELayer
        self.layers = nn.ModuleList([MoELayer(config) for _ in range(config.n_layers)])
        
        self.norm = RMSNorm(config.d_model)
        
        # Train script has heads built-in
        self.policy_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
        self.value_head = nn.Sequential(
            nn.Linear(config.d_model, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )
        
        # Initialize weights
        self.post_init()

    def forward(
        self,
        input_ids: torch.Tensor,
        return_dict: Optional[bool] = None,
        **kwargs
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        # Forward pass matching training logic
        x = self.noise(self.embedding(input_ids))
        
        for layer in self.layers:
            # Training script logic: x = x + layer(norm(x))[0]
            # Our MoELayer returns just the tensor (we dropped aux_loss return for inference clean-up)
            x = x + layer(self.norm(x))
            
        x = self.norm(x)
        
        # Last token logic
        last_token_state = x[:, -1, :]
        
        policy_logits = self.policy_head(last_token_state)
        value = torch.tanh(self.value_head(last_token_state))
        
        if not return_dict:
            return (policy_logits, value, x)
            
        return KiyEngineOutput(
            policy_logits=policy_logits,
            value=value,
            last_hidden_state=x
        )