| """ |
| 工具函数 |
| """ |
| import os |
| import json |
| import numpy as np |
| import torch |
| from datetime import datetime |
| from typing import Dict, Any, Optional |
| import shutil |
|
|
|
|
| def ensure_dir(path: str) -> str: |
| """确保目录存在,不存在则创建""" |
| if not os.path.exists(path): |
| os.makedirs(path) |
| return path |
|
|
|
|
| def save_checkpoint( |
| model: torch.nn.Module, |
| optimizer: torch.optim.Optimizer, |
| epoch: int, |
| stats: Dict[str, Any], |
| path: str |
| ) -> None: |
| """保存训练检查点""" |
| ensure_dir(os.path.dirname(path)) |
| torch.save({ |
| 'epoch': epoch, |
| 'model_state_dict': model.state_dict(), |
| 'optimizer_state_dict': optimizer.state_dict(), |
| 'stats': stats, |
| 'timestamp': datetime.now().isoformat() |
| }, path) |
|
|
|
|
| def load_checkpoint( |
| path: str, |
| model: torch.nn.Module, |
| optimizer: Optional[torch.optim.Optimizer] = None, |
| device: str = 'cpu' |
| ) -> Dict[str, Any]: |
| """加载训练检查点""" |
| checkpoint = torch.load(path, map_location=device, weights_only=False) |
| model.load_state_dict(checkpoint['model_state_dict']) |
| |
| if optimizer is not None and 'optimizer_state_dict' in checkpoint: |
| optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
| |
| return checkpoint |
|
|
|
|
| def save_training_log(log_data: Dict[str, Any], path: str) -> None: |
| """保存训练日志""" |
| ensure_dir(os.path.dirname(path)) |
| |
| |
| if os.path.exists(path): |
| with open(path, 'r', encoding='utf-8') as f: |
| try: |
| logs = json.load(f) |
| except json.JSONDecodeError: |
| logs = [] |
| else: |
| logs = [] |
| |
| |
| log_data['timestamp'] = datetime.now().isoformat() |
| logs.append(log_data) |
| |
| |
| with open(path, 'w', encoding='utf-8') as f: |
| json.dump(logs, f, indent=2, ensure_ascii=False) |
|
|
|
|
| def format_time(seconds: float) -> str: |
| """格式化时间""" |
| if seconds < 60: |
| return f'{seconds:.1f}s' |
| elif seconds < 3600: |
| minutes = seconds / 60 |
| return f'{minutes:.1f}m' |
| else: |
| hours = seconds / 3600 |
| return f'{hours:.1f}h' |
|
|
|
|
| def format_number(num: int) -> str: |
| """格式化数字(添加逗号分隔)""" |
| return f'{num:,}' |
|
|
|
|
| def calculate_ema(values: list, alpha: float = 0.1) -> list: |
| """计算指数移动平均""" |
| if not values: |
| return [] |
| |
| ema = [values[0]] |
| for value in values[1:]: |
| ema.append(alpha * value + (1 - alpha) * ema[-1]) |
| return ema |
|
|
|
|
| def get_tile_color(value: int) -> str: |
| """获取砖块颜色""" |
| colors = { |
| 0: '#cdc1b4', |
| 2: '#eee4da', |
| 4: '#ede0c8', |
| 8: '#f2b179', |
| 16: '#f59563', |
| 32: '#f67c5f', |
| 64: '#f65e3b', |
| 128: '#edcf72', |
| 256: '#edcc61', |
| 512: '#edc850', |
| 1024: '#edc53f', |
| 2048: '#edc22e', |
| } |
| return colors.get(value, '#3c3a32') |
|
|
|
|
| def get_text_color(value: int) -> str: |
| """获取文字颜色""" |
| if value <= 4: |
| return '#776e65' |
| return '#f9f6f2' |
|
|
|
|
| class EarlyStopping: |
| """早停机制""" |
| |
| def __init__( |
| self, |
| patience: int = 100, |
| min_delta: float = 0.01, |
| mode: str = 'max' |
| ): |
| """ |
| Args: |
| patience: 容忍的epoch数 |
| min_delta: 最小改进 |
| mode: 'max' 或 'min' |
| """ |
| self.patience = patience |
| self.min_delta = min_delta |
| self.mode = mode |
| self.counter = 0 |
| self.best_value = None |
| self.should_stop = False |
| |
| def __call__(self, value: float) -> bool: |
| """ |
| 检查是否应该停止 |
| |
| Args: |
| value: 当前值 |
| |
| Returns: |
| 是否应该停止 |
| """ |
| if self.best_value is None: |
| self.best_value = value |
| return False |
| |
| if self.mode == 'max': |
| improved = value > self.best_value + self.min_delta |
| else: |
| improved = value < self.best_value - self.min_delta |
| |
| if improved: |
| self.best_value = value |
| self.counter = 0 |
| else: |
| self.counter += 1 |
| if self.counter >= self.patience: |
| self.should_stop = True |
| |
| return self.should_stop |
|
|
|
|
| class MetricTracker: |
| """指标跟踪器""" |
| |
| def __init__(self, window_size: int = 100): |
| self.window_size = window_size |
| self.metrics = {} |
| |
| def update(self, name: str, value: float) -> None: |
| """更新指标""" |
| if name not in self.metrics: |
| self.metrics[name] = [] |
| self.metrics[name].append(value) |
| |
| |
| if len(self.metrics[name]) > self.window_size: |
| self.metrics[name] = self.metrics[name][-self.window_size:] |
| |
| def get_mean(self, name: str) -> float: |
| """获取平均值""" |
| if name not in self.metrics or not self.metrics[name]: |
| return 0.0 |
| return np.mean(self.metrics[name]) |
| |
| def get_std(self, name: str) -> float: |
| """获取标准差""" |
| if name not in self.metrics or len(self.metrics[name]) < 2: |
| return 0.0 |
| return np.std(self.metrics[name]) |
| |
| def get_all_means(self) -> Dict[str, float]: |
| """获取所有指标的平均值""" |
| return {name: self.get_mean(name) for name in self.metrics} |
|
|
|
|
| def set_seed(seed: int) -> None: |
| """设置随机种子""" |
| import random |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
|
|
|
|
| def get_device() -> str: |
| """获取可用设备""" |
| if torch.cuda.is_available(): |
| return 'cuda' |
| return 'cpu' |
|
|
|
|
| def count_parameters(model: torch.nn.Module) -> int: |
| """计算模型参数数量""" |
| return sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
|
|
| def print_model_info(model: torch.nn.Module) -> None: |
| """打印模型信息""" |
| total_params = count_parameters(model) |
| print(f"模型参数数量: {format_number(total_params)}") |
| print(f"模型大小: {total_params * 4 / 1024 / 1024:.2f} MB (float32)") |
|
|
|
|
| def export_to_onnx( |
| model: torch.nn.Module, |
| path: str, |
| input_size: tuple = (1, 4, 4) |
| ) -> None: |
| """导出模型到ONNX格式""" |
| model.eval() |
| dummy_input = torch.randn(*input_size) |
| dummy_scores = torch.randn(1, 2) |
| dummy_mask = torch.ones(1, 4, dtype=torch.bool) |
| |
| ensure_dir(os.path.dirname(path)) |
| torch.onnx.export( |
| model, |
| (dummy_input, dummy_scores, dummy_mask), |
| path, |
| input_names=['board', 'scores', 'mask'], |
| output_names=['policy', 'value'], |
| dynamic_axes={ |
| 'board': {0: 'batch_size'}, |
| 'scores': {0: 'batch_size'}, |
| 'mask': {0: 'batch_size'} |
| } |
| ) |
| print(f"模型已导出到: {path}") |
|
|
|
|
| if __name__ == "__main__": |
| |
| print("Testing utility functions...") |
| |
| |
| print(f"Format time: {format_time(45.5)}, {format_time(125.3)}, {format_time(3661)}") |
| |
| |
| print(f"Format number: {format_number(1234567)}") |
| |
| |
| values = [1, 2, 3, 4, 5] |
| print(f"EMA: {calculate_ema(values)}") |
| |
| |
| early_stop = EarlyStopping(patience=3, min_delta=0.1) |
| scores = [10, 11, 12, 12, 12, 12, 12] |
| for i, score in enumerate(scores): |
| stop = early_stop(score) |
| print(f"Epoch {i}: score={score}, stop={stop}") |
| |
| print("All tests passed!") |
|
|