Kiy-K commited on
Commit
f0bcf7d
·
verified ·
1 Parent(s): 7022ebf

Upload modeling_kiyengine.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_kiyengine.py +294 -0
modeling_kiyengine.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # === Imports ===
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torch.optim as optim
7
+ import chess
8
+ import chess.pgn
9
+ import os
10
+ import random
11
+ import pickle
12
+ import time
13
+ import glob
14
+ from typing import Dict, List, Tuple
15
+ from tqdm import tqdm
16
+ from safetensors.torch import save_file
17
+ from torch.utils.data import Dataset, DataLoader
18
+ from torch.amp import GradScaler, autocast
19
+
20
+ # === Configuration (P100 Optimized & FIXED) ===
21
+ CONFIG = {
22
+ 'model': {
23
+ 'd_model': 384, 'n_layers': 4, 'n_experts': 8, 'top_k': 2, 'd_state': 16,
24
+ 'd_conv': 4, 'expansion_factor': 2, 'vocab_size': 768,
25
+ },
26
+ 'training': {
27
+ 'batch_size': 4096,
28
+ 'learning_rate': 4.0e-4,
29
+ 'epochs': 10,
30
+ 'noise_sigma': 0.01,
31
+ 'save_every_mins': 15,
32
+ 'keep_checkpoints': 2,
33
+ # --- [FIX HERE] Trả lại các trọng số đã bị thất lạc ---
34
+ 'policy_weight': 1.0,
35
+ 'value_weight': 1.0,
36
+ 'aux_loss_lambda': 0.01,
37
+ # -----------------------------------------------------
38
+ },
39
+ 'paths': {
40
+ 'train_data_path': "/kaggle/working/train_data.pgn",
41
+ 'save_path': "./snapshots",
42
+ 'model_save_name': "model.safetensors",
43
+ },
44
+ }
45
+
46
+ # === Helper: Data Prefetcher ===
47
+ class DataPrefetcher:
48
+ def __init__(self, loader, device):
49
+ self.loader = iter(loader)
50
+ self.device = device
51
+ self.stream = torch.cuda.Stream()
52
+ self.preload()
53
+
54
+ def preload(self):
55
+ try:
56
+ self.next_batch = next(self.loader)
57
+ except StopIteration:
58
+ self.next_batch = None
59
+ return
60
+
61
+ with torch.cuda.stream(self.stream):
62
+ self.next_batch = [x.to(self.device, non_blocking=True) for x in self.next_batch]
63
+
64
+ def next(self):
65
+ torch.cuda.current_stream().wait_stream(self.stream)
66
+ batch = self.next_batch
67
+ self.preload()
68
+ return batch
69
+
70
+ # === Helper: Rolling Checkpoint Manager ===
71
+ def manage_checkpoints(save_dir, keep_n=2):
72
+ files = glob.glob(os.path.join(save_dir, "checkpoint_*.safetensors"))
73
+ files.sort(key=os.path.getmtime)
74
+ while len(files) > keep_n:
75
+ oldest_file = files.pop(0)
76
+ try:
77
+ os.remove(oldest_file)
78
+ print(f"🗑️ Cleaned up old checkpoint: {oldest_file}")
79
+ except OSError as e:
80
+ print(f"⚠️ Error deleting file {oldest_file}: {e}")
81
+
82
+ # === Model Architecture (Mamba + MoE) ===
83
+ class GaussianNoise(nn.Module):
84
+ def __init__(self, sigma: float = 0.01): super().__init__(); self.sigma = sigma
85
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
86
+ if self.training and self.sigma != 0: return x + torch.randn_like(x) * self.sigma
87
+ return x
88
+
89
+ class RMSNorm(nn.Module):
90
+ def __init__(self, d_model: int, eps: float = 1e-5):
91
+ super().__init__(); self.eps = eps; self.weight = nn.Parameter(torch.ones(d_model))
92
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
93
+ norm = x.norm(2, dim=-1, keepdim=True) * (x.shape[-1] ** -0.5)
94
+ return x / (norm + self.eps) * self.weight
95
+
96
+ class MambaBlock(nn.Module):
97
+ def __init__(self, config: Dict):
98
+ super().__init__()
99
+ d_model, d_state, d_conv, exp_factor = config['d_model'], config['d_state'], config['d_conv'], config['expansion_factor']
100
+ d_inner = d_model * exp_factor
101
+ self.in_proj = nn.Linear(d_model, 2 * d_inner, bias=False)
102
+ self.conv1d = nn.Conv1d(in_channels=d_inner, out_channels=d_inner, kernel_size=d_conv, bias=True, groups=d_inner, padding=d_conv - 1)
103
+ self.x_proj = nn.Linear(d_inner, d_inner + 2 * d_state, bias=False)
104
+ self.dt_proj = nn.Linear(d_inner, d_inner, bias=True)
105
+ self.A_log = nn.Parameter(torch.randn(d_inner, d_state)); self.D = nn.Parameter(torch.ones(d_inner))
106
+ self.out_proj = nn.Linear(d_inner, d_model, bias=False)
107
+
108
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
109
+ _, L, C = x.shape; xz = self.in_proj(x); x_inner, z = xz.chunk(2, dim=-1)
110
+ x_conv = self.conv1d(x_inner.transpose(1, 2))[:, :, :L].transpose(1, 2); x_activated = F.silu(x_conv)
111
+ y = x_activated * self.D.unsqueeze(0); y = y * F.silu(z)
112
+ return self.out_proj(y)
113
+
114
+ class MoELayer(nn.Module):
115
+ def __init__(self, config: Dict):
116
+ super().__init__(); self.n_experts, self.top_k = config['n_experts'], config['top_k']
117
+ self.router = nn.Linear(config['d_model'], self.n_experts)
118
+ self.experts = nn.ModuleList([MambaBlock(config) for _ in range(self.n_experts)])
119
+
120
+ def forward(self, x: torch.Tensor) -> (torch.Tensor, torch.Tensor):
121
+ B, L, C = x.shape; x_flat = x.view(-1, C); router_logits = self.router(x_flat)
122
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
123
+ top_k_weights, top_k_indices = torch.topk(routing_weights, self.top_k, dim=-1)
124
+ top_k_weights /= top_k_weights.sum(dim=-1, keepdim=True)
125
+ expert_mask = F.one_hot(top_k_indices, self.n_experts).sum(dim=1); expert_load = expert_mask.float().mean(dim=0)
126
+ aux_loss = (expert_load * expert_load).sum()
127
+ final_output = torch.zeros_like(x_flat)
128
+ for i in range(self.top_k):
129
+ expert_idx = top_k_indices[:, i]; weight = top_k_weights[:, i].unsqueeze(-1)
130
+ for j in range(self.n_experts):
131
+ mask = expert_idx == j
132
+ if mask.any(): final_output[mask] += (self.experts[j](x_flat[mask].unsqueeze(1)).squeeze(1) * weight[mask])
133
+ return final_output.view(B, L, C), aux_loss
134
+
135
+ class KiyEngineV3(nn.Module):
136
+ def __init__(self, config: Dict):
137
+ super().__init__(); self.config = config
138
+ self.embedding = nn.Embedding(config['vocab_size'], config['d_model'])
139
+ self.noise = GaussianNoise(sigma=config.get('training', {}).get('noise_sigma', 0.0))
140
+ self.layers = nn.ModuleList([MoELayer(config) for _ in range(config['n_layers'])])
141
+ self.norm = RMSNorm(config['d_model'])
142
+ self.policy_head = nn.Linear(config['d_model'], config['vocab_size'], bias=False)
143
+ self.value_head = nn.Sequential(nn.Linear(config['d_model'], 128), nn.ReLU(), nn.Linear(128, 1))
144
+
145
+ def forward(self, input_ids: torch.Tensor) -> (torch.Tensor, torch.Tensor, torch.Tensor):
146
+ x = self.noise(self.embedding(input_ids)); total_aux_loss = 0.0
147
+ for layer in self.layers: x = x + layer(self.norm(x))[0]; total_aux_loss += layer(self.norm(x))[1]
148
+ x = self.norm(x); last_token_state = x[:, -1, :]
149
+ policy_logits = self.policy_head(last_token_state); value = torch.tanh(self.value_head(last_token_state))
150
+ return policy_logits, value, total_aux_loss / self.config['n_layers']
151
+
152
+ # === Data Pipeline (Header Only + Robust) ===
153
+ def move_to_token(move, board):
154
+ piece = board.piece_at(move.from_square)
155
+ if piece is None: return 0
156
+ piece_idx = move.promotion - 1 if move.promotion else piece.piece_type - 1
157
+ if piece.color == chess.BLACK: piece_idx += 6
158
+ return piece_idx * 64 + move.to_square
159
+
160
+ class ChessDataset(Dataset):
161
+ def __init__(self, pgn_file_path, context_length=16):
162
+ self.pgn_file_path = pgn_file_path
163
+ self.context_length = context_length
164
+ self.games = self._index_games(pgn_file_path)
165
+
166
+ def _index_games(self, pgn_file_path):
167
+ index_path = pgn_file_path + ".index.pkl"
168
+ if os.path.exists(index_path):
169
+ print(f"🚀 Loading cached index from {index_path}...")
170
+ with open(index_path, "rb") as f: return pickle.load(f)
171
+
172
+ print(f"⚡ Turbo Indexing {pgn_file_path} (Header Only Mode)...")
173
+ offsets = []
174
+ count = 0
175
+ with open(pgn_file_path) as pgn:
176
+ while True:
177
+ offset = pgn.tell()
178
+ headers = chess.pgn.read_headers(pgn)
179
+ if headers is None: break
180
+ res = headers.get("Result", "*")
181
+ val = 0.0
182
+ if res == "1-0": val = 1.0
183
+ elif res == "0-1": val = -1.0
184
+ elif res == "1/2-1/2": val = 0.0
185
+ else: continue
186
+ offsets.append((offset, val))
187
+ count += 1
188
+ if count % 50000 == 0: print(f"Indexed {count} games...", end='\r')
189
+
190
+ print(f"\n✅ Done! Found {len(offsets)} valid games.")
191
+ with open(index_path, "wb") as f: pickle.dump(offsets, f)
192
+ return offsets
193
+
194
+ def __len__(self): return len(self.games)
195
+
196
+ def __getitem__(self, idx):
197
+ offset, value = self.games[idx]
198
+ try:
199
+ with open(self.pgn_file_path) as f:
200
+ f.seek(offset)
201
+ game = chess.pgn.read_game(f)
202
+ if game is None or game.errors: return torch.zeros(self.context_length, dtype=torch.long), torch.tensor(0, dtype=torch.long), torch.tensor([0.0])
203
+ moves = list(game.mainline_moves())
204
+ if len(moves) <= self.context_length: return torch.zeros(self.context_length, dtype=torch.long), torch.tensor(0, dtype=torch.long), torch.tensor([0.0])
205
+
206
+ start_ply = random.randint(0, len(moves) - self.context_length - 1)
207
+ move_history = moves[start_ply : start_ply + self.context_length]
208
+ target_move = moves[start_ply + self.context_length]
209
+
210
+ board = chess.Board()
211
+ for i in range(start_ply): board.push(moves[i])
212
+ temp_board = board.copy()
213
+ seq = []
214
+ for move in move_history:
215
+ seq.append(move_to_token(move, temp_board))
216
+ temp_board.push(move)
217
+ target_token = move_to_token(target_move, temp_board)
218
+ return torch.tensor(seq), torch.tensor(target_token), torch.tensor([value])
219
+ except Exception:
220
+ return torch.zeros(self.context_length, dtype=torch.long), torch.tensor(0, dtype=torch.long), torch.tensor([0.0])
221
+
222
+ # === Training Loop (Single GPU Optimized) ===
223
+ def train_main():
224
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
225
+ print(f"🔥 Hardware: {torch.cuda.get_device_name(0)}")
226
+
227
+ os.makedirs(CONFIG['paths']['save_path'], exist_ok=True)
228
+
229
+ model = KiyEngineV3(CONFIG['model']).to(device)
230
+ dataset = ChessDataset(CONFIG['paths']['train_data_path'])
231
+
232
+ dataloader = DataLoader(dataset, batch_size=CONFIG['training']['batch_size'],
233
+ shuffle=True, num_workers=os.cpu_count(), pin_memory=True)
234
+
235
+ optimizer = optim.Adam(model.parameters(), lr=CONFIG['training']['learning_rate'])
236
+ scaler = GradScaler('cuda')
237
+
238
+ print("🚀 Starting P100 Turbo Training...")
239
+
240
+ last_save_time = time.time()
241
+
242
+ for epoch in range(CONFIG['training']['epochs']):
243
+ prefetcher = DataPrefetcher(dataloader, device)
244
+ batch = prefetcher.next()
245
+
246
+ pbar = tqdm(total=len(dataloader), desc=f"Epoch {epoch+1}")
247
+
248
+ batch_idx = 0
249
+ while batch is not None:
250
+ input_seq, policy_target, value_target = batch
251
+
252
+ optimizer.zero_grad()
253
+ with autocast('cuda'):
254
+ policy_logits, value_pred, aux_loss = model(input_seq)
255
+ policy_loss = F.cross_entropy(policy_logits, policy_target)
256
+ value_loss = F.mse_loss(value_pred.squeeze(), value_target.squeeze())
257
+ # --- Hàng về rồi đây ---
258
+ loss = CONFIG['training']['policy_weight'] * policy_loss + CONFIG['training']['value_weight'] * value_loss + CONFIG['training']['aux_loss_lambda'] * aux_loss
259
+
260
+ scaler.scale(loss).backward()
261
+ scaler.step(optimizer)
262
+ scaler.update()
263
+
264
+ if (time.time() - last_save_time) > (CONFIG['training']['save_every_mins'] * 60):
265
+ checkpoint_name = f"checkpoint_ep{epoch+1}_step{batch_idx}.safetensors"
266
+ save_path = os.path.join(CONFIG['paths']['save_path'], checkpoint_name)
267
+
268
+ model_to_save = model
269
+ tensors = {name: param for name, param in model_to_save.state_dict().items()}
270
+ save_file(tensors, save_path)
271
+
272
+ print(f"\n💾 Auto-saved: {checkpoint_name}")
273
+ manage_checkpoints(CONFIG['paths']['save_path'], keep_n=CONFIG['training']['keep_checkpoints'])
274
+ last_save_time = time.time()
275
+
276
+ if batch_idx % 100 == 0:
277
+ with open("training_progress.log", "a") as f:
278
+ f.write(f"Epoch {epoch+1} | Batch {batch_idx} | Loss: {loss.item():.4f}\n")
279
+
280
+ pbar.set_postfix({"Loss": f"{loss.item():.4f}"})
281
+ pbar.update(1)
282
+
283
+ batch = prefetcher.next()
284
+ batch_idx += 1
285
+
286
+ pbar.close()
287
+
288
+ final_path = os.path.join(CONFIG['paths']['save_path'], CONFIG['paths']['model_save_name'])
289
+ tensors = {name: param for name, param in model.state_dict().items()}
290
+ save_file(tensors, final_path)
291
+ print(f"🏁 Model saved to {final_path}")
292
+
293
+ if __name__ == "__main__":
294
+ train_main()