|
|
|
|
|
""" |
|
|
训练脚本 |
|
|
Training script for PAD Predictor |
|
|
|
|
|
该脚本实现了完整的训练流程,包括: |
|
|
- 命令行参数解析 |
|
|
- 配置文件加载和验证 |
|
|
- 数据加载和预处理 |
|
|
- 模型训练和验证 |
|
|
- 检查点保存和恢复 |
|
|
- 训练日志和可视化 |
|
|
- 早停机制和学习率调度 |
|
|
- 多GPU和混合精度训练支持 |
|
|
|
|
|
使用方法: |
|
|
python train.py --config configs/training_config.yaml --model-config configs/model_config.yaml |
|
|
python train.py --config configs/training_config.yaml --data-path data/train.csv --output-dir outputs/ |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import os |
|
|
import sys |
|
|
import yaml |
|
|
import json |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import numpy as np |
|
|
from pathlib import Path |
|
|
from typing import Dict, Any, Optional, Union |
|
|
import logging |
|
|
import warnings |
|
|
from datetime import datetime |
|
|
|
|
|
|
|
|
project_root = Path(__file__).parent.parent.parent |
|
|
sys.path.insert(0, str(project_root)) |
|
|
|
|
|
from src.models.pad_predictor import PADPredictor, create_pad_predictor |
|
|
from src.data.data_loader import DataLoader, load_data_from_config |
|
|
from src.utils.trainer import Trainer, create_trainer |
|
|
from src.utils.logger import TrainingLogger, create_logger, ProgressLogger |
|
|
|
|
|
|
|
|
def parse_arguments() -> argparse.Namespace: |
|
|
""" |
|
|
解析命令行参数 |
|
|
|
|
|
Returns: |
|
|
解析后的参数 |
|
|
""" |
|
|
parser = argparse.ArgumentParser( |
|
|
description='PAD预测器训练脚本', |
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
'--config', '-c', |
|
|
type=str, |
|
|
default='configs/training_config.yaml', |
|
|
help='训练配置文件路径' |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--model-config', '-mc', |
|
|
type=str, |
|
|
default='configs/model_config.yaml', |
|
|
help='模型配置文件路径' |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
'--data-path', '-d', |
|
|
type=str, |
|
|
help='数据文件路径(可选,覆盖配置文件中的设置)' |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--train-data-path', |
|
|
type=str, |
|
|
help='训练数据文件路径' |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--val-data-path', |
|
|
type=str, |
|
|
help='验证数据文件路径' |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--test-data-path', |
|
|
type=str, |
|
|
help='测试数据文件路径' |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
'--output-dir', '-o', |
|
|
type=str, |
|
|
default='outputs', |
|
|
help='输出目录' |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--experiment-name', '-e', |
|
|
type=str, |
|
|
help='实验名称(覆盖配置文件中的设置)' |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
'--epochs', |
|
|
type=int, |
|
|
help='训练轮次(覆盖配置文件中的设置)' |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--batch-size', |
|
|
type=int, |
|
|
help='批次大小(覆盖配置文件中的设置)' |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--learning-rate', '-lr', |
|
|
type=float, |
|
|
help='学习率(覆盖配置文件中的设置)' |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--resume', |
|
|
type=str, |
|
|
help='从检查点恢复训练的路径' |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
'--device', |
|
|
type=str, |
|
|
choices=['auto', 'cpu', 'cuda', 'mps'], |
|
|
default='auto', |
|
|
help='训练设备' |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--gpu-id', |
|
|
type=int, |
|
|
default=0, |
|
|
help='GPU ID(当使用CUDA时)' |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
'--debug', |
|
|
action='store_true', |
|
|
help='启用调试模式' |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--fast-train', |
|
|
action='store_true', |
|
|
help='快速训练模式(用于调试)' |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--diagnostic', |
|
|
action='store_true', |
|
|
help='诊断模式:打印每个维度的详细指标(R²、MAE、RMSE、MAPE)' |
|
|
) |
|
|
parser.add_argument( |
|
|
'--synthetic-data', |
|
|
action='store_true', |
|
|
help='使用合成数据训练' |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--num-samples', |
|
|
type=int, |
|
|
default=1000, |
|
|
help='合成数据样本数量' |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
'--seed', |
|
|
type=int, |
|
|
help='随机种子(覆盖配置文件中的设置)' |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--verbose', '-v', |
|
|
action='store_true', |
|
|
help='详细输出' |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--dry-run', |
|
|
action='store_true', |
|
|
help='干运行(只检查配置,不实际训练)' |
|
|
) |
|
|
|
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def load_config(config_path: str) -> Dict[str, Any]: |
|
|
""" |
|
|
加载配置文件 |
|
|
|
|
|
Args: |
|
|
config_path: 配置文件路径 |
|
|
|
|
|
Returns: |
|
|
配置字典 |
|
|
""" |
|
|
if not os.path.exists(config_path): |
|
|
raise FileNotFoundError(f"配置文件不存在: {config_path}") |
|
|
|
|
|
with open(config_path, 'r', encoding='utf-8') as f: |
|
|
config = yaml.safe_load(f) |
|
|
|
|
|
return config |
|
|
|
|
|
|
|
|
def merge_configs(base_config: Dict[str, Any], override_config: Dict[str, Any]) -> Dict[str, Any]: |
|
|
""" |
|
|
合并配置字典 |
|
|
|
|
|
Args: |
|
|
base_config: 基础配置 |
|
|
override_config: 覆盖配置 |
|
|
|
|
|
Returns: |
|
|
合并后的配置 |
|
|
""" |
|
|
def _merge_dict(base: Dict, override: Dict) -> Dict: |
|
|
result = base.copy() |
|
|
for key, value in override.items(): |
|
|
if key in result and isinstance(result[key], dict) and isinstance(value, dict): |
|
|
result[key] = _merge_dict(result[key], value) |
|
|
else: |
|
|
result[key] = value |
|
|
return result |
|
|
|
|
|
return _merge_dict(base_config, override_config) |
|
|
|
|
|
|
|
|
def validate_config(config: Dict[str, Any]) -> bool: |
|
|
""" |
|
|
验证配置文件 |
|
|
|
|
|
Args: |
|
|
config: 配置字典 |
|
|
|
|
|
Returns: |
|
|
是否有效 |
|
|
""" |
|
|
required_keys = [ |
|
|
'training_info', |
|
|
'data', |
|
|
'training', |
|
|
'validation', |
|
|
'logging', |
|
|
'checkpointing' |
|
|
] |
|
|
|
|
|
for key in required_keys: |
|
|
if key not in config: |
|
|
logging.error(f"配置文件缺少必需的键: {key}") |
|
|
return False |
|
|
|
|
|
|
|
|
training_config = config.get('training', {}) |
|
|
required_training_keys = ['optimizer', 'epochs'] |
|
|
for key in required_training_keys: |
|
|
if key not in training_config: |
|
|
logging.error(f"训练配置缺少必需的键: {key}") |
|
|
return False |
|
|
|
|
|
|
|
|
data_config = config.get('data', {}) |
|
|
if 'dataloader' not in data_config: |
|
|
logging.error("数据配置缺少 'dataloader' 键") |
|
|
return False |
|
|
|
|
|
return True |
|
|
|
|
|
|
|
|
def setup_environment(config: Dict[str, Any], args: argparse.Namespace): |
|
|
""" |
|
|
设置训练环境 |
|
|
|
|
|
Args: |
|
|
config: 配置字典 |
|
|
args: 命令行参数 |
|
|
""" |
|
|
|
|
|
seed = args.seed or config.get('training_info', {}).get('seed', 42) |
|
|
torch.manual_seed(seed) |
|
|
np.random.seed(seed) |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.manual_seed(seed) |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
|
|
|
|
|
|
if args.device == 'cuda' and torch.cuda.is_available(): |
|
|
torch.cuda.set_device(args.gpu_id) |
|
|
logging.info(f"设置CUDA设备: {args.gpu_id}") |
|
|
|
|
|
|
|
|
torch.backends.cudnn.benchmark = True |
|
|
torch.backends.cudnn.deterministic = False |
|
|
|
|
|
|
|
|
if not args.verbose: |
|
|
warnings.filterwarnings('ignore') |
|
|
|
|
|
|
|
|
def create_output_directory(base_dir: str, experiment_name: str) -> str: |
|
|
""" |
|
|
创建输出目录 |
|
|
|
|
|
Args: |
|
|
base_dir: 基础目录 |
|
|
experiment_name: 实验名称 |
|
|
|
|
|
Returns: |
|
|
输出目录路径 |
|
|
""" |
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
output_dir = Path(base_dir) / f"{experiment_name}_{timestamp}" |
|
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
(output_dir / 'checkpoints').mkdir(exist_ok=True) |
|
|
(output_dir / 'logs').mkdir(exist_ok=True) |
|
|
(output_dir / 'plots').mkdir(exist_ok=True) |
|
|
(output_dir / 'configs').mkdir(exist_ok=True) |
|
|
|
|
|
return str(output_dir) |
|
|
|
|
|
|
|
|
def load_data(config: Dict[str, Any], args: argparse.Namespace) -> tuple: |
|
|
""" |
|
|
加载数据 |
|
|
|
|
|
Args: |
|
|
config: 配置字典 |
|
|
args: 命令行参数 |
|
|
|
|
|
Returns: |
|
|
训练、验证、测试数据加载器 |
|
|
""" |
|
|
if args.synthetic_data: |
|
|
|
|
|
logging.info(f"生成合成数据,样本数量: {args.num_samples}") |
|
|
|
|
|
from src.data.synthetic_generator import SyntheticDataGenerator |
|
|
generator = SyntheticDataGenerator(num_samples=args.num_samples) |
|
|
data, labels = generator.generate_data() |
|
|
|
|
|
|
|
|
data_loader_config = config.get('data', {}).get('dataloader', {}) |
|
|
data_loader = DataLoader(data_loader_config) |
|
|
|
|
|
if args.fast_train: |
|
|
|
|
|
train_loader, val_loader, test_loader = data_loader.get_synthetic_loaders( |
|
|
num_samples=min(args.num_samples, 100), |
|
|
split_ratio=(0.7, 0.2, 0.1) |
|
|
) |
|
|
else: |
|
|
train_loader, val_loader, test_loader = data_loader.get_synthetic_loaders( |
|
|
num_samples=args.num_samples, |
|
|
split_ratio=(0.7, 0.15, 0.15) |
|
|
) |
|
|
|
|
|
else: |
|
|
|
|
|
if args.data_path: |
|
|
|
|
|
logging.info(f"从单个文件加载数据: {args.data_path}") |
|
|
|
|
|
data_loader_config = config.get('data', {}).get('dataloader', {}) |
|
|
data_loader = DataLoader(data_loader_config) |
|
|
train_loader, val_loader, test_loader = data_loader.get_all_loaders( |
|
|
data_path=args.data_path |
|
|
) |
|
|
|
|
|
elif args.train_data_path and args.val_data_path: |
|
|
|
|
|
logging.info(f"分别加载训练和验证数据") |
|
|
|
|
|
data_loader_config = config.get('data', {}).get('dataloader', {}) |
|
|
data_loader = DataLoader(data_loader_config) |
|
|
|
|
|
train_loader = data_loader.get_train_loader(data_path=args.train_data_path) |
|
|
val_loader = data_loader.get_val_loader(data_path=args.val_data_path) |
|
|
|
|
|
if args.test_data_path: |
|
|
test_loader = data_loader.get_test_loader(data_path=args.test_data_path) |
|
|
else: |
|
|
|
|
|
logging.info("未提供测试数据,从训练数据分割") |
|
|
|
|
|
test_loader = val_loader |
|
|
|
|
|
else: |
|
|
|
|
|
logging.info("从配置文件加载数据") |
|
|
|
|
|
try: |
|
|
|
|
|
data_config = config.get('data', {}) |
|
|
train_path = data_config.get('train_data_path', 'data/train.csv') |
|
|
val_path = data_config.get('val_data_path', 'data/val.csv') |
|
|
test_path = data_config.get('test_data_path', 'data/test.csv') |
|
|
|
|
|
dataloader_config = data_config.get('dataloader', {}) |
|
|
data_loader = DataLoader(dataloader_config) |
|
|
|
|
|
train_loader = data_loader.get_train_loader(train_path) |
|
|
val_loader = data_loader.get_val_loader(val_path) |
|
|
test_loader = data_loader.get_test_loader(test_path) |
|
|
|
|
|
except Exception as e: |
|
|
logging.error(f"从文件加载数据失败: {e}") |
|
|
logging.info("回退到合成数据") |
|
|
|
|
|
new_args = argparse.Namespace(**vars(args)) |
|
|
new_args.synthetic_data = True |
|
|
return load_data(config, new_args) |
|
|
|
|
|
logging.info(f"数据加载完成") |
|
|
logging.info(f"训练批次数: {len(train_loader)}") |
|
|
logging.info(f"验证批次数: {len(val_loader)}") |
|
|
logging.info(f"测试批次数: {len(test_loader)}") |
|
|
|
|
|
return train_loader, val_loader, test_loader |
|
|
|
|
|
|
|
|
def create_model(model_config: Dict[str, Any], training_config: Dict[str, Any]) -> nn.Module: |
|
|
""" |
|
|
创建模型 |
|
|
|
|
|
Args: |
|
|
model_config: 模型配置 |
|
|
training_config: 训练配置 |
|
|
|
|
|
Returns: |
|
|
模型实例 |
|
|
""" |
|
|
|
|
|
model = create_pad_predictor(model_config) |
|
|
|
|
|
|
|
|
model_info = model.get_model_info() |
|
|
logging.info(f"模型创建完成:") |
|
|
logging.info(f" 模型类型: {model_info['model_type']}") |
|
|
logging.info(f" 输入维度: {model_info['input_dim']}") |
|
|
logging.info(f" 输出维度: {model_info['output_dim']}") |
|
|
logging.info(f" 总参数数: {model_info['total_parameters']:,}") |
|
|
logging.info(f" 可训练参数: {model_info['trainable_parameters']:,}") |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
def apply_fast_train_config(config: Dict[str, Any]) -> Dict[str, Any]: |
|
|
""" |
|
|
应用快速训练配置 |
|
|
|
|
|
Args: |
|
|
config: 原始配置 |
|
|
|
|
|
Returns: |
|
|
修改后的配置 |
|
|
""" |
|
|
fast_config = config.copy() |
|
|
|
|
|
|
|
|
if 'training' in fast_config: |
|
|
if 'epochs' in fast_config['training']: |
|
|
fast_config['training']['epochs']['max_epochs'] = 5 |
|
|
if 'optimizer' in fast_config['training']: |
|
|
fast_config['training']['optimizer']['learning_rate'] = 1e-3 |
|
|
|
|
|
|
|
|
if 'data' in fast_config and 'dataloader' in fast_config['data']: |
|
|
fast_config['data']['dataloader']['batch_size'] = 8 |
|
|
fast_config['data']['dataloader']['num_workers'] = 0 |
|
|
|
|
|
|
|
|
if 'debug' in fast_config: |
|
|
fast_config['debug']['enabled'] = True |
|
|
fast_config['debug']['fast_train']['enabled'] = True |
|
|
fast_config['debug']['fast_train']['max_epochs'] = 5 |
|
|
|
|
|
return fast_config |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""主函数""" |
|
|
|
|
|
args = parse_arguments() |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.DEBUG if args.verbose else logging.INFO, |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
|
) |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
logger.info("开始PAD预测器训练") |
|
|
|
|
|
try: |
|
|
|
|
|
logger.info(f"加载训练配置: {args.config}") |
|
|
training_config = load_config(args.config) |
|
|
|
|
|
logger.info(f"加载模型配置: {args.model_config}") |
|
|
model_config = load_config(args.model_config) |
|
|
|
|
|
|
|
|
overrides = {} |
|
|
|
|
|
if args.epochs: |
|
|
overrides.setdefault('training', {})['epochs'] = {'max_epochs': args.epochs} |
|
|
|
|
|
if args.batch_size: |
|
|
overrides.setdefault('data', {}).setdefault('dataloader', {})['batch_size'] = args.batch_size |
|
|
|
|
|
if args.learning_rate: |
|
|
overrides.setdefault('training', {}).setdefault('optimizer', {})['learning_rate'] = args.learning_rate |
|
|
|
|
|
if args.seed: |
|
|
overrides.setdefault('training_info', {})['seed'] = args.seed |
|
|
|
|
|
if args.experiment_name: |
|
|
overrides.setdefault('training_info', {})['experiment_name'] = args.experiment_name |
|
|
|
|
|
if args.device != 'auto': |
|
|
overrides.setdefault('hardware', {})['device'] = args.device |
|
|
|
|
|
|
|
|
if args.fast_train: |
|
|
logger.info("启用快速训练模式") |
|
|
training_config = apply_fast_train_config(training_config) |
|
|
|
|
|
|
|
|
if overrides: |
|
|
training_config = merge_configs(training_config, overrides) |
|
|
logger.info("应用命令行参数覆盖") |
|
|
|
|
|
|
|
|
if not validate_config(training_config): |
|
|
logger.error("配置验证失败") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
setup_environment(training_config, args) |
|
|
|
|
|
|
|
|
experiment_name = training_config.get('training_info', {}).get('experiment_name', 'pad_predictor') |
|
|
output_dir = create_output_directory(args.output_dir, experiment_name) |
|
|
logger.info(f"输出目录: {output_dir}") |
|
|
|
|
|
|
|
|
config_save_dir = Path(output_dir) / 'configs' |
|
|
with open(config_save_dir / 'training_config.yaml', 'w', encoding='utf-8') as f: |
|
|
yaml.dump(training_config, f, default_flow_style=False, allow_unicode=True) |
|
|
|
|
|
with open(config_save_dir / 'model_config.yaml', 'w', encoding='utf-8') as f: |
|
|
yaml.dump(model_config, f, default_flow_style=False, allow_unicode=True) |
|
|
|
|
|
|
|
|
if args.dry_run: |
|
|
logger.info("干运行模式,配置验证完成,退出") |
|
|
return |
|
|
|
|
|
|
|
|
with create_logger(training_config, experiment_name, output_dir) as training_logger: |
|
|
|
|
|
training_logger.log_config(training_config, 'training_config') |
|
|
training_logger.log_config(model_config, 'model_config') |
|
|
|
|
|
|
|
|
train_loader, val_loader, test_loader = load_data(training_config, args) |
|
|
|
|
|
|
|
|
model = create_model(model_config, training_config) |
|
|
|
|
|
|
|
|
training_logger.log_model_info(model.get_model_info()) |
|
|
|
|
|
|
|
|
device = args.device if args.device != 'auto' else None |
|
|
diagnostic_mode = args.diagnostic |
|
|
trainer = create_trainer(model, training_config, device, training_logger.logger, |
|
|
diagnostic_mode=diagnostic_mode) |
|
|
|
|
|
|
|
|
if args.resume: |
|
|
logger.info(f"从检查点恢复训练: {args.resume}") |
|
|
trainer.load_checkpoint(args.resume) |
|
|
|
|
|
|
|
|
logger.info("开始训练...") |
|
|
start_time = datetime.now() |
|
|
|
|
|
|
|
|
total_steps = len(train_loader) * training_config.get('training', {}).get('epochs', {}).get('max_epochs', 100) |
|
|
progress_logger = ProgressLogger(total_steps, log_frequency=10) |
|
|
progress_logger.start() |
|
|
|
|
|
|
|
|
history = trainer.train( |
|
|
train_loader=train_loader, |
|
|
val_loader=val_loader, |
|
|
save_dir=Path(output_dir) / 'checkpoints' |
|
|
) |
|
|
|
|
|
|
|
|
train_hist = history['train_history'] |
|
|
val_hist = history['val_history'] |
|
|
|
|
|
|
|
|
num_epochs = len(train_hist.get('loss', [])) |
|
|
for epoch in range(num_epochs): |
|
|
|
|
|
train_metrics = {} |
|
|
for key, values in train_hist.items(): |
|
|
if epoch < len(values): |
|
|
train_metrics[key] = values[epoch] |
|
|
training_logger.log_metrics(train_metrics, prefix='train', step=epoch) |
|
|
|
|
|
|
|
|
val_metrics = {} |
|
|
for key, values in val_hist.items(): |
|
|
if epoch < len(values): |
|
|
val_metrics[key] = values[epoch] |
|
|
if val_metrics: |
|
|
training_logger.log_metrics(val_metrics, prefix='val', step=epoch) |
|
|
|
|
|
|
|
|
training_logger.save_metrics_history() |
|
|
|
|
|
|
|
|
training_logger.plot_training_curves() |
|
|
|
|
|
|
|
|
if test_loader: |
|
|
logger.info("在测试集上评估模型...") |
|
|
evaluation_results = trainer.evaluate(test_loader) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
final_metrics = { |
|
|
'final_train_loss': history['train_history']['loss'][-1] if history['train_history']['loss'] else 0, |
|
|
'final_val_loss': history['val_history']['val_loss'][-1] if history['val_history']['val_loss'] else 0, |
|
|
'best_val_loss': min(history['val_history']['val_loss']) if history['val_history']['val_loss'] else 0, |
|
|
'total_epochs': len(history['train_history']['loss']) |
|
|
} |
|
|
|
|
|
|
|
|
if 'regression' in evaluation_results: |
|
|
regression_metrics = evaluation_results['regression'] |
|
|
if 'overall' in regression_metrics: |
|
|
|
|
|
overall = regression_metrics['overall'] |
|
|
final_metrics['test_mae'] = overall.get('mae', 0) |
|
|
final_metrics['test_rmse'] = overall.get('rmse', 0) |
|
|
final_metrics['test_r2_mean'] = overall.get('r2', 0) |
|
|
final_metrics['test_r2_robust'] = overall.get('r2_robust', 0) |
|
|
final_metrics['test_mape'] = overall.get('mape', 0) |
|
|
|
|
|
training_logger.log_metrics(final_metrics, prefix='final') |
|
|
training_logger.log_experiment_summary(final_metrics) |
|
|
|
|
|
|
|
|
end_time = datetime.now() |
|
|
training_time = end_time - start_time |
|
|
|
|
|
logger.info(f"训练完成!") |
|
|
logger.info(f"总训练时间: {training_time}") |
|
|
logger.info(f"输出目录: {output_dir}") |
|
|
logger.info(f"最佳模型保存在: {Path(output_dir) / 'checkpoints' / 'best_model.pth'}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"训练过程中发生错误: {e}") |
|
|
if args.debug: |
|
|
import traceback |
|
|
logger.error(traceback.format_exc()) |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |