Chordia / docs /CONFIGURATION.md
Corolin's picture
first commit
0a6452f
# 配置文件说明文档
本文档详细介绍了情绪与生理状态变化预测模型的所有配置选项、参数说明和使用示例。
## 目录
1. [配置系统概述](#配置系统概述)
2. [模型配置](#模型配置)
3. [训练配置](#训练配置)
4. [数据配置](#数据配置)
5. [推理配置](#推理配置)
6. [日志配置](#日志配置)
7. [硬件配置](#硬件配置)
8. [实验跟踪配置](#实验跟踪配置)
9. [配置最佳实践](#配置最佳实践)
10. [配置验证](#配置验证)
## 配置系统概述
### 配置文件格式
项目使用YAML格式的配置文件,支持:
- 层次化结构
- 注释支持
- 变量引用
- 环境变量替换
- 配置继承
### 配置文件加载顺序
1. 默认配置 (内置)
2. 全局配置文件 (`~/.emotion-prediction/config.yaml`)
3. 项目配置文件 (`configs/`)
4. 命令行参数覆盖
### 配置管理器
```python
from src.utils.config import ConfigManager
# 加载配置
config_manager = ConfigManager()
config = config_manager.load_config("configs/training_config.yaml")
# 访问配置
learning_rate = config.training.optimizer.learning_rate
batch_size = config.training.batch_size
# 配置验证
config_manager.validate_config(config)
```
## 模型配置
### 主配置文件: `configs/model_config.yaml`
```yaml
# ========================================
# 模型配置文件
# ========================================
# 模型基本信息
model_info:
name: "MLP_Emotion_Predictor"
type: "MLP"
version: "1.0"
description: "基于MLP的情绪与生理状态变化预测模型"
author: "Research Team"
# 输入输出维度配置
dimensions:
input_dim: 7 # 输入维度:User PAD 3维 + Vitality 1维 + Current PAD 3维
output_dim: 3 # 输出维度:ΔPAD 3维(ΔPleasure, ΔArousal, ΔDominance)
# 网络架构配置
architecture:
# 隐藏层配置
hidden_layers:
- size: 128
activation: "ReLU"
dropout: 0.2
batch_norm: false
layer_norm: false
- size: 64
activation: "ReLU"
dropout: 0.2
batch_norm: false
layer_norm: false
- size: 32
activation: "ReLU"
dropout: 0.1
batch_norm: false
layer_norm: false
# 输出层配置
output_layer:
activation: "Linear" # 线性激活,用于回归任务
# 正则化配置
use_batch_norm: false
use_layer_norm: false
# 权重初始化配置
initialization:
weight_init: "xavier_uniform" # 可选: xavier_uniform, xavier_normal, kaiming_uniform, kaiming_normal
bias_init: "zeros" # 可选: zeros, ones, uniform, normal
# 正则化配置
regularization:
# L2正则化
weight_decay: 0.0001
# Dropout配置
dropout_config:
type: "standard" # 标准 dropout
rate: 0.2 # Dropout 概率
# 批归一化
batch_norm_config:
momentum: 0.1
eps: 1e-5
# 模型保存配置
model_saving:
save_best_only: true # 只保存最佳模型
save_format: "pytorch" # 保存格式: pytorch, onnx, torchscript
checkpoint_interval: 10 # 每10个epoch保存一次检查点
max_checkpoints: 5 # 最多保存5个检查点
# PAD情绪空间特殊配置
emotion_model:
# PAD值的范围限制
pad_space:
pleasure_range: [-1.0, 1.0] # 快乐维度范围
arousal_range: [-1.0, 1.0] # 激活度维度范围
dominance_range: [-1.0, 1.0] # 支配度维度范围
# 生理指标配置
vitality:
range: [0.0, 100.0] # 活力值范围
normalization: "min_max" # 标准化方法: min_max, z_score, robust
# 预测输出配置
prediction:
# ΔPAD的变化范围限制
delta_pad_range: [-0.5, 0.5] # PAD变化的合理范围
# 压力值变化范围
delta_pressure_range: [-0.3, 0.3]
# 置信度范围
confidence_range: [0.0, 1.0]
```
### 模型配置参数详解
#### `model_info` 模型基本信息
| 参数 | 类型 | 必需 | 默认值 | 说明 |
|------|------|------|--------|------|
| `name` | str | 是 | - | 模型名称 |
| `type` | str | 是 | - | 模型类型 (MLP, CNN, RNN等) |
| `version` | str | 是 | - | 模型版本号 |
| `description` | str | 否 | - | 模型描述 |
| `author` | str | 否 | - | 作者信息 |
#### `dimensions` 输入输出维度
| 参数 | 类型 | 必需 | 默认值 | 说明 |
|------|------|------|--------|------|
| `input_dim` | int | 是 | 7 | 输入特征维度 |
| `output_dim` | int | 是 | 3 | 输出预测维度(ΔPAD 3维) |
#### `architecture` 网络架构
##### `hidden_layers` 隐藏层配置
每个隐藏层支持以下参数:
| 参数 | 类型 | 必需 | 默认值 | 说明 |
|------|------|------|--------|------|
| `size` | int | 是 | - | 神经元数量 |
| `activation` | str | 否 | ReLU | 激活函数 |
| `dropout` | float | 否 | 0.0 | Dropout概率 |
| `batch_norm` | bool | 否 | false | 是否使用批归一化 |
| `layer_norm` | bool | 否 | false | 是否使用层归一化 |
**激活函数选项**:
- `ReLU`: 修正线性单元
- `LeakyReLU`: 泄漏ReLU
- `Tanh`: 双曲正切
- `Sigmoid`: Sigmoid函数
- `GELU`: 高斯误差线性单元
- `Swish`: Swish激活函数
##### `output_layer` 输出层配置
| 参数 | 类型 | 必需 | 默认值 | 说明 |
|------|------|------|--------|------|
| `activation` | str | 否 | Linear | 输出激活函数 |
#### `initialization` 权重初始化
| 参数 | 类型 | 必需 | 默认值 | 说明 |
|------|------|------|--------|------|
| `weight_init` | str | 否 | xavier_uniform | 权重初始化方法 |
| `bias_init` | str | 否 | zeros | 偏置初始化方法 |
**权重初始化选项**:
- `xavier_uniform`: Xavier均匀初始化
- `xavier_normal`: Xavier正态初始化
- `kaiming_uniform`: Kaiming均匀初始化 (适合ReLU)
- `kaiming_normal`: Kaiming正态初始化 (适合ReLU)
- `uniform`: 均匀分布初始化
- `normal`: 正态分布初始化
## 训练配置
### 主配置文件: `configs/training_config.yaml`
```yaml
# ========================================
# 训练配置文件
# ========================================
# 训练基本信息
training_info:
experiment_name: "emotion_prediction_v1"
description: "基于MLP的情绪与生理状态变化预测模型训练"
seed: 42
tags: ["baseline", "mlp", "emotion_prediction"]
# 数据配置
data:
# 数据路径
paths:
train_data: "data/train.csv"
val_data: "data/val.csv"
test_data: "data/test.csv"
# 数据预处理
preprocessing:
# 特征标准化
feature_scaling:
method: "standard" # standard, min_max, robust, none
pad_features: "standard" # PAD特征标准化方法
vitality_feature: "min_max" # 活力值标准化方法
# 标签标准化
label_scaling:
method: "standard"
delta_pad: "standard"
delta_pressure: "standard"
confidence: "none"
# 数据增强
augmentation:
enabled: false
noise_std: 0.01
mixup_alpha: 0.2
augmentation_factor: 2
# 数据验证
validation:
check_ranges: true
check_missing: true
check_outliers: true
outlier_method: "iqr" # iqr, zscore, isolation_forest
# 数据加载器配置
dataloader:
batch_size: 32
num_workers: 4
pin_memory: true
shuffle: true
drop_last: false
persistent_workers: true
# 数据分割
split:
train_ratio: 0.8
val_ratio: 0.1
test_ratio: 0.1
stratify: false
random_seed: 42
# 训练超参数
training:
# 训练轮次
epochs:
max_epochs: 200
warmup_epochs: 5
# 早停配置
early_stopping:
enabled: true
patience: 15 # 监控轮数
min_delta: 1e-4 # 最小改善
monitor: "val_loss" # 监控指标
mode: "min" # min/max
restore_best_weights: true
# 梯度配置
gradient:
clip_enabled: true
clip_value: 1.0
clip_norm: 2 # 1: L1 norm, 2: L2 norm
# 混合精度训练
mixed_precision:
enabled: false
opt_level: "O1" # O0, O1, O2, O3
# 梯度累积
gradient_accumulation:
enabled: false
accumulation_steps: 4
# 优化器配置
optimizer:
type: "AdamW" # Adam, SGD, AdamW, RMSprop, Adagrad
# Adam/AdamW 参数
adam_config:
lr: 0.0005 # 学习率
weight_decay: 0.01 # 权重衰减
betas: [0.9, 0.999] # Beta参数
eps: 1e-8 # 数值稳定性
amsgrad: false # AMSGrad变体
# SGD 参数
sgd_config:
lr: 0.01
momentum: 0.9
weight_decay: 0.0001
nesterov: true
# RMSprop 参数
rmsprop_config:
lr: 0.001
alpha: 0.99
weight_decay: 0.0
momentum: 0.0
# 学习率调度器配置
scheduler:
type: "CosineAnnealingLR" # StepLR, CosineAnnealingLR, ReduceLROnPlateau, ExponentialLR
# 余弦退火调度器
cosine_config:
T_max: 200 # 最大轮数
eta_min: 1e-6 # 最小学习率
last_epoch: -1
# 步长调度器
step_config:
step_size: 30 # 步长
gamma: 0.1 # 衰减因子
# 平台衰减调度器
plateau_config:
patience: 10 # 耐心值
factor: 0.5 # 衰减因子
min_lr: 1e-7 # 最小学习率
threshold: 1e-4 # 改善阈值
verbose: true
# 损失函数配置
loss:
type: "WeightedMSELoss" # MSELoss, L1Loss, SmoothL1Loss, HuberLoss, WeightedMSELoss
# 基础损失参数
base_config:
reduction: "mean" # mean, sum, none
# 加权损失配置
weighted_config:
delta_pad_weight: 1.0 # ΔPAD预测权重
delta_pressure_weight: 1.0 # ΔPressure预测权重
confidence_weight: 0.5 # 置信度预测权重
# Huber损失配置
huber_config:
delta: 1.0 # Huber阈值
# 焦点损失配置 (可选)
focal_config:
alpha: 1.0
gamma: 2.0
# 验证配置
validation:
# 验证频率
val_frequency: 1 # 每多少个epoch验证一次
# 验证指标
metrics:
- "MSE"
- "MAE"
- "RMSE"
- "R2"
- "MAPE"
# 模型选择
model_selection:
criterion: "val_loss" # val_loss, val_mae, val_r2
mode: "min" # min/max
# 验证数据增强
val_augmentation:
enabled: false
methods: []
# 日志和监控配置
logging:
# 日志级别
level: "INFO" # DEBUG, INFO, WARNING, ERROR
# 日志文件
log_dir: "logs"
log_file: "training.log"
max_file_size: "10MB"
backup_count: 5
# TensorBoard
tensorboard:
enabled: true
log_dir: "runs"
comment: ""
flush_secs: 10
# Wandb
wandb:
enabled: false
project: "emotion-prediction"
entity: "your-team"
tags: []
notes: ""
# 进度条
progress_bar:
enabled: true
update_frequency: 10 # 更新频率
leave: true # 训练完成后是否保留
# 检查点保存配置
checkpointing:
# 保存目录
save_dir: "checkpoints"
# 保存策略
save_strategy: "best" # best, last, all
# 文件命名
filename_template: "model_epoch_{epoch}_val_{val_loss:.4f}.pth"
# 保存内容
save_items:
- "model_state_dict"
- "optimizer_state_dict"
- "scheduler_state_dict"
- "epoch"
- "loss"
- "metrics"
- "config"
# 保存频率
save_frequency: 1 # 每多少个epoch保存一次
# 最大检查点数量
max_checkpoints: 5
# 硬件配置
hardware:
# 设备选择
device: "auto" # auto, cpu, cuda, mps
# GPU配置
gpu:
id: 0 # GPU ID
memory_fraction: 0.9 # GPU内存使用比例
allow_growth: true # 动态内存增长
# 混合精度
mixed_precision:
enabled: false
opt_level: "O1"
# 分布式训练
distributed:
enabled: false
backend: "nccl"
init_method: "env://"
world_size: 1
rank: 0
# 调试配置
debug:
# 调试模式
enabled: false
# 快速训练
fast_train:
enabled: false
max_epochs: 5
batch_size: 8
subset_size: 100
# 梯度检查
gradient_checking:
enabled: false
clip_value: 1.0
check_nan: true
check_inf: true
# 数据检查
data_checking:
enabled: true
check_nan: true
check_inf: true
check_range: true
sample_output: true
# 模型检查
model_checking:
enabled: false
count_parameters: true
check_gradients: true
visualize_model: false
# 实验跟踪配置
experiment_tracking:
# 是否启用实验跟踪
enabled: false
# MLflow配置
mlflow:
tracking_uri: "http://localhost:5000"
experiment_name: "emotion_prediction"
run_name: null
tags: {}
params: {}
# Weights & Biases配置
wandb:
project: "emotion-prediction"
entity: null
group: null
job_type: "training"
tags: []
notes: ""
config: {}
# 本地实验跟踪
local:
save_dir: "experiments"
save_config: true
save_metrics: true
save_model: true
```
### 训练配置参数详解
#### `training.epochs` 训练轮次
| 参数 | 类型 | 必需 | 默认值 | 说明 |
|------|------|------|--------|------|
| `max_epochs` | int | 是 | 200 | 最大训练轮数 |
| `warmup_epochs` | int | 否 | 0 | 预热轮数 |
#### `training.early_stopping` 早停配置
| 参数 | 类型 | 必需 | 默认值 | 说明 |
|------|------|------|--------|------|
| `enabled` | bool | 否 | true | 是否启用早停 |
| `patience` | int | 否 | 10 | 耐心值(轮数) |
| `min_delta` | float | 否 | 1e-4 | 最小改善阈值 |
| `monitor` | str | 否 | val_loss | 监控指标 |
| `mode` | str | 否 | min | 监控模式 (min/max) |
| `restore_best_weights` | bool | 否 | true | 恢复最佳权重 |
#### `optimizer` 优化器配置
支持的优化器类型:
- `Adam`: 自适应矩估计
- `AdamW`: Adam with Weight Decay
- `SGD`: 随机梯度下降
- `RMSprop`: RMSprop优化器
- `Adagrad`: 自适应梯度算法
#### `scheduler` 学习率调度器
支持的调度器类型:
- `StepLR`: 步长衰减
- `CosineAnnealingLR`: 余弦退火
- `ReduceLROnPlateau`: 平台衰减
- `ExponentialLR`: 指数衰减
## 数据配置
### 数据配置文件: `configs/data_config.yaml`
```yaml
# ========================================
# 数据配置文件
# ========================================
# 数据路径配置
paths:
# 训练数据
train_data: "data/train.csv"
val_data: "data/val.csv"
test_data: "data/test.csv"
# 预处理器
preprocessor: "models/preprocessor.pkl"
# 数据统计
statistics: "data/statistics.json"
# 数据质量报告
quality_report: "reports/data_quality.html"
# 数据源配置
data_source:
type: "csv" # csv, json, parquet, hdf5, database
# CSV配置
csv_config:
delimiter: ","
encoding: "utf-8"
header: 0
index_col: null
# JSON配置
json_config:
orient: "records" # records, index, values, columns
lines: false
# 数据库配置
database_config:
connection_string: "sqlite:///data.db"
table: "emotion_data"
query: null
# 数据预处理配置
preprocessing:
# 特征处理
features:
# 缺失值处理
missing_values:
strategy: "drop" # drop, fill_mean, fill_median, fill_mode, fill_constant
fill_value: 0.0
# 异常值处理
outliers:
method: "iqr" # iqr, zscore, isolation_forest, none
threshold: 1.5
action: "clip" # clip, remove, flag
# 特征缩放
scaling:
method: "standard" # standard, minmax, robust, none
feature_range: [-1, 1] # MinMax缩放范围
# 特征选择
selection:
enabled: false
method: "correlation" # correlation, mutual_info, rfe
k_best: 10
# 标签处理
labels:
# 缺失值处理
missing_values:
strategy: "fill_mean"
# 标签缩放
scaling:
method: "standard"
# 标签变换
transformation:
enabled: false
method: "log" # log, sqrt, boxcox
# 数据增强配置
augmentation:
enabled: false
# 噪声注入
noise_injection:
enabled: true
noise_type: "gaussian" # gaussian, uniform
noise_std: 0.01
feature_wise: true
# Mixup增强
mixup:
enabled: true
alpha: 0.2
# SMOTE增强 (用于不平衡数据)
smote:
enabled: false
k_neighbors: 5
sampling_strategy: "auto"
# 数据验证配置
validation:
# 数值范围检查
range_validation:
enabled: true
features:
user_pleasure: [-1.0, 1.0]
user_arousal: [-1.0, 1.0]
user_dominance: [-1.0, 1.0]
vitality: [0.0, 100.0]
current_pleasure: [-1.0, 1.0]
current_arousal: [-1.0, 1.0]
current_dominance: [-1.0, 1.0]
labels:
delta_pleasure: [-0.5, 0.5]
delta_arousal: [-0.5, 0.5]
delta_dominance: [-0.5, 0.5]
delta_pressure: [-0.3, 0.3]
confidence: [0.0, 1.0]
# 数据质量检查
quality_checks:
check_duplicates: true
check_missing: true
check_outliers: true
check_correlations: true
check_distribution: true
# 统计报告
statistics:
compute_descriptive: true
compute_correlations: true
compute_distributions: true
save_plots: true
# 合成数据配置
synthetic_data:
enabled: false
# 生成参数
generation:
num_samples: 1000
seed: 42
# 数据分布
distribution:
type: "multivariate_normal" # normal, uniform, multivariate_normal
mean: null
cov: null
# 相关性配置
correlation:
enabled: true
strength: 0.5
structure: "block" # block, random, toeplitz
# 噪声配置
noise:
add_noise: true
noise_type: "gaussian"
noise_std: 0.1
```
## 推理配置
### 推理配置文件: `configs/inference_config.yaml`
```yaml
# ========================================
# 推理配置文件
# ========================================
# 推理基本信息
inference_info:
model_path: "models/best_model.pth"
preprocessor_path: "models/preprocessor.pkl"
device: "auto"
batch_size: 32
# 输入配置
input:
# 输入格式
format: "auto" # auto, list, numpy, pandas, json, csv
# 输入验证
validation:
enabled: true
check_shape: true
check_range: true
check_type: true
# 输入预处理
preprocessing:
normalize: true
handle_missing: "error" # error, fill, skip
missing_value: 0.0
# 输出配置
output:
# 输出格式
format: "dict" # dict, json, csv, numpy
# 输出内容
include:
predictions: true
confidence: true
components: true # delta_pad, delta_pressure, confidence
metadata: false # inference_time, model_info
# 输出后处理
postprocessing:
clip_predictions: true
round_decimals: 6
format_confidence: "percentage" # decimal, percentage
# 性能优化配置
optimization:
# 模型优化
model_optimization:
enabled: true
torch_script: false
onnx: false
quantization: false
# 推理优化
inference_optimization:
warmup: true
warmup_samples: 5
batch_optimization: true
memory_optimization: true
# 缓存配置
caching:
enabled: false
cache_size: 1000
cache_policy: "lru" # lru, fifo
# 监控配置
monitoring:
# 性能监控
performance:
enabled: true
track_latency: true
track_memory: true
track_throughput: true
# 质量监控
quality:
enabled: false
confidence_threshold: 0.5
prediction_validation: true
# 异常检测
anomaly_detection:
enabled: false
method: "statistical" # statistical, isolation_forest
threshold: 2.0
# 服务配置 (用于部署)
service:
# API配置
api:
host: "0.0.0.0"
port: 8000
workers: 1
timeout: 30
# 限流配置
rate_limiting:
enabled: false
requests_per_minute: 100
# 认证配置
authentication:
enabled: false
method: "api_key" # api_key, jwt, basic
# 日志配置
logging:
level: "INFO"
format: "json"
```
## 日志配置
### 日志配置文件: `configs/logging_config.yaml`
```yaml
# ========================================
# 日志配置文件
# ========================================
# 日志系统配置
logging:
# 根日志器
root:
level: "INFO"
handlers: ["console", "file"]
# 日志器配置
loggers:
training:
level: "INFO"
handlers: ["console", "file", "tensorboard"]
propagate: false
inference:
level: "WARNING"
handlers: ["console", "file"]
propagate: false
data:
level: "DEBUG"
handlers: ["file"]
propagate: false
# 处理器配置
handlers:
# 控制台处理器
console:
class: "StreamHandler"
level: "INFO"
formatter: "console"
stream: "ext://sys.stdout"
# 文件处理器
file:
class: "RotatingFileHandler"
level: "DEBUG"
formatter: "detailed"
filename: "logs/app.log"
maxBytes: 10485760 # 10MB
backupCount: 5
encoding: "utf8"
# 错误文件处理器
error_file:
class: "RotatingFileHandler"
level: "ERROR"
formatter: "detailed"
filename: "logs/error.log"
maxBytes: 10485760
backupCount: 3
encoding: "utf8"
# 格式化器配置
formatters:
# 控制台格式
console:
format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
datefmt: "%H:%M:%S"
# 详细格式
detailed:
format: "%(asctime)s - %(name)s - %(levelname)s - %(module)s:%(lineno)d - %(funcName)s - %(message)s"
datefmt: "%Y-%m-%d %H:%M:%S"
# JSON格式
json:
format: '{"timestamp": "%(asctime)s", "level": "%(levelname)s", "logger": "%(name)s", "module": "%(module)s", "line": %(lineno)d, "message": "%(message)s"}'
datefmt: "%Y-%m-%dT%H:%M:%S"
# 日志过滤配置
filters:
# 性能过滤器
performance:
class: "PerformanceFilter"
threshold: 0.1
# 敏感信息过滤器
sensitive:
class: "SensitiveDataFilter"
patterns: ["password", "token", "key"]
```
## 硬件配置
### 硬件配置文件: `configs/hardware_config.yaml`
```yaml
# ========================================
# 硬件配置文件
# ========================================
# 设备配置
device:
# 自动选择
auto:
priority: ["cuda", "mps", "cpu"] # 设备优先级
memory_threshold: 0.8 # 内存使用阈值
# CPU配置
cpu:
num_threads: null # null为自动检测
use_openmp: true
use_mkl: true
# GPU配置
gpu:
# GPU选择
device_id: 0 # GPU ID
memory_fraction: 0.9 # GPU内存使用比例
allow_growth: true # 动态内存增长
# CUDA配置
cuda:
allow_tf32: true # 启用TF32
benchmark: true # 启用cuDNN基准
deterministic: false # 确定性模式
# 混合精度
mixed_precision:
enabled: false
opt_level: "O1" # O0, O1, O2, O3
loss_scale: "dynamic" # static, dynamic
# 多GPU配置
multi_gpu:
enabled: false
device_ids: [0, 1]
output_device: 0
dim: 0 # 数据并行维度
# 内存配置
memory:
# 系统内存
system:
max_usage: 0.8 # 最大使用比例
cleanup_threshold: 0.9 # 清理阈值
# GPU内存
gpu:
max_usage: 0.9
cleanup_interval: 100 # 清理间隔(步数)
# 内存优化
optimization:
enable_gc: true # 启用垃圾回收
gc_threshold: 0.8 # GC触发阈值
pin_memory: true # 锁页内存
share_memory: true # 共享内存
# 性能配置
performance:
# 并行配置
parallel:
num_workers: 4 # 数据加载器工作进程数
prefetch_factor: 2 # 预取因子
# 缓存配置
cache:
model_cache: true # 模型缓存
data_cache: true # 数据缓存
cache_size: 1024 # 缓存大小(MB)
# 编译优化
compilation:
torch_compile: false # PyTorch 2.0编译
jit_script: true # TorchScript
mode: "default" # default, reduce-overhead, max-autotune
```
## 配置最佳实践
### 1. 配置文件组织
```
configs/
├── model_config.yaml # 模型配置
├── training_config.yaml # 训练配置
├── data_config.yaml # 数据配置
├── inference_config.yaml # 推理配置
├── logging_config.yaml # 日志配置
├── hardware_config.yaml # 硬件配置
├── environments/ # 环境特定配置
│ ├── development.yaml
│ ├── staging.yaml
│ └── production.yaml
└── experiments/ # 实验特定配置
├── baseline.yaml
├── large_model.yaml
└── fast_train.yaml
```
### 2. 配置继承
```yaml
# configs/experiments/large_model.yaml
_base_: "../training_config.yaml"
training:
epochs:
max_epochs: 500
model:
architecture:
hidden_layers:
- size: 256
activation: "ReLU"
dropout: 0.3
- size: 128
activation: "ReLU"
dropout: 0.2
- size: 64
activation: "ReLU"
dropout: 0.1
experiment_tracking:
enabled: true
mlflow:
experiment_name: "large_model_experiment"
```
### 3. 环境变量替换
```yaml
# 使用环境变量
model_path: "${MODEL_PATH:/models/default.pth}"
learning_rate: "${LEARNING_RATE:0.001}"
batch_size: "${BATCH_SIZE:32}"
```
### 4. 配置验证
```python
from src.utils.config import ConfigValidator
from src.utils.config import ValidationError
# 创建验证器
validator = ConfigValidator()
# 添加验证规则
validator.add_rule("training.optimizer.lr", lambda x: 0 < x <= 1)
validator.add_rule("model.hidden_dims", lambda x: len(x) > 0)
# 验证配置
try:
validator.validate(config)
except ValidationError as e:
print(f"配置验证失败: {e}")
```
### 5. 配置版本管理
```yaml
# 配置文件版本
config_version: "1.0"
compatibility_version: ">=0.9.0"
# 变更日志
changelog:
- version: "1.0"
changes: ["添加混合精度支持", "更新学习率调度器"]
- version: "0.9"
changes: ["初始版本"]
```
## 配置验证
### 配置验证器
```python
class ConfigValidator:
"""配置验证器"""
def __init__(self):
self.rules = {}
self.schemas = {}
def add_rule(self, path: str, validator: callable, message: str = None):
"""添加验证规则"""
self.rules[path] = {
'validator': validator,
'message': message or f"Invalid value at {path}"
}
def add_schema(self, section: str, schema: Dict):
"""添加配置模式"""
self.schemas[section] = schema
def validate(self, config: Dict) -> bool:
"""验证配置"""
for path, rule in self.rules.items():
value = self._get_nested_value(config, path)
if not rule['validator'](value):
raise ValidationError(rule['message'])
return True
def _get_nested_value(self, config: Dict, path: str):
"""获取嵌套值"""
keys = path.split('.')
value = config
for key in keys:
value = value.get(key)
if value is None:
return None
return value
```
### 常用验证规则
```python
# 数值范围验证
validator.add_rule("training.optimizer.lr", lambda x: 0 < x <= 1, "学习率必须在(0, 1]范围内")
validator.add_rule("model.dropout_rate", lambda x: 0 <= x < 1, "Dropout率必须在[0, 1)范围内")
# 列表验证
validator.add_rule("model.hidden_dims", lambda x: isinstance(x, list) and len(x) > 0