| """ |
| 2048 AI Trainer - 主入口 |
| 基于Transformer的2048游戏AI训练器 |
| |
| 使用方法: |
| python main.py # 启动GUI界面 |
| python main.py --train # 命令行训练模式 |
| python main.py --demo # 演示模式(加载已有模型) |
| """ |
| import sys |
| import os |
| import argparse |
| import torch |
| import time |
|
|
| |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
|
|
| from game import Game2048 |
| from model import Game2048Transformer, count_parameters |
| from trainer import PPOTrainer, RolloutBuffer, TrainingStats |
| from parallel import TrainingLoop, ParallelGameEnv, TrainingWorker |
| from utils import ( |
| set_seed, get_device, print_model_info, save_checkpoint, |
| load_checkpoint, EarlyStopping, format_time, format_number |
| ) |
|
|
|
|
| def run_gui(): |
| """运行GUI界面""" |
| from gui import main |
| main() |
|
|
|
|
| def run_training( |
| num_games: int = 10000, |
| num_envs: int = 4, |
| save_interval: int = 100, |
| checkpoint_dir: str = "checkpoints", |
| seed: int = 42 |
| ): |
| """ |
| 命令行训练模式 |
| |
| Args: |
| num_games: 总游戏局数 |
| num_envs: 并行环境数 |
| save_interval: 保存间隔 |
| checkpoint_dir: 检查点目录 |
| seed: 随机种子 |
| """ |
| print("=" * 50) |
| print("2048 AI Training") |
| print("=" * 50) |
| |
| |
| set_seed(seed) |
| |
| |
| device = get_device() |
| print(f"Device: {device}") |
| |
| |
| model = Game2048Transformer() |
| print_model_info(model) |
| model.to(device) |
| |
| |
| trainer = PPOTrainer(model, lr=1e-4, device=device) |
| |
| |
| training_loop = TrainingLoop( |
| model=model, |
| trainer=trainer, |
| num_envs=num_envs, |
| device=device, |
| steps_per_update=256, |
| checkpoint_dir=checkpoint_dir |
| ) |
| |
| |
| stats = TrainingStats() |
| start_time = time.time() |
| best_score = 0 |
| |
| |
| os.makedirs(checkpoint_dir, exist_ok=True) |
| |
| print(f"\nStarting training for {num_games} games...") |
| print(f"Parallel environments: {num_envs}") |
| print("-" * 50) |
| |
| try: |
| games_completed = 0 |
| |
| def on_game_end(game_stats): |
| nonlocal games_completed, best_score |
| |
| stats.record_game( |
| score=game_stats['score'], |
| situational_score=game_stats['situational_score'], |
| max_tile=game_stats['max_tile'], |
| steps=game_stats['moves'] |
| ) |
| |
| games_completed += 1 |
| |
| if game_stats['score'] > best_score: |
| best_score = game_stats['score'] |
| |
| |
| if games_completed % 10 == 0: |
| elapsed = time.time() - start_time |
| avg_stats = stats.get_avg_stats(window=100) |
| |
| print( |
| f"Games: {games_completed} | " |
| f"Avg Score: {avg_stats['avg_score']:.0f} | " |
| f"Best: {best_score} | " |
| f"Max Tile: {avg_stats['avg_max_tile']:.0f} | " |
| f"Speed: {games_completed/elapsed:.2f} games/s" |
| ) |
| |
| |
| if games_completed % save_interval == 0: |
| checkpoint_path = os.path.join( |
| checkpoint_dir, |
| f"checkpoint_{games_completed}.pt" |
| ) |
| save_checkpoint( |
| model, trainer.optimizer, games_completed, |
| avg_stats, checkpoint_path |
| ) |
| print(f"Checkpoint saved: {checkpoint_path}") |
| |
| training_loop.on_game_end_callback = on_game_end |
| training_loop.train(total_games=num_games, stop_threshold=200) |
| |
| except KeyboardInterrupt: |
| print("\nTraining interrupted by user.") |
| |
| |
| elapsed = time.time() - start_time |
| final_stats = stats.get_avg_stats() |
| |
| print("\n" + "=" * 50) |
| print("Training Complete!") |
| print("=" * 50) |
| print(f"Total games: {format_number(final_stats['games_played'])}") |
| print(f"Total time: {format_time(elapsed)}") |
| print(f"Average score: {final_stats['avg_score']:.0f}") |
| print(f"Best score: {final_stats['best_score']}") |
| print(f"Best max tile: {final_stats['best_max_tile']}") |
| |
| |
| final_path = os.path.join(checkpoint_dir, "final_model.pt") |
| save_checkpoint( |
| model, trainer.optimizer, final_stats['games_played'], |
| final_stats, final_path |
| ) |
| print(f"Final model saved: {final_path}") |
|
|
|
|
| def run_demo(model_path: str = None, num_games: int = 5): |
| """ |
| 演示模式 |
| |
| Args: |
| model_path: 模型路径 |
| num_games: 演示游戏数 |
| """ |
| print("=" * 50) |
| print("2048 AI Demo Mode") |
| print("=" * 50) |
| |
| device = get_device() |
| print(f"Device: {device}") |
| |
| |
| model = Game2048Transformer() |
| |
| |
| if model_path and os.path.exists(model_path): |
| print(f"Loading model from: {model_path}") |
| checkpoint = load_checkpoint(model_path, model, device=device) |
| print(f"Loaded checkpoint: {checkpoint.get('epoch', 'unknown')} games") |
| else: |
| print("No model loaded, using random weights.") |
| |
| model.to(device) |
| model.eval() |
| print_model_info(model) |
| |
| |
| print(f"\nRunning {num_games} demo games...") |
| print("-" * 50) |
| |
| total_scores = [] |
| |
| for game_idx in range(num_games): |
| game = Game2048() |
| game.reset() |
| |
| steps = 0 |
| while not game.game_over and steps < 10000: |
| |
| state = game.get_state() |
| scores = game.get_state_with_scores()[-2:] |
| valid_actions = game.get_valid_actions() |
| |
| |
| state_t = torch.FloatTensor(state).unsqueeze(0).to(device) |
| scores_t = torch.FloatTensor(scores).unsqueeze(0).to(device) |
| valid_t = torch.BoolTensor(valid_actions).unsqueeze(0).to(device) |
| |
| |
| with torch.no_grad(): |
| action, _, _ = model.get_action(state_t, scores_t, valid_t, deterministic=True) |
| |
| |
| game.move(action) |
| steps += 1 |
| |
| total_scores.append(game.accumulated_score) |
| |
| print( |
| f"Game {game_idx + 1}: " |
| f"Score = {game.accumulated_score}, " |
| f"Max Tile = {game.get_max_tile()}, " |
| f"Steps = {steps}" |
| ) |
| |
| |
| print("\n" + "-" * 50) |
| print(f"Average score: {sum(total_scores) / len(total_scores):.0f}") |
| print(f"Best score: {max(total_scores)}") |
| print(f"Worst score: {min(total_scores)}") |
|
|
|
|
| def main(): |
| """主函数""" |
| parser = argparse.ArgumentParser( |
| description='2048 AI Trainer - 基于Transformer的2048游戏AI' |
| ) |
| |
| parser.add_argument( |
| '--train', |
| action='store_true', |
| help='运行命令行训练模式' |
| ) |
| |
| parser.add_argument( |
| '--demo', |
| action='store_true', |
| help='运行演示模式' |
| ) |
| |
| parser.add_argument( |
| '--model', |
| type=str, |
| default=None, |
| help='模型路径(用于演示模式或继续训练)' |
| ) |
| |
| parser.add_argument( |
| '--games', |
| type=int, |
| default=10000, |
| help='训练游戏数(默认10000)' |
| ) |
| |
| parser.add_argument( |
| '--envs', |
| type=int, |
| default=4, |
| help='并行环境数(默认4)' |
| ) |
| |
| parser.add_argument( |
| '--seed', |
| type=int, |
| default=42, |
| help='随机种子(默认42)' |
| ) |
| |
| args = parser.parse_args() |
| |
| if args.train: |
| run_training( |
| num_games=args.games, |
| num_envs=args.envs, |
| seed=args.seed |
| ) |
| elif args.demo: |
| run_demo(model_path=args.model, num_games=5) |
| else: |
| |
| run_gui() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|