Chordia / configs /quick_training_config.yaml
Corolin's picture
first commit
0a6452f
# 快速训练配置文件
# Quick Training Configuration - 用于快速验证和调试
# 训练基本信息
training_info:
experiment_name: "emotion_prediction_quick"
description: "基于MLP的情绪与生理状态变化预测模型快速训练"
seed: 42
# 数据配置
data:
# 数据路径
train_data_path: "data/train.csv"
val_data_path: "data/val.csv"
test_data_path: "data/test.csv"
# 数据预处理
preprocessing:
# 特征标准化
feature_scaling:
method: "standard" # standard, min_max, robust, none
pad_features: "standard" # PAD特征标准化方法
vitality_feature: "min_max" # 活力值标准化方法
# 数据增强
augmentation:
enabled: false
noise_std: 0.01
mixup_alpha: 0.2
# 数据加载 - 较小的批次大小用于快速训练
dataloader:
batch_size: 32
num_workers: 2
pin_memory: true
shuffle: true
drop_last: false
# 训练超参数 - 快速训练设置
training:
# 优化器配置
optimizer:
type: "AdamW"
learning_rate: 0.001 # 稍高的学习率
weight_decay: 0.01
betas: [0.9, 0.999]
eps: 1e-8
# 学习率调度
scheduler:
type: "CosineAnnealingLR"
T_max: 50 # 与max_epochs相同
eta_min: 1e-6
verbose: true
# 训练轮次 - 快速训练
epochs:
max_epochs: 50
early_stopping:
enabled: true
patience: 10 # 较短的耐心
min_delta: 1e-4
monitor: "val_loss"
mode: "min"
# 损失函数
loss:
type: "MSELoss"
reduction: "mean"
# 多任务损失权重
multi_task_weights:
delta_pad: 1.0
delta_pressure: 1.0
confidence: 0.5
# 验证配置
validation:
# 验证频率
val_frequency: 1
# 验证指标
metrics:
- "MSE"
- "MAE"
- "RMSE"
- "R2"
- "MAPE"
# 模型选择
model_selection:
criterion: "val_loss"
mode: "min"
# 日志和保存配置
logging:
# 日志级别
level: "INFO"
# 日志文件
log_dir: "logs"
log_file: "training.log"
# TensorBoard
tensorboard:
enabled: true
log_dir: "runs"
comment: "_quick_train"
# 进度条
progress_bar:
enabled: true
update_frequency: 5 # 更频繁的更新
# 检查点保存
checkpointing:
# 保存目录
save_dir: "checkpoints"
# 保存策略
save_strategy: "best"
# 文件命名
filename_template: "model_epoch_{epoch}_val_{val_loss:.4f}.pth"
# 保存内容
save_items:
- "model_state_dict"
- "optimizer_state_dict"
- "scheduler_state_dict"
- "epoch"
- "loss"
- "metrics"
- "config"
# 硬件配置
hardware:
# 设备选择
device: "auto"
# GPU配置
gpu:
id: 0
memory_fraction: 0.8
allow_growth: true
# 混合精度训练
mixed_precision:
enabled: false
opt_level: "O1"
# 调试配置
debug:
# 调试模式
enabled: true
# 快速训练(用于调试)
fast_train:
enabled: true
max_epochs: 50
batch_size: 32
subset_size: 1000
# 梯度检查
gradient_checking:
enabled: true
clip_value: 1.0
# 数据检查
data_checking:
enabled: true
check_nan: true
check_inf: true
check_range: true
# 实验跟踪
experiment_tracking:
# 是否启用实验跟踪
enabled: false
# MLflow配置
mlflow:
tracking_uri: "http://localhost:5000"
experiment_name: "emotion_prediction_quick"
run_name: null
tags: {}
params: {}