Spaces:
Running
Running
| import gc | |
| import glob | |
| import os | |
| import random | |
| import sys | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from torch.utils.data import DataLoader, Dataset | |
| from tqdm import tqdm | |
| # Add project root to path | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) | |
| from ai.models.network_torch import TransformerCardNet | |
| class ChunkDataset(Dataset): | |
| """Dataset for a single chunk of data.""" | |
| def __init__(self, states, policies, winners, scores, turns): | |
| self.states = torch.as_tensor(states, dtype=torch.float32) | |
| self.policies = torch.as_tensor(policies, dtype=torch.float32) | |
| self.winners = torch.as_tensor(winners, dtype=torch.float32).view(-1, 1) | |
| self.scores = torch.as_tensor(scores, dtype=torch.float32).view(-1, 1) | |
| self.turns = torch.as_tensor(turns, dtype=torch.float32).view(-1, 1) | |
| def __len__(self): | |
| return len(self.states) | |
| def __getitem__(self, idx): | |
| return (self.states[idx], self.policies[idx], self.winners[idx], self.scores[idx], self.turns[idx]) | |
| def get_data_from_files(files): | |
| """Loads and concatenates data from a list of files.""" | |
| all_states, all_policies, all_winners = [], [], [] | |
| all_scores, all_turns = [], [] | |
| for f in files: | |
| data = np.load(f) | |
| all_states.append(data["states"]) | |
| all_policies.append(data["policies"]) | |
| all_winners.append(data["winners"]) | |
| all_scores.append(data["scores"]) | |
| all_turns.append(data["turns_left"]) | |
| return ( | |
| np.concatenate(all_states), | |
| np.concatenate(all_policies), | |
| np.concatenate(all_winners), | |
| np.concatenate(all_scores), | |
| np.concatenate(all_turns), | |
| ) | |
| def train(data_pattern, epochs=20, batch_size=16384, lr=0.001, resume_path=None): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Training on {device} with batch size {batch_size}") | |
| # Locate all matching files | |
| all_files = sorted(glob.glob(data_pattern)) | |
| if not all_files: | |
| print(f"Error: No files found matching {data_pattern}") | |
| return | |
| print(f"Found {len(all_files)} data chunks.") | |
| # Reserve ~10% of files for validation (exclusive) | |
| # If only 1 file, use it for BOTH training and validation | |
| if len(all_files) == 1: | |
| train_files = val_files = all_files | |
| else: | |
| val_count = max(1, len(all_files) // 10) | |
| # Ensure at least one training file exists | |
| if val_count >= len(all_files): | |
| val_count = 0 | |
| rand_gen = random.Random(42) | |
| shuffled_files = all_files.copy() | |
| rand_gen.shuffle(shuffled_files) | |
| train_files = shuffled_files[:-val_count] if val_count > 0 else shuffled_files | |
| val_files = shuffled_files[-val_count:] if val_count > 0 else shuffled_files | |
| print(f"Split: {len(train_files)} training files, {len(val_files)} validation files.") | |
| # Initialize model | |
| model = TransformerCardNet().to(device) | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4) | |
| scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", patience=3, factor=0.5) | |
| # Loss functions | |
| policy_loss_fn = nn.CrossEntropyLoss(reduction="none") | |
| win_loss_fn = nn.BCELoss() | |
| mse_loss_fn = nn.MSELoss() | |
| best_val_loss = float("inf") | |
| start_epoch = 0 | |
| # Resume Logic | |
| if resume_path and os.path.exists(resume_path): | |
| print(f" -> Resuming from checkpoint: {resume_path}") | |
| checkpoint = torch.load(resume_path, map_location=device) | |
| model.load_state_dict(checkpoint["model_state"]) | |
| if "optimizer_state" in checkpoint: | |
| optimizer.load_state_dict(checkpoint["optimizer_state"]) | |
| if "scheduler_state" in checkpoint: | |
| scheduler.load_state_dict(checkpoint["scheduler_state"]) | |
| start_epoch = checkpoint.get("epoch", -1) + 1 | |
| best_val_loss = checkpoint.get("val_loss", float("inf")) | |
| print(f" -> Starting from epoch {start_epoch + 1}") | |
| try: | |
| for epoch in range(start_epoch, epochs): | |
| model.train() | |
| train_losses = {"poly": 0, "win": 0, "score": 0, "turn": 0} | |
| correct_policy = 0 | |
| total_policy = 0 | |
| current_train_files = train_files.copy() | |
| random.shuffle(current_train_files) | |
| pbar = tqdm(current_train_files, desc=f"Epoch {epoch + 1}/{epochs}") | |
| for f_idx, f_path in enumerate(pbar): | |
| try: | |
| s, pol, w, sc, tn = get_data_from_files([f_path]) | |
| chunk_dataset = ChunkDataset(s, pol, w, sc, tn) | |
| chunk_loader = DataLoader(chunk_dataset, batch_size=batch_size, shuffle=True) | |
| for states, target_p, target_w, target_s, target_t in chunk_loader: | |
| states = states.to(device) | |
| target_p = target_p.to(device) | |
| target_w = target_w.to(device) | |
| target_s = target_s.to(device) | |
| target_t = target_t.to(device) | |
| optimizer.zero_grad() | |
| # Forward pass | |
| p_soft, w_pred, s_pred, t_pred = model(states) | |
| # 1. Policy Loss (Weighted) | |
| # Avoid log(0) by using logits or CrossEntropy internally | |
| # But TransformerCardNet returns softmax, so we use NLL or manual CE | |
| # For simplicity, let's use the probabilities directly or re-logit | |
| p_loss_raw = -torch.sum(target_p * torch.log(p_soft + 1e-8), dim=1) | |
| # Weight non-pass actions higher | |
| target_argmax = torch.max(target_p, dim=1)[1] | |
| weights = torch.ones(states.size(0), device=device) | |
| weights[target_argmax != 0] = 5.0 | |
| p_loss = (p_loss_raw * weights).mean() | |
| # 2. Value Losses | |
| loss_win = win_loss_fn(w_pred, target_w) | |
| loss_score = mse_loss_fn(s_pred, target_s) | |
| loss_turn = mse_loss_fn(t_pred, target_t) | |
| total_loss = p_loss + loss_win + loss_score + loss_turn | |
| total_loss.backward() | |
| optimizer.step() | |
| train_losses["poly"] += p_loss.item() | |
| train_losses["win"] += loss_win.item() | |
| train_losses["score"] += loss_score.item() | |
| train_losses["turn"] += loss_turn.item() | |
| _, pred_action = torch.max(p_soft, 1) | |
| _, target_action = torch.max(target_p, 1) | |
| correct_policy += (pred_action == target_action).sum().item() | |
| total_policy += states.size(0) | |
| pbar.set_postfix( | |
| { | |
| "acc": f"{100 * correct_policy / total_policy:.1f}%", | |
| "win": f"{train_losses['win'] / (f_idx + 1):.3f}", | |
| } | |
| ) | |
| del chunk_dataset, chunk_loader, s, pol, w, sc, tn | |
| gc.collect() | |
| except Exception as e: | |
| print(f"Error processing chunk {f_path}: {e}") | |
| continue | |
| # Validation | |
| model.eval() | |
| val_losses = {"poly": 0, "win": 0, "score": 0, "turn": 0} | |
| val_correct = 0 | |
| val_total = 0 | |
| num_batches_val = 0 | |
| print(f" [Epoch {epoch + 1}] Validating...") | |
| with torch.no_grad(): | |
| for f_path in val_files: | |
| try: | |
| s, pol, w, sc, tn = get_data_from_files([f_path]) | |
| val_chunk = ChunkDataset(s, pol, w, sc, tn) | |
| val_loader = DataLoader(val_chunk, batch_size=batch_size, shuffle=False) | |
| for states, target_p, target_w, target_s, target_t in val_loader: | |
| states = states.to(device) | |
| target_p = target_p.to(device) | |
| target_w = target_w.to(device) | |
| target_s = target_s.to(device) | |
| target_t = target_t.to(device) | |
| p_soft, w_pred, s_pred, t_pred = model(states) | |
| p_loss_raw = -torch.sum(target_p * torch.log(p_soft + 1e-8), dim=1) | |
| val_losses["poly"] += p_loss_raw.mean().item() | |
| val_losses["win"] += win_loss_fn(w_pred, target_w).item() | |
| val_losses["score"] += mse_loss_fn(s_pred, target_s).item() | |
| val_losses["turn"] += mse_loss_fn(t_pred, target_t).item() | |
| num_batches_val += 1 | |
| _, pred_action = torch.max(p_soft, 1) | |
| _, target_action = torch.max(target_p, 1) | |
| val_correct += (pred_action == target_action).sum().item() | |
| val_total += states.size(0) | |
| del val_chunk, val_loader, s, pol, w, sc, tn | |
| gc.collect() | |
| except Exception as e: | |
| print(f"Error validating chunk {f_path}: {e}") | |
| if val_total > 0: | |
| avg_val_win = val_losses["win"] / num_batches_val | |
| avg_val_total = (val_losses["poly"] + val_losses["win"] + val_losses["score"]) / num_batches_val | |
| scheduler.step(avg_val_total) | |
| print(f"Epoch {epoch + 1} | Val WinLoss: {avg_val_win:.4f} | Acc: {100 * val_correct / val_total:.1f}%") | |
| checkpoint = { | |
| "model_state": model.state_dict(), | |
| "optimizer_state": optimizer.state_dict(), | |
| "val_loss": avg_val_total, | |
| "epoch": epoch, | |
| } | |
| torch.save(checkpoint, f"ai/models/transformer_epoch_{epoch + 1}.pt") | |
| if avg_val_total < best_val_loss: | |
| best_val_loss = avg_val_total | |
| torch.save(checkpoint, "ai/models/transformer_best.pt") | |
| except KeyboardInterrupt: | |
| print("\nTraining interrupted by user. Saving current state to ai/models/alphanet_interrupted.pt...") | |
| # Save complete state for resumption | |
| interrupted_state = { | |
| "model_state": model.state_dict(), | |
| "optimizer_state": optimizer.state_dict(), | |
| "scheduler_state": scheduler.state_dict() if "scheduler" in locals() else None, | |
| "epoch": epoch if "epoch" in locals() else 0, | |
| "val_loss": best_val_loss, | |
| } | |
| torch.save(interrupted_state, "ai/models/alphanet_interrupted.pt") | |
| torch.save(model.state_dict(), "ai/models/alphanet_final.pt") | |
| print(f"Training complete. Best Val Loss: {best_val_loss:.4f}") | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--data", type=str, default="ai/data/alphazero_nightly_chunk_*.npz") | |
| parser.add_argument("--epochs", type=int, default=10) | |
| parser.add_argument("--batch-size", type=int, default=16384) | |
| parser.add_argument("--lr", type=float, default=0.001) | |
| parser.add_argument( | |
| "--resume", type=str, help="Path to checkpoint to resume from (e.g. ai/models/alphanet_interrupted.pt)" | |
| ) | |
| args = parser.parse_args() | |
| if not os.path.exists("ai/models"): | |
| os.makedirs("ai/models") | |
| train(args.data, epochs=args.epochs, batch_size=args.batch_size, lr=args.lr, resume_path=args.resume) | |