trioskosmos commited on
Commit
f6a813d
·
verified ·
1 Parent(s): c6d22b8

Upload ai/training/train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ai/training/train.py +280 -0
ai/training/train.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import glob
3
+ import os
4
+ import random
5
+ import sys
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.utils.data import DataLoader, Dataset
11
+ from tqdm import tqdm
12
+
13
+ # Add project root to path
14
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
15
+ from ai.models.network_torch import TransformerCardNet
16
+
17
+
18
+ class ChunkDataset(Dataset):
19
+ """Dataset for a single chunk of data."""
20
+
21
+ def __init__(self, states, policies, winners, scores, turns):
22
+ self.states = torch.as_tensor(states, dtype=torch.float32)
23
+ self.policies = torch.as_tensor(policies, dtype=torch.float32)
24
+ self.winners = torch.as_tensor(winners, dtype=torch.float32).view(-1, 1)
25
+ self.scores = torch.as_tensor(scores, dtype=torch.float32).view(-1, 1)
26
+ self.turns = torch.as_tensor(turns, dtype=torch.float32).view(-1, 1)
27
+
28
+ def __len__(self):
29
+ return len(self.states)
30
+
31
+ def __getitem__(self, idx):
32
+ return (self.states[idx], self.policies[idx], self.winners[idx], self.scores[idx], self.turns[idx])
33
+
34
+
35
+ def get_data_from_files(files):
36
+ """Loads and concatenates data from a list of files."""
37
+ all_states, all_policies, all_winners = [], [], []
38
+ all_scores, all_turns = [], []
39
+
40
+ for f in files:
41
+ data = np.load(f)
42
+ all_states.append(data["states"])
43
+ all_policies.append(data["policies"])
44
+ all_winners.append(data["winners"])
45
+ all_scores.append(data["scores"])
46
+ all_turns.append(data["turns_left"])
47
+
48
+ return (
49
+ np.concatenate(all_states),
50
+ np.concatenate(all_policies),
51
+ np.concatenate(all_winners),
52
+ np.concatenate(all_scores),
53
+ np.concatenate(all_turns),
54
+ )
55
+
56
+
57
+ def train(data_pattern, epochs=20, batch_size=16384, lr=0.001, resume_path=None):
58
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
59
+ print(f"Training on {device} with batch size {batch_size}")
60
+
61
+ # Locate all matching files
62
+ all_files = sorted(glob.glob(data_pattern))
63
+ if not all_files:
64
+ print(f"Error: No files found matching {data_pattern}")
65
+ return
66
+
67
+ print(f"Found {len(all_files)} data chunks.")
68
+
69
+ # Reserve ~10% of files for validation (exclusive)
70
+ # If only 1 file, use it for BOTH training and validation
71
+ if len(all_files) == 1:
72
+ train_files = val_files = all_files
73
+ else:
74
+ val_count = max(1, len(all_files) // 10)
75
+ # Ensure at least one training file exists
76
+ if val_count >= len(all_files):
77
+ val_count = 0
78
+
79
+ rand_gen = random.Random(42)
80
+ shuffled_files = all_files.copy()
81
+ rand_gen.shuffle(shuffled_files)
82
+
83
+ train_files = shuffled_files[:-val_count] if val_count > 0 else shuffled_files
84
+ val_files = shuffled_files[-val_count:] if val_count > 0 else shuffled_files
85
+
86
+ print(f"Split: {len(train_files)} training files, {len(val_files)} validation files.")
87
+
88
+ # Initialize model
89
+ model = TransformerCardNet().to(device)
90
+ optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
91
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", patience=3, factor=0.5)
92
+
93
+ # Loss functions
94
+ policy_loss_fn = nn.CrossEntropyLoss(reduction="none")
95
+ win_loss_fn = nn.BCELoss()
96
+ mse_loss_fn = nn.MSELoss()
97
+
98
+ best_val_loss = float("inf")
99
+ start_epoch = 0
100
+
101
+ # Resume Logic
102
+ if resume_path and os.path.exists(resume_path):
103
+ print(f" -> Resuming from checkpoint: {resume_path}")
104
+ checkpoint = torch.load(resume_path, map_location=device)
105
+ model.load_state_dict(checkpoint["model_state"])
106
+ if "optimizer_state" in checkpoint:
107
+ optimizer.load_state_dict(checkpoint["optimizer_state"])
108
+ if "scheduler_state" in checkpoint:
109
+ scheduler.load_state_dict(checkpoint["scheduler_state"])
110
+ start_epoch = checkpoint.get("epoch", -1) + 1
111
+ best_val_loss = checkpoint.get("val_loss", float("inf"))
112
+ print(f" -> Starting from epoch {start_epoch + 1}")
113
+
114
+ try:
115
+ for epoch in range(start_epoch, epochs):
116
+ model.train()
117
+ train_losses = {"poly": 0, "win": 0, "score": 0, "turn": 0}
118
+ correct_policy = 0
119
+ total_policy = 0
120
+
121
+ current_train_files = train_files.copy()
122
+ random.shuffle(current_train_files)
123
+
124
+ pbar = tqdm(current_train_files, desc=f"Epoch {epoch + 1}/{epochs}")
125
+ for f_idx, f_path in enumerate(pbar):
126
+ try:
127
+ s, pol, w, sc, tn = get_data_from_files([f_path])
128
+ chunk_dataset = ChunkDataset(s, pol, w, sc, tn)
129
+ chunk_loader = DataLoader(chunk_dataset, batch_size=batch_size, shuffle=True)
130
+
131
+ for states, target_p, target_w, target_s, target_t in chunk_loader:
132
+ states = states.to(device)
133
+ target_p = target_p.to(device)
134
+ target_w = target_w.to(device)
135
+ target_s = target_s.to(device)
136
+ target_t = target_t.to(device)
137
+
138
+ optimizer.zero_grad()
139
+
140
+ # Forward pass
141
+ p_soft, w_pred, s_pred, t_pred = model(states)
142
+
143
+ # 1. Policy Loss (Weighted)
144
+ # Avoid log(0) by using logits or CrossEntropy internally
145
+ # But TransformerCardNet returns softmax, so we use NLL or manual CE
146
+ # For simplicity, let's use the probabilities directly or re-logit
147
+ p_loss_raw = -torch.sum(target_p * torch.log(p_soft + 1e-8), dim=1)
148
+
149
+ # Weight non-pass actions higher
150
+ target_argmax = torch.max(target_p, dim=1)[1]
151
+ weights = torch.ones(states.size(0), device=device)
152
+ weights[target_argmax != 0] = 5.0
153
+ p_loss = (p_loss_raw * weights).mean()
154
+
155
+ # 2. Value Losses
156
+ loss_win = win_loss_fn(w_pred, target_w)
157
+ loss_score = mse_loss_fn(s_pred, target_s)
158
+ loss_turn = mse_loss_fn(t_pred, target_t)
159
+
160
+ total_loss = p_loss + loss_win + loss_score + loss_turn
161
+ total_loss.backward()
162
+ optimizer.step()
163
+
164
+ train_losses["poly"] += p_loss.item()
165
+ train_losses["win"] += loss_win.item()
166
+ train_losses["score"] += loss_score.item()
167
+ train_losses["turn"] += loss_turn.item()
168
+
169
+ _, pred_action = torch.max(p_soft, 1)
170
+ _, target_action = torch.max(target_p, 1)
171
+ correct_policy += (pred_action == target_action).sum().item()
172
+ total_policy += states.size(0)
173
+
174
+ pbar.set_postfix(
175
+ {
176
+ "acc": f"{100 * correct_policy / total_policy:.1f}%",
177
+ "win": f"{train_losses['win'] / (f_idx + 1):.3f}",
178
+ }
179
+ )
180
+
181
+ del chunk_dataset, chunk_loader, s, pol, w, sc, tn
182
+ gc.collect()
183
+ except Exception as e:
184
+ print(f"Error processing chunk {f_path}: {e}")
185
+ continue
186
+
187
+ # Validation
188
+ model.eval()
189
+ val_losses = {"poly": 0, "win": 0, "score": 0, "turn": 0}
190
+ val_correct = 0
191
+ val_total = 0
192
+ num_batches_val = 0
193
+
194
+ print(f" [Epoch {epoch + 1}] Validating...")
195
+ with torch.no_grad():
196
+ for f_path in val_files:
197
+ try:
198
+ s, pol, w, sc, tn = get_data_from_files([f_path])
199
+ val_chunk = ChunkDataset(s, pol, w, sc, tn)
200
+ val_loader = DataLoader(val_chunk, batch_size=batch_size, shuffle=False)
201
+
202
+ for states, target_p, target_w, target_s, target_t in val_loader:
203
+ states = states.to(device)
204
+ target_p = target_p.to(device)
205
+ target_w = target_w.to(device)
206
+ target_s = target_s.to(device)
207
+ target_t = target_t.to(device)
208
+
209
+ p_soft, w_pred, s_pred, t_pred = model(states)
210
+
211
+ p_loss_raw = -torch.sum(target_p * torch.log(p_soft + 1e-8), dim=1)
212
+ val_losses["poly"] += p_loss_raw.mean().item()
213
+ val_losses["win"] += win_loss_fn(w_pred, target_w).item()
214
+ val_losses["score"] += mse_loss_fn(s_pred, target_s).item()
215
+ val_losses["turn"] += mse_loss_fn(t_pred, target_t).item()
216
+
217
+ num_batches_val += 1
218
+ _, pred_action = torch.max(p_soft, 1)
219
+ _, target_action = torch.max(target_p, 1)
220
+ val_correct += (pred_action == target_action).sum().item()
221
+ val_total += states.size(0)
222
+
223
+ del val_chunk, val_loader, s, pol, w, sc, tn
224
+ gc.collect()
225
+ except Exception as e:
226
+ print(f"Error validating chunk {f_path}: {e}")
227
+
228
+ if val_total > 0:
229
+ avg_val_win = val_losses["win"] / num_batches_val
230
+ avg_val_total = (val_losses["poly"] + val_losses["win"] + val_losses["score"]) / num_batches_val
231
+
232
+ scheduler.step(avg_val_total)
233
+
234
+ print(f"Epoch {epoch + 1} | Val WinLoss: {avg_val_win:.4f} | Acc: {100 * val_correct / val_total:.1f}%")
235
+
236
+ checkpoint = {
237
+ "model_state": model.state_dict(),
238
+ "optimizer_state": optimizer.state_dict(),
239
+ "val_loss": avg_val_total,
240
+ "epoch": epoch,
241
+ }
242
+ torch.save(checkpoint, f"ai/models/transformer_epoch_{epoch + 1}.pt")
243
+
244
+ if avg_val_total < best_val_loss:
245
+ best_val_loss = avg_val_total
246
+ torch.save(checkpoint, "ai/models/transformer_best.pt")
247
+
248
+ except KeyboardInterrupt:
249
+ print("\nTraining interrupted by user. Saving current state to ai/models/alphanet_interrupted.pt...")
250
+ # Save complete state for resumption
251
+ interrupted_state = {
252
+ "model_state": model.state_dict(),
253
+ "optimizer_state": optimizer.state_dict(),
254
+ "scheduler_state": scheduler.state_dict() if "scheduler" in locals() else None,
255
+ "epoch": epoch if "epoch" in locals() else 0,
256
+ "val_loss": best_val_loss,
257
+ }
258
+ torch.save(interrupted_state, "ai/models/alphanet_interrupted.pt")
259
+
260
+ torch.save(model.state_dict(), "ai/models/alphanet_final.pt")
261
+ print(f"Training complete. Best Val Loss: {best_val_loss:.4f}")
262
+
263
+
264
+ if __name__ == "__main__":
265
+ import argparse
266
+
267
+ parser = argparse.ArgumentParser()
268
+ parser.add_argument("--data", type=str, default="ai/data/alphazero_nightly_chunk_*.npz")
269
+ parser.add_argument("--epochs", type=int, default=10)
270
+ parser.add_argument("--batch-size", type=int, default=16384)
271
+ parser.add_argument("--lr", type=float, default=0.001)
272
+ parser.add_argument(
273
+ "--resume", type=str, help="Path to checkpoint to resume from (e.g. ai/models/alphanet_interrupted.pt)"
274
+ )
275
+ args = parser.parse_args()
276
+
277
+ if not os.path.exists("ai/models"):
278
+ os.makedirs("ai/models")
279
+
280
+ train(args.data, epochs=args.epochs, batch_size=args.batch_size, lr=args.lr, resume_path=args.resume)