|
|
""" |
|
|
训练器模块 |
|
|
Trainer module for PAD Predictor training |
|
|
|
|
|
该模块实现了一个完整的训练器类,包含: |
|
|
- 训练循环和验证循环 |
|
|
- 早停机制和学习率调度 |
|
|
- 检查点保存和恢复 |
|
|
- 混合精度训练支持 |
|
|
- 多GPU训练支持 |
|
|
- 梯度裁剪和正则化 |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
from torch.cuda.amp import GradScaler, autocast |
|
|
from torch.utils.data import DataLoader |
|
|
from torch.optim.lr_scheduler import _LRScheduler |
|
|
import numpy as np |
|
|
from typing import Dict, List, Tuple, Optional, Any, Union |
|
|
import os |
|
|
import json |
|
|
import time |
|
|
from pathlib import Path |
|
|
import logging |
|
|
from collections import defaultdict |
|
|
import copy |
|
|
|
|
|
from ..models.metrics import PADMetrics |
|
|
from ..models.loss_functions import MultiTaskLoss |
|
|
|
|
|
|
|
|
class EarlyStopping: |
|
|
"""早停机制类""" |
|
|
|
|
|
def __init__(self, |
|
|
patience: int = 20, |
|
|
min_delta: float = 1e-4, |
|
|
mode: str = 'min', |
|
|
restore_best_weights: bool = True): |
|
|
""" |
|
|
初始化早停机制 |
|
|
|
|
|
Args: |
|
|
patience: 容忍轮次 |
|
|
min_delta: 最小改善阈值 |
|
|
mode: 监控模式 ('min' 或 'max') |
|
|
restore_best_weights: 是否恢复最佳权重 |
|
|
""" |
|
|
self.patience = patience |
|
|
self.min_delta = min_delta |
|
|
self.mode = mode |
|
|
self.restore_best_weights = restore_best_weights |
|
|
|
|
|
self.best_score = None |
|
|
self.counter = 0 |
|
|
self.best_weights = None |
|
|
self.early_stop = False |
|
|
|
|
|
if mode == 'min': |
|
|
self.monitor_op = np.less |
|
|
self.min_delta *= -1 |
|
|
else: |
|
|
self.monitor_op = np.greater |
|
|
|
|
|
def __call__(self, score: float, model: nn.Module) -> bool: |
|
|
""" |
|
|
检查是否应该早停 |
|
|
|
|
|
Args: |
|
|
score: 当前监控分数 |
|
|
model: 模型 |
|
|
|
|
|
Returns: |
|
|
是否应该早停 |
|
|
""" |
|
|
if self.best_score is None: |
|
|
self.best_score = score |
|
|
self.save_checkpoint(model) |
|
|
elif self.monitor_op(score, self.best_score + self.min_delta): |
|
|
self.best_score = score |
|
|
self.counter = 0 |
|
|
self.save_checkpoint(model) |
|
|
else: |
|
|
self.counter += 1 |
|
|
if self.counter >= self.patience: |
|
|
self.early_stop = True |
|
|
if self.restore_best_weights and self.best_weights is not None: |
|
|
model.load_state_dict(self.best_weights) |
|
|
return True |
|
|
|
|
|
return False |
|
|
|
|
|
def save_checkpoint(self, model: nn.Module): |
|
|
"""保存检查点""" |
|
|
if self.restore_best_weights: |
|
|
self.best_weights = copy.deepcopy(model.state_dict()) |
|
|
|
|
|
|
|
|
class Trainer: |
|
|
""" |
|
|
PAD预测器训练器类 |
|
|
|
|
|
功能特性: |
|
|
- 支持AdamW优化器(结合L2正则化) |
|
|
- 支持Cosine Decay学习率调度 |
|
|
- 支持混合精度训练 |
|
|
- 支持多GPU训练(DataParallel) |
|
|
- 支持早停机制 |
|
|
- 支持检查点保存和恢复 |
|
|
- 支持梯度裁剪 |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
model: nn.Module, |
|
|
config: Dict[str, Any], |
|
|
device: Optional[Union[str, torch.device]] = None, |
|
|
logger: Optional[logging.Logger] = None, |
|
|
diagnostic_mode: bool = False): |
|
|
""" |
|
|
初始化训练器 |
|
|
|
|
|
Args: |
|
|
model: 要训练的模型 |
|
|
config: 训练配置 |
|
|
device: 训练设备 |
|
|
logger: 日志记录器 |
|
|
diagnostic_mode: 诊断模式(打印每个维度的详细指标) |
|
|
""" |
|
|
self.model = model |
|
|
self.config = config |
|
|
self.logger = logger or logging.getLogger(__name__) |
|
|
self.diagnostic_mode = diagnostic_mode |
|
|
|
|
|
|
|
|
self.device = self._setup_device(device) |
|
|
self.model.to(self.device) |
|
|
|
|
|
|
|
|
self._setup_multi_gpu() |
|
|
|
|
|
|
|
|
self.optimizer = self._setup_optimizer() |
|
|
self.scheduler = self._setup_scheduler() |
|
|
self.loss_fn = self._setup_loss_function() |
|
|
self.metrics = PADMetrics() |
|
|
|
|
|
|
|
|
self.current_epoch = 0 |
|
|
self.best_score = None |
|
|
self.train_history = defaultdict(list) |
|
|
self.val_history = defaultdict(list) |
|
|
|
|
|
|
|
|
self.early_stopping = self._setup_early_stopping() |
|
|
|
|
|
|
|
|
self.mixed_precision = config.get('hardware', {}).get('mixed_precision', {}).get('enabled', False) |
|
|
self.scaler = GradScaler() if self.mixed_precision else None |
|
|
|
|
|
|
|
|
self.grad_clip_val = config.get('debug', {}).get('gradient_checking', {}).get('clip_value', 1.0) |
|
|
|
|
|
self.logger.info("训练器初始化完成") |
|
|
self._log_model_info() |
|
|
|
|
|
def _setup_device(self, device: Optional[Union[str, torch.device]]) -> torch.device: |
|
|
"""设置训练设备""" |
|
|
if device is None: |
|
|
device_config = self.config.get('hardware', {}).get('device', 'auto') |
|
|
|
|
|
if device_config == 'auto': |
|
|
if torch.cuda.is_available(): |
|
|
device = torch.device('cuda') |
|
|
self.logger.info(f"使用GPU训练: {torch.cuda.get_device_name()}") |
|
|
else: |
|
|
device = torch.device('cpu') |
|
|
self.logger.info("使用CPU训练") |
|
|
else: |
|
|
device = torch.device(device_config) |
|
|
else: |
|
|
device = torch.device(device) |
|
|
|
|
|
return device |
|
|
|
|
|
def _setup_multi_gpu(self): |
|
|
"""设置多GPU训练""" |
|
|
if torch.cuda.device_count() > 1: |
|
|
self.model = nn.DataParallel(self.model) |
|
|
self.logger.info(f"使用多GPU训练: {torch.cuda.device_count()} 个GPU") |
|
|
|
|
|
def _setup_optimizer(self) -> optim.Optimizer: |
|
|
"""设置优化器""" |
|
|
optimizer_config = self.config.get('training', {}).get('optimizer', {}) |
|
|
|
|
|
optimizer_type = optimizer_config.get('type', 'AdamW') |
|
|
learning_rate = optimizer_config.get('learning_rate', 1e-3) |
|
|
weight_decay = optimizer_config.get('weight_decay', 1e-4) |
|
|
betas = optimizer_config.get('betas', [0.9, 0.999]) |
|
|
eps = optimizer_config.get('eps', 1e-8) |
|
|
|
|
|
if optimizer_type == 'AdamW': |
|
|
optimizer = optim.AdamW( |
|
|
self.model.parameters(), |
|
|
lr=learning_rate, |
|
|
weight_decay=weight_decay, |
|
|
betas=betas, |
|
|
eps=eps |
|
|
) |
|
|
elif optimizer_type == 'Adam': |
|
|
optimizer = optim.Adam( |
|
|
self.model.parameters(), |
|
|
lr=learning_rate, |
|
|
weight_decay=weight_decay, |
|
|
betas=betas, |
|
|
eps=eps |
|
|
) |
|
|
elif optimizer_type == 'SGD': |
|
|
momentum = optimizer_config.get('momentum', 0.9) |
|
|
optimizer = optim.SGD( |
|
|
self.model.parameters(), |
|
|
lr=learning_rate, |
|
|
momentum=momentum, |
|
|
weight_decay=weight_decay |
|
|
) |
|
|
else: |
|
|
raise ValueError(f"不支持的优化器类型: {optimizer_type}") |
|
|
|
|
|
self.logger.info(f"优化器: {optimizer_type}, 学习率: {learning_rate}, 权重衰减: {weight_decay}") |
|
|
return optimizer |
|
|
|
|
|
def _setup_scheduler(self) -> Optional[_LRScheduler]: |
|
|
"""设置学习率调度器""" |
|
|
scheduler_config = self.config.get('training', {}).get('scheduler', {}) |
|
|
|
|
|
if not scheduler_config.get('type'): |
|
|
return None |
|
|
|
|
|
scheduler_type = scheduler_config.get('type') |
|
|
|
|
|
if scheduler_type == 'CosineAnnealingLR': |
|
|
T_max = scheduler_config.get('T_max', 100) |
|
|
eta_min = scheduler_config.get('eta_min', 1e-6) |
|
|
scheduler = optim.lr_scheduler.CosineAnnealingLR( |
|
|
self.optimizer, T_max=T_max, eta_min=eta_min |
|
|
) |
|
|
elif scheduler_type == 'ReduceLROnPlateau': |
|
|
mode = scheduler_config.get('mode', 'min') |
|
|
factor = scheduler_config.get('factor', 0.5) |
|
|
patience = scheduler_config.get('patience', 10) |
|
|
min_lr = scheduler_config.get('min_lr', 1e-6) |
|
|
scheduler = optim.lr_scheduler.ReduceLROnPlateau( |
|
|
self.optimizer, mode=mode, factor=factor, |
|
|
patience=patience, min_lr=min_lr, verbose=True |
|
|
) |
|
|
elif scheduler_type == 'StepLR': |
|
|
step_size = scheduler_config.get('step_size', 30) |
|
|
gamma = scheduler_config.get('gamma', 0.1) |
|
|
scheduler = optim.lr_scheduler.StepLR( |
|
|
self.optimizer, step_size=step_size, gamma=gamma |
|
|
) |
|
|
else: |
|
|
raise ValueError(f"不支持的调度器类型: {scheduler_type}") |
|
|
|
|
|
self.logger.info(f"学习率调度器: {scheduler_type}") |
|
|
return scheduler |
|
|
|
|
|
def _setup_loss_function(self) -> nn.Module: |
|
|
"""设置损失函数""" |
|
|
loss_config = self.config.get('training', {}).get('loss', {}) |
|
|
loss_type = loss_config.get('type', 'MSELoss') |
|
|
|
|
|
if loss_type == 'MultiTaskLoss': |
|
|
|
|
|
weights_config = loss_config.get('multi_task_weights', {}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if 'delta_pad_p' in weights_config: |
|
|
|
|
|
task_weights = [ |
|
|
weights_config.get('delta_pad_p', 1.0), |
|
|
weights_config.get('delta_pad_a', 1.0), |
|
|
weights_config.get('delta_pad_d', 1.0) |
|
|
] |
|
|
else: |
|
|
|
|
|
delta_pad_weight = weights_config.get('delta_pad', 1.0) |
|
|
task_weights = [delta_pad_weight] * 3 |
|
|
|
|
|
return MultiTaskLoss(num_tasks=3, task_weights=task_weights) |
|
|
elif loss_type == 'MSELoss': |
|
|
return nn.MSELoss(reduction='mean') |
|
|
elif loss_type == 'L1Loss': |
|
|
return nn.L1Loss(reduction='mean') |
|
|
elif loss_type == 'SmoothL1Loss': |
|
|
return nn.SmoothL1Loss(reduction='mean') |
|
|
elif loss_type == 'HuberLoss': |
|
|
|
|
|
return nn.HuberLoss(reduction='mean', delta=0.05) |
|
|
else: |
|
|
raise ValueError(f"不支持的损失函数类型: {loss_type}") |
|
|
|
|
|
def _setup_early_stopping(self) -> Optional[EarlyStopping]: |
|
|
"""设置早停机制""" |
|
|
early_stopping_config = self.config.get('training', {}).get('epochs', {}).get('early_stopping', {}) |
|
|
|
|
|
if not early_stopping_config.get('enabled', False): |
|
|
return None |
|
|
|
|
|
return EarlyStopping( |
|
|
patience=early_stopping_config.get('patience', 20), |
|
|
min_delta=early_stopping_config.get('min_delta', 1e-4), |
|
|
mode=early_stopping_config.get('mode', 'min'), |
|
|
restore_best_weights=True |
|
|
) |
|
|
|
|
|
def _log_model_info(self): |
|
|
"""记录模型信息""" |
|
|
total_params = sum(p.numel() for p in self.model.parameters()) |
|
|
trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad) |
|
|
|
|
|
self.logger.info(f"模型参数总数: {total_params:,}") |
|
|
self.logger.info(f"可训练参数: {trainable_params:,}") |
|
|
self.logger.info(f"训练设备: {self.device}") |
|
|
|
|
|
def train_epoch(self, train_loader: DataLoader) -> Dict[str, float]: |
|
|
""" |
|
|
训练一个epoch |
|
|
|
|
|
Args: |
|
|
train_loader: 训练数据加载器 |
|
|
|
|
|
Returns: |
|
|
训练指标字典 |
|
|
""" |
|
|
self.model.train() |
|
|
epoch_losses = [] |
|
|
epoch_metrics = defaultdict(list) |
|
|
|
|
|
for batch_idx, (features, targets) in enumerate(train_loader): |
|
|
|
|
|
if features.device != self.device: |
|
|
features = features.to(self.device) |
|
|
if targets.device != self.device: |
|
|
targets = targets.to(self.device) |
|
|
|
|
|
self.optimizer.zero_grad() |
|
|
|
|
|
|
|
|
if self.mixed_precision: |
|
|
with autocast(): |
|
|
predictions = self.model(features) |
|
|
loss = self.loss_fn(predictions, targets) |
|
|
else: |
|
|
predictions = self.model(features) |
|
|
loss = self.loss_fn(predictions, targets) |
|
|
|
|
|
|
|
|
if self.mixed_precision: |
|
|
self.scaler.scale(loss).backward() |
|
|
|
|
|
|
|
|
if self.grad_clip_val > 0: |
|
|
self.scaler.unscale_(self.optimizer) |
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_val) |
|
|
|
|
|
self.scaler.step(self.optimizer) |
|
|
self.scaler.update() |
|
|
else: |
|
|
loss.backward() |
|
|
|
|
|
|
|
|
if self.grad_clip_val > 0: |
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_val) |
|
|
|
|
|
self.optimizer.step() |
|
|
|
|
|
|
|
|
epoch_losses.append(loss.item()) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
batch_metrics = self.metrics.evaluate_predictions(predictions, targets) |
|
|
for key, value in batch_metrics.items(): |
|
|
if isinstance(value, dict): |
|
|
for sub_key, sub_value in value.items(): |
|
|
if isinstance(sub_value, (int, float)): |
|
|
epoch_metrics[f"{key}_{sub_key}"].append(sub_value) |
|
|
elif isinstance(value, (int, float)): |
|
|
epoch_metrics[key].append(value) |
|
|
|
|
|
|
|
|
epoch_results = { |
|
|
'loss': np.mean(epoch_losses), |
|
|
'lr': self.optimizer.param_groups[0]['lr'] |
|
|
} |
|
|
|
|
|
for key, values in epoch_metrics.items(): |
|
|
epoch_results[key] = np.mean(values) |
|
|
|
|
|
return epoch_results |
|
|
|
|
|
def validate_epoch(self, val_loader: DataLoader) -> Dict[str, float]: |
|
|
""" |
|
|
验证一个epoch |
|
|
|
|
|
Args: |
|
|
val_loader: 验证数据加载器 |
|
|
|
|
|
Returns: |
|
|
验证指标字典 |
|
|
""" |
|
|
self.model.eval() |
|
|
epoch_losses = [] |
|
|
all_predictions = [] |
|
|
all_targets = [] |
|
|
|
|
|
with torch.no_grad(): |
|
|
for features, targets in val_loader: |
|
|
|
|
|
if features.device != self.device: |
|
|
features = features.to(self.device) |
|
|
if targets.device != self.device: |
|
|
targets = targets.to(self.device) |
|
|
|
|
|
|
|
|
if self.mixed_precision: |
|
|
with autocast(): |
|
|
predictions = self.model(features) |
|
|
loss = self.loss_fn(predictions, targets) |
|
|
else: |
|
|
predictions = self.model(features) |
|
|
loss = self.loss_fn(predictions, targets) |
|
|
|
|
|
epoch_losses.append(loss.item()) |
|
|
all_predictions.append(predictions.cpu()) |
|
|
all_targets.append(targets.cpu()) |
|
|
|
|
|
|
|
|
all_predictions = torch.cat(all_predictions, dim=0) |
|
|
all_targets = torch.cat(all_targets, dim=0) |
|
|
|
|
|
|
|
|
if self.diagnostic_mode: |
|
|
|
|
|
detailed_metrics = self.metrics.evaluate_predictions_diagnostic( |
|
|
all_predictions, all_targets, |
|
|
component_names=['ΔPAD_P', 'ΔPAD_A', 'ΔPAD_D'] |
|
|
) |
|
|
else: |
|
|
|
|
|
detailed_metrics = self.metrics.evaluate_predictions(all_predictions, all_targets) |
|
|
|
|
|
|
|
|
epoch_results = { |
|
|
'val_loss': np.mean(epoch_losses) |
|
|
} |
|
|
|
|
|
|
|
|
regression_metrics = detailed_metrics.get('regression', {}) |
|
|
for key, value in regression_metrics.items(): |
|
|
if isinstance(value, dict): |
|
|
for sub_key, sub_value in value.items(): |
|
|
if isinstance(sub_value, (int, float)): |
|
|
epoch_results[f"val_{key}_{sub_key}"] = sub_value |
|
|
elif isinstance(value, (int, float)): |
|
|
epoch_results[f"val_{key}"] = value |
|
|
|
|
|
|
|
|
if 'overall' in regression_metrics: |
|
|
overall = regression_metrics['overall'] |
|
|
epoch_results['val_mae'] = overall.get('mae', 0) |
|
|
epoch_results['val_rmse'] = overall.get('rmse', 0) |
|
|
epoch_results['val_r2_mean'] = overall.get('r2', 0) |
|
|
epoch_results['val_r2_robust'] = overall.get('r2_robust', 0) |
|
|
epoch_results['val_mape'] = overall.get('mape', 0) |
|
|
|
|
|
|
|
|
calibration_metrics = detailed_metrics.get('calibration', {}) |
|
|
for key, value in calibration_metrics.items(): |
|
|
if isinstance(value, (int, float)): |
|
|
epoch_results[f"val_{key}"] = value |
|
|
|
|
|
|
|
|
pad_metrics = detailed_metrics.get('pad_specific', {}) |
|
|
for key, value in pad_metrics.items(): |
|
|
if isinstance(value, (int, float)): |
|
|
epoch_results[f"val_{key}"] = value |
|
|
|
|
|
return epoch_results |
|
|
|
|
|
def train(self, |
|
|
train_loader: DataLoader, |
|
|
val_loader: Optional[DataLoader] = None, |
|
|
save_dir: Optional[str] = None) -> Dict[str, List[float]]: |
|
|
""" |
|
|
完整训练流程 |
|
|
|
|
|
Args: |
|
|
train_loader: 训练数据加载器 |
|
|
val_loader: 验证数据加载器 |
|
|
save_dir: 保存目录 |
|
|
|
|
|
Returns: |
|
|
训练历史记录 |
|
|
""" |
|
|
self.logger.info("开始训练...") |
|
|
|
|
|
max_epochs = self.config.get('training', {}).get('epochs', {}).get('max_epochs', 200) |
|
|
val_frequency = self.config.get('validation', {}).get('val_frequency', 1) |
|
|
|
|
|
|
|
|
if save_dir: |
|
|
os.makedirs(save_dir, exist_ok=True) |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
for epoch in range(max_epochs): |
|
|
self.current_epoch = epoch |
|
|
|
|
|
|
|
|
train_metrics = self.train_epoch(train_loader) |
|
|
|
|
|
|
|
|
for key, value in train_metrics.items(): |
|
|
self.train_history[key].append(value) |
|
|
|
|
|
|
|
|
if val_loader is not None and epoch % val_frequency == 0: |
|
|
val_metrics = self.validate_epoch(val_loader) |
|
|
|
|
|
|
|
|
for key, value in val_metrics.items(): |
|
|
self.val_history[key].append(value) |
|
|
|
|
|
|
|
|
if self.scheduler is not None: |
|
|
if isinstance(self.scheduler, optim.lr_scheduler.ReduceLROnPlateau): |
|
|
self.scheduler.step(val_metrics['val_loss']) |
|
|
else: |
|
|
self.scheduler.step() |
|
|
|
|
|
|
|
|
if self.early_stopping is not None: |
|
|
monitor_metric = self.config.get('training', {}).get('epochs', {}).get('early_stopping', {}).get('monitor', 'val_loss') |
|
|
if self.early_stopping(val_metrics[monitor_metric], self.model): |
|
|
self.logger.info(f"早停触发,在第 {epoch + 1} 轮停止训练") |
|
|
break |
|
|
|
|
|
|
|
|
if save_dir and self._is_best_model(val_metrics): |
|
|
self._save_checkpoint(save_dir, is_best=True) |
|
|
|
|
|
|
|
|
if save_dir and (epoch + 1) % 10 == 0: |
|
|
self._save_checkpoint(save_dir, epoch=epoch + 1) |
|
|
|
|
|
|
|
|
self._log_epoch_progress(epoch + 1, train_metrics, val_metrics if val_loader else None) |
|
|
|
|
|
training_time = time.time() - start_time |
|
|
self.logger.info(f"训练完成,总耗时: {training_time:.2f} 秒") |
|
|
|
|
|
|
|
|
if save_dir: |
|
|
self._save_checkpoint(save_dir, is_final=True) |
|
|
|
|
|
return { |
|
|
'train_history': dict(self.train_history), |
|
|
'val_history': dict(self.val_history) |
|
|
} |
|
|
|
|
|
def _is_best_model(self, val_metrics: Dict[str, float]) -> bool: |
|
|
"""检查是否为最佳模型""" |
|
|
monitor_metric = self.config.get('validation', {}).get('model_selection', {}).get('criterion', 'val_loss') |
|
|
mode = self.config.get('validation', {}).get('model_selection', {}).get('mode', 'min') |
|
|
|
|
|
current_score = val_metrics.get(monitor_metric) |
|
|
if current_score is None: |
|
|
return False |
|
|
|
|
|
if self.best_score is None: |
|
|
self.best_score = current_score |
|
|
return True |
|
|
|
|
|
if mode == 'min': |
|
|
if current_score < self.best_score: |
|
|
self.best_score = current_score |
|
|
return True |
|
|
else: |
|
|
if current_score > self.best_score: |
|
|
self.best_score = current_score |
|
|
return True |
|
|
|
|
|
return False |
|
|
|
|
|
def _save_checkpoint(self, save_dir: str, epoch: Optional[int] = None, is_best: bool = False, is_final: bool = False): |
|
|
"""保存检查点""" |
|
|
checkpoint = { |
|
|
'epoch': self.current_epoch, |
|
|
'model_state_dict': self.model.state_dict(), |
|
|
'optimizer_state_dict': self.optimizer.state_dict(), |
|
|
'config': self.config, |
|
|
'train_history': dict(self.train_history), |
|
|
'val_history': dict(self.val_history), |
|
|
'best_score': self.best_score |
|
|
} |
|
|
|
|
|
if self.scheduler is not None: |
|
|
checkpoint['scheduler_state_dict'] = self.scheduler.state_dict() |
|
|
|
|
|
if self.mixed_precision and self.scaler is not None: |
|
|
checkpoint['scaler_state_dict'] = self.scaler.state_dict() |
|
|
|
|
|
|
|
|
if is_best: |
|
|
filename = 'best_model.pth' |
|
|
elif is_final: |
|
|
filename = 'final_model.pth' |
|
|
else: |
|
|
filename = f'checkpoint_epoch_{epoch}.pth' |
|
|
|
|
|
filepath = os.path.join(save_dir, filename) |
|
|
torch.save(checkpoint, filepath) |
|
|
|
|
|
if is_best: |
|
|
self.logger.info(f"保存最佳模型到: {filepath}") |
|
|
elif is_final: |
|
|
self.logger.info(f"保存最终模型到: {filepath}") |
|
|
else: |
|
|
self.logger.info(f"保存检查点到: {filepath}") |
|
|
|
|
|
def load_checkpoint(self, checkpoint_path: str, load_optimizer: bool = True, load_scheduler: bool = True): |
|
|
""" |
|
|
加载检查点 |
|
|
|
|
|
Args: |
|
|
checkpoint_path: 检查点路径 |
|
|
load_optimizer: 是否加载优化器状态 |
|
|
load_scheduler: 是否加载调度器状态 |
|
|
""" |
|
|
checkpoint = torch.load(checkpoint_path, map_location=self.device) |
|
|
|
|
|
|
|
|
self.model.load_state_dict(checkpoint['model_state_dict']) |
|
|
|
|
|
|
|
|
self.current_epoch = checkpoint.get('epoch', 0) |
|
|
self.train_history = defaultdict(list, checkpoint.get('train_history', {})) |
|
|
self.val_history = defaultdict(list, checkpoint.get('val_history', {})) |
|
|
self.best_score = checkpoint.get('best_score') |
|
|
|
|
|
|
|
|
if load_optimizer and 'optimizer_state_dict' in checkpoint: |
|
|
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
|
|
|
|
|
|
|
|
if load_scheduler and self.scheduler is not None and 'scheduler_state_dict' in checkpoint: |
|
|
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) |
|
|
|
|
|
|
|
|
if self.mixed_precision and self.scaler is not None and 'scaler_state_dict' in checkpoint: |
|
|
self.scaler.load_state_dict(checkpoint['scaler_state_dict']) |
|
|
|
|
|
self.logger.info(f"从 {checkpoint_path} 恢复训练状态,当前epoch: {self.current_epoch}") |
|
|
|
|
|
def _log_epoch_progress(self, epoch: int, train_metrics: Dict[str, float], val_metrics: Optional[Dict[str, float]] = None): |
|
|
"""记录epoch进度""" |
|
|
log_msg = f"Epoch {epoch:3d} | " |
|
|
|
|
|
|
|
|
train_loss = train_metrics.get('loss', 0) |
|
|
train_lr = train_metrics.get('lr', 0) |
|
|
log_msg += f"Train Loss: {train_loss:.6f} | LR: {train_lr:.2e}" |
|
|
|
|
|
|
|
|
if val_metrics: |
|
|
val_loss = val_metrics.get('val_loss', 0) |
|
|
val_mae = val_metrics.get('val_regression_mae', 0) |
|
|
val_r2 = val_metrics.get('val_regression_r2', 0) |
|
|
log_msg += f" | Val Loss: {val_loss:.6f} | Val MAE: {val_mae:.6f} | Val R²: {val_r2:.4f}" |
|
|
|
|
|
self.logger.info(log_msg) |
|
|
|
|
|
def evaluate(self, test_loader: DataLoader) -> Dict[str, Any]: |
|
|
""" |
|
|
评估模型 |
|
|
|
|
|
Args: |
|
|
test_loader: 测试数据加载器 |
|
|
|
|
|
Returns: |
|
|
评估结果 |
|
|
""" |
|
|
self.logger.info("开始模型评估...") |
|
|
|
|
|
self.model.eval() |
|
|
all_predictions = [] |
|
|
all_targets = [] |
|
|
|
|
|
with torch.no_grad(): |
|
|
for features, targets in test_loader: |
|
|
|
|
|
if features.device != self.device: |
|
|
features = features.to(self.device) |
|
|
if targets.device != self.device: |
|
|
targets = targets.to(self.device) |
|
|
|
|
|
predictions = self.model(features) |
|
|
|
|
|
all_predictions.append(predictions.cpu()) |
|
|
all_targets.append(targets.cpu()) |
|
|
|
|
|
|
|
|
all_predictions = torch.cat(all_predictions, dim=0) |
|
|
all_targets = torch.cat(all_targets, dim=0) |
|
|
|
|
|
|
|
|
if self.diagnostic_mode: |
|
|
|
|
|
evaluation_results = self.metrics.evaluate_predictions_diagnostic( |
|
|
all_predictions, all_targets, |
|
|
component_names=['ΔPAD_P', 'ΔPAD_A', 'ΔPAD_D'] |
|
|
) |
|
|
else: |
|
|
|
|
|
evaluation_results = self.metrics.evaluate_predictions(all_predictions, all_targets) |
|
|
|
|
|
self.logger.info("模型评估完成") |
|
|
|
|
|
return evaluation_results |
|
|
|
|
|
|
|
|
def create_trainer(model: nn.Module, |
|
|
config: Dict[str, Any], |
|
|
device: Optional[Union[str, torch.device]] = None, |
|
|
logger: Optional[logging.Logger] = None, |
|
|
diagnostic_mode: bool = False) -> Trainer: |
|
|
""" |
|
|
创建训练器的工厂函数 |
|
|
|
|
|
Args: |
|
|
model: 要训练的模型 |
|
|
config: 训练配置 |
|
|
device: 训练设备 |
|
|
logger: 日志记录器 |
|
|
diagnostic_mode: 诊断模式(打印每个维度的详细指标) |
|
|
|
|
|
Returns: |
|
|
训练器实例 |
|
|
""" |
|
|
return Trainer(model, config, device, logger, diagnostic_mode=diagnostic_mode) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
from ..models.pad_predictor import PADPredictor |
|
|
from ..data.data_loader import DataLoader |
|
|
|
|
|
|
|
|
test_config = { |
|
|
'training': { |
|
|
'optimizer': { |
|
|
'type': 'AdamW', |
|
|
'learning_rate': 1e-3, |
|
|
'weight_decay': 1e-4 |
|
|
}, |
|
|
'scheduler': { |
|
|
'type': 'CosineAnnealingLR', |
|
|
'T_max': 100 |
|
|
}, |
|
|
'epochs': { |
|
|
'max_epochs': 5, |
|
|
'early_stopping': { |
|
|
'enabled': True, |
|
|
'patience': 10 |
|
|
} |
|
|
}, |
|
|
'loss': { |
|
|
'type': 'MSELoss' |
|
|
} |
|
|
}, |
|
|
'validation': { |
|
|
'val_frequency': 1, |
|
|
'model_selection': { |
|
|
'criterion': 'val_loss', |
|
|
'mode': 'min' |
|
|
} |
|
|
}, |
|
|
'hardware': { |
|
|
'device': 'cpu', |
|
|
'mixed_precision': { |
|
|
'enabled': False |
|
|
} |
|
|
}, |
|
|
'debug': { |
|
|
'gradient_checking': { |
|
|
'clip_value': 1.0 |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
model = PADPredictor() |
|
|
trainer = create_trainer(model, test_config) |
|
|
|
|
|
|
|
|
from ..data.synthetic_generator import SyntheticDataGenerator |
|
|
generator = SyntheticDataGenerator(num_samples=100) |
|
|
data, labels = generator.generate_data() |
|
|
|
|
|
|
|
|
data_loader = DataLoader(test_config.get('data', {})) |
|
|
train_loader, val_loader, _ = data_loader.get_all_loaders(data=np.hstack([data, labels])) |
|
|
|
|
|
print("开始训练测试...") |
|
|
history = trainer.train(train_loader, val_loader) |
|
|
|
|
|
print("训练测试完成!") |
|
|
print(f"训练历史: {len(history['train_history']['loss'])} 个epoch") |