Chordia / docs /CONFIGURATION.md
Corolin's picture
first commit
0a6452f

配置文件说明文档

本文档详细介绍了情绪与生理状态变化预测模型的所有配置选项、参数说明和使用示例。

目录

  1. 配置系统概述
  2. 模型配置
  3. 训练配置
  4. 数据配置
  5. 推理配置
  6. 日志配置
  7. 硬件配置
  8. 实验跟踪配置
  9. 配置最佳实践
  10. 配置验证

配置系统概述

配置文件格式

项目使用YAML格式的配置文件,支持:

  • 层次化结构
  • 注释支持
  • 变量引用
  • 环境变量替换
  • 配置继承

配置文件加载顺序

  1. 默认配置 (内置)
  2. 全局配置文件 (~/.emotion-prediction/config.yaml)
  3. 项目配置文件 (configs/)
  4. 命令行参数覆盖

配置管理器

from src.utils.config import ConfigManager

# 加载配置
config_manager = ConfigManager()
config = config_manager.load_config("configs/training_config.yaml")

# 访问配置
learning_rate = config.training.optimizer.learning_rate
batch_size = config.training.batch_size

# 配置验证
config_manager.validate_config(config)

模型配置

主配置文件: configs/model_config.yaml

# ========================================
# 模型配置文件
# ========================================

# 模型基本信息
model_info:
  name: "MLP_Emotion_Predictor"
  type: "MLP"
  version: "1.0"
  description: "基于MLP的情绪与生理状态变化预测模型"
  author: "Research Team"
  
# 输入输出维度配置
dimensions:
  input_dim: 7    # 输入维度:User PAD 3维 + Vitality 1维 + Current PAD 3维
  output_dim: 3   # 输出维度:ΔPAD 3维(ΔPleasure, ΔArousal, ΔDominance)
  
# 网络架构配置
architecture:
  # 隐藏层配置
  hidden_layers:
    - size: 128
      activation: "ReLU"
      dropout: 0.2
      batch_norm: false
      layer_norm: false
    - size: 64
      activation: "ReLU"
      dropout: 0.2
      batch_norm: false
      layer_norm: false
    - size: 32
      activation: "ReLU"
      dropout: 0.1
      batch_norm: false
      layer_norm: false
  
  # 输出层配置
  output_layer:
    activation: "Linear"  # 线性激活,用于回归任务
    
  # 正则化配置
  use_batch_norm: false
  use_layer_norm: false
  
# 权重初始化配置
initialization:
  weight_init: "xavier_uniform"  # 可选: xavier_uniform, xavier_normal, kaiming_uniform, kaiming_normal
  bias_init: "zeros"             # 可选: zeros, ones, uniform, normal
  
# 正则化配置
regularization:
  # L2正则化
  weight_decay: 0.0001
  
  # Dropout配置
  dropout_config:
    type: "standard"      # 标准 dropout
    rate: 0.2            # Dropout 概率
    
  # 批归一化
  batch_norm_config:
    momentum: 0.1
    eps: 1e-5
    
# 模型保存配置
model_saving:
  save_best_only: true          # 只保存最佳模型
  save_format: "pytorch"        # 保存格式: pytorch, onnx, torchscript
  checkpoint_interval: 10       # 每10个epoch保存一次检查点
  max_checkpoints: 5           # 最多保存5个检查点
  
# PAD情绪空间特殊配置
emotion_model:
  # PAD值的范围限制
  pad_space:
    pleasure_range: [-1.0, 1.0]     # 快乐维度范围
    arousal_range: [-1.0, 1.0]      # 激活度维度范围  
    dominance_range: [-1.0, 1.0]    # 支配度维度范围
    
  # 生理指标配置
  vitality:
    range: [0.0, 100.0]             # 活力值范围
    normalization: "min_max"        # 标准化方法: min_max, z_score, robust
    
  # 预测输出配置
  prediction:
    # ΔPAD的变化范围限制
    delta_pad_range: [-0.5, 0.5]    # PAD变化的合理范围
    # 压力值变化范围
    delta_pressure_range: [-0.3, 0.3]
    # 置信度范围
    confidence_range: [0.0, 1.0]

模型配置参数详解

model_info 模型基本信息

参数 类型 必需 默认值 说明
name str - 模型名称
type str - 模型类型 (MLP, CNN, RNN等)
version str - 模型版本号
description str - 模型描述
author str - 作者信息

dimensions 输入输出维度

参数 类型 必需 默认值 说明
input_dim int 7 输入特征维度
output_dim int 3 输出预测维度(ΔPAD 3维)

architecture 网络架构

hidden_layers 隐藏层配置

每个隐藏层支持以下参数:

参数 类型 必需 默认值 说明
size int - 神经元数量
activation str ReLU 激活函数
dropout float 0.0 Dropout概率
batch_norm bool false 是否使用批归一化
layer_norm bool false 是否使用层归一化

激活函数选项:

  • ReLU: 修正线性单元
  • LeakyReLU: 泄漏ReLU
  • Tanh: 双曲正切
  • Sigmoid: Sigmoid函数
  • GELU: 高斯误差线性单元
  • Swish: Swish激活函数
output_layer 输出层配置
参数 类型 必需 默认值 说明
activation str Linear 输出激活函数

initialization 权重初始化

参数 类型 必需 默认值 说明
weight_init str xavier_uniform 权重初始化方法
bias_init str zeros 偏置初始化方法

权重初始化选项:

  • xavier_uniform: Xavier均匀初始化
  • xavier_normal: Xavier正态初始化
  • kaiming_uniform: Kaiming均匀初始化 (适合ReLU)
  • kaiming_normal: Kaiming正态初始化 (适合ReLU)
  • uniform: 均匀分布初始化
  • normal: 正态分布初始化

训练配置

主配置文件: configs/training_config.yaml

# ========================================
# 训练配置文件
# ========================================

# 训练基本信息
training_info:
  experiment_name: "emotion_prediction_v1"
  description: "基于MLP的情绪与生理状态变化预测模型训练"
  seed: 42
  tags: ["baseline", "mlp", "emotion_prediction"]
  
# 数据配置
data:
  # 数据路径
  paths:
    train_data: "data/train.csv"
    val_data: "data/val.csv"
    test_data: "data/test.csv"
    
  # 数据预处理
  preprocessing:
    # 特征标准化
    feature_scaling:
      method: "standard"        # standard, min_max, robust, none
      pad_features: "standard"  # PAD特征标准化方法
      vitality_feature: "min_max" # 活力值标准化方法
      
    # 标签标准化
    label_scaling:
      method: "standard"
      delta_pad: "standard"
      delta_pressure: "standard"
      confidence: "none"
    
    # 数据增强
    augmentation:
      enabled: false
      noise_std: 0.01
      mixup_alpha: 0.2
      augmentation_factor: 2
    
    # 数据验证
    validation:
      check_ranges: true
      check_missing: true
      check_outliers: true
      outlier_method: "iqr"  # iqr, zscore, isolation_forest
  
  # 数据加载器配置
  dataloader:
    batch_size: 32
    num_workers: 4
    pin_memory: true
    shuffle: true
    drop_last: false
    persistent_workers: true
    
  # 数据分割
  split:
    train_ratio: 0.8
    val_ratio: 0.1
    test_ratio: 0.1
    stratify: false
    random_seed: 42

# 训练超参数
training:
  # 训练轮次
  epochs:
    max_epochs: 200
    warmup_epochs: 5
    
  # 早停配置
  early_stopping:
    enabled: true
    patience: 15          # 监控轮数
    min_delta: 1e-4       # 最小改善
    monitor: "val_loss"   # 监控指标
    mode: "min"           # min/max
    restore_best_weights: true
    
  # 梯度配置
  gradient:
    clip_enabled: true
    clip_value: 1.0
    clip_norm: 2          # 1: L1 norm, 2: L2 norm
    
  # 混合精度训练
  mixed_precision:
    enabled: false
    opt_level: "O1"       # O0, O1, O2, O3
    
  # 梯度累积
  gradient_accumulation:
    enabled: false
    accumulation_steps: 4

# 优化器配置
optimizer:
  type: "AdamW"          # Adam, SGD, AdamW, RMSprop, Adagrad
  
  # Adam/AdamW 参数
  adam_config:
    lr: 0.0005           # 学习率
    weight_decay: 0.01   # 权重衰减
    betas: [0.9, 0.999]  # Beta参数
    eps: 1e-8           # 数值稳定性
    amsgrad: false       # AMSGrad变体
    
  # SGD 参数
  sgd_config:
    lr: 0.01
    momentum: 0.9
    weight_decay: 0.0001
    nesterov: true
    
  # RMSprop 参数
  rmsprop_config:
    lr: 0.001
    alpha: 0.99
    weight_decay: 0.0
    momentum: 0.0

# 学习率调度器配置
scheduler:
  type: "CosineAnnealingLR"  # StepLR, CosineAnnealingLR, ReduceLROnPlateau, ExponentialLR
  
  # 余弦退火调度器
  cosine_config:
    T_max: 200              # 最大轮数
    eta_min: 1e-6          # 最小学习率
    last_epoch: -1
    
  # 步长调度器
  step_config:
    step_size: 30           # 步长
    gamma: 0.1             # 衰减因子
    
  # 平台衰减调度器
  plateau_config:
    patience: 10           # 耐心值
    factor: 0.5           # 衰减因子
    min_lr: 1e-7          # 最小学习率
    threshold: 1e-4       # 改善阈值
    verbose: true

# 损失函数配置
loss:
  type: "WeightedMSELoss"  # MSELoss, L1Loss, SmoothL1Loss, HuberLoss, WeightedMSELoss
  
  # 基础损失参数
  base_config:
    reduction: "mean"      # mean, sum, none
    
  # 加权损失配置
  weighted_config:
    delta_pad_weight: 1.0      # ΔPAD预测权重
    delta_pressure_weight: 1.0 # ΔPressure预测权重  
    confidence_weight: 0.5     # 置信度预测权重
    
  # Huber损失配置
  huber_config:
    delta: 1.0             # Huber阈值
    
  # 焦点损失配置 (可选)
  focal_config:
    alpha: 1.0
    gamma: 2.0

# 验证配置
validation:
  # 验证频率
  val_frequency: 1        # 每多少个epoch验证一次
  
  # 验证指标
  metrics:
    - "MSE"
    - "MAE"
    - "RMSE"
    - "R2"
    - "MAPE"
    
  # 模型选择
  model_selection:
    criterion: "val_loss"  # val_loss, val_mae, val_r2
    mode: "min"           # min/max
    
  # 验证数据增强
  val_augmentation:
    enabled: false
    methods: []

# 日志和监控配置
logging:
  # 日志级别
  level: "INFO"           # DEBUG, INFO, WARNING, ERROR
  
  # 日志文件
  log_dir: "logs"
  log_file: "training.log"
  max_file_size: "10MB"
  backup_count: 5
  
  # TensorBoard
  tensorboard:
    enabled: true
    log_dir: "runs"
    comment: ""
    flush_secs: 10
    
  # Wandb
  wandb:
    enabled: false
    project: "emotion-prediction"
    entity: "your-team"
    tags: []
    notes: ""
    
  # 进度条
  progress_bar:
    enabled: true
    update_frequency: 10   # 更新频率
    leave: true           # 训练完成后是否保留

# 检查点保存配置
checkpointing:
  # 保存目录
  save_dir: "checkpoints"
  
  # 保存策略
  save_strategy: "best"   # best, last, all
  
  # 文件命名
  filename_template: "model_epoch_{epoch}_val_{val_loss:.4f}.pth"
  
  # 保存内容
  save_items:
    - "model_state_dict"
    - "optimizer_state_dict"
    - "scheduler_state_dict"
    - "epoch"
    - "loss"
    - "metrics"
    - "config"
    
  # 保存频率
  save_frequency: 1       # 每多少个epoch保存一次
  
  # 最大检查点数量
  max_checkpoints: 5

# 硬件配置
hardware:
  # 设备选择
  device: "auto"          # auto, cpu, cuda, mps
  
  # GPU配置
  gpu:
    id: 0                 # GPU ID
    memory_fraction: 0.9  # GPU内存使用比例
    allow_growth: true    # 动态内存增长
    
  # 混合精度
  mixed_precision:
    enabled: false
    opt_level: "O1"
    
  # 分布式训练
  distributed:
    enabled: false
    backend: "nccl"
    init_method: "env://"
    world_size: 1
    rank: 0

# 调试配置
debug:
  # 调试模式
  enabled: false
  
  # 快速训练
  fast_train:
    enabled: false
    max_epochs: 5
    batch_size: 8
    subset_size: 100
    
  # 梯度检查
  gradient_checking:
    enabled: false
    clip_value: 1.0
    check_nan: true
    check_inf: true
    
  # 数据检查
  data_checking:
    enabled: true
    check_nan: true
    check_inf: true
    check_range: true
    sample_output: true
    
  # 模型检查
  model_checking:
    enabled: false
    count_parameters: true
    check_gradients: true
    visualize_model: false

# 实验跟踪配置
experiment_tracking:
  # 是否启用实验跟踪
  enabled: false
  
  # MLflow配置
  mlflow:
    tracking_uri: "http://localhost:5000"
    experiment_name: "emotion_prediction"
    run_name: null
    tags: {}
    params: {}
    
  # Weights & Biases配置
  wandb:
    project: "emotion-prediction"
    entity: null
    group: null
    job_type: "training"
    tags: []
    notes: ""
    config: {}
    
  # 本地实验跟踪
  local:
    save_dir: "experiments"
    save_config: true
    save_metrics: true
    save_model: true

训练配置参数详解

training.epochs 训练轮次

参数 类型 必需 默认值 说明
max_epochs int 200 最大训练轮数
warmup_epochs int 0 预热轮数

training.early_stopping 早停配置

参数 类型 必需 默认值 说明
enabled bool true 是否启用早停
patience int 10 耐心值(轮数)
min_delta float 1e-4 最小改善阈值
monitor str val_loss 监控指标
mode str min 监控模式 (min/max)
restore_best_weights bool true 恢复最佳权重

optimizer 优化器配置

支持的优化器类型:

  • Adam: 自适应矩估计
  • AdamW: Adam with Weight Decay
  • SGD: 随机梯度下降
  • RMSprop: RMSprop优化器
  • Adagrad: 自适应梯度算法

scheduler 学习率调度器

支持的调度器类型:

  • StepLR: 步长衰减
  • CosineAnnealingLR: 余弦退火
  • ReduceLROnPlateau: 平台衰减
  • ExponentialLR: 指数衰减

数据配置

数据配置文件: configs/data_config.yaml

# ========================================
# 数据配置文件
# ========================================

# 数据路径配置
paths:
  # 训练数据
  train_data: "data/train.csv"
  val_data: "data/val.csv"
  test_data: "data/test.csv"
  
  # 预处理器
  preprocessor: "models/preprocessor.pkl"
  
  # 数据统计
  statistics: "data/statistics.json"
  
  # 数据质量报告
  quality_report: "reports/data_quality.html"

# 数据源配置
data_source:
  type: "csv"              # csv, json, parquet, hdf5, database
  
  # CSV配置
  csv_config:
    delimiter: ","
    encoding: "utf-8"
    header: 0
    index_col: null
    
  # JSON配置
  json_config:
    orient: "records"      # records, index, values, columns
    lines: false
    
  # 数据库配置
  database_config:
    connection_string: "sqlite:///data.db"
    table: "emotion_data"
    query: null

# 数据预处理配置
preprocessing:
  # 特征处理
  features:
    # 缺失值处理
    missing_values:
      strategy: "drop"     # drop, fill_mean, fill_median, fill_mode, fill_constant
      fill_value: 0.0
      
    # 异常值处理
    outliers:
      method: "iqr"        # iqr, zscore, isolation_forest, none
      threshold: 1.5
      action: "clip"       # clip, remove, flag
      
    # 特征缩放
    scaling:
      method: "standard"   # standard, minmax, robust, none
      feature_range: [-1, 1]  # MinMax缩放范围
      
    # 特征选择
    selection:
      enabled: false
      method: "correlation" # correlation, mutual_info, rfe
      k_best: 10
      
  # 标签处理
  labels:
    # 缺失值处理
    missing_values:
      strategy: "fill_mean"
      
    # 标签缩放
    scaling:
      method: "standard"
      
    # 标签变换
    transformation:
      enabled: false
      method: "log"        # log, sqrt, boxcox

# 数据增强配置
augmentation:
  enabled: false
  
  # 噪声注入
  noise_injection:
    enabled: true
    noise_type: "gaussian" # gaussian, uniform
    noise_std: 0.01
    feature_wise: true
    
  # Mixup增强
  mixup:
    enabled: true
    alpha: 0.2
    
  # SMOTE增强 (用于不平衡数据)
  smote:
    enabled: false
    k_neighbors: 5
    sampling_strategy: "auto"

# 数据验证配置
validation:
  # 数值范围检查
  range_validation:
    enabled: true
    features:
      user_pleasure: [-1.0, 1.0]
      user_arousal: [-1.0, 1.0]
      user_dominance: [-1.0, 1.0]
      vitality: [0.0, 100.0]
      current_pleasure: [-1.0, 1.0]
      current_arousal: [-1.0, 1.0]
      current_dominance: [-1.0, 1.0]
    labels:
      delta_pleasure: [-0.5, 0.5]
      delta_arousal: [-0.5, 0.5]
      delta_dominance: [-0.5, 0.5]
      delta_pressure: [-0.3, 0.3]
      confidence: [0.0, 1.0]
      
  # 数据质量检查
  quality_checks:
    check_duplicates: true
    check_missing: true
    check_outliers: true
    check_correlations: true
    check_distribution: true
    
  # 统计报告
  statistics:
    compute_descriptive: true
    compute_correlations: true
    compute_distributions: true
    save_plots: true

# 合成数据配置
synthetic_data:
  enabled: false
  
  # 生成参数
  generation:
    num_samples: 1000
    seed: 42
    
  # 数据分布
  distribution:
    type: "multivariate_normal"  # normal, uniform, multivariate_normal
    mean: null
    cov: null
    
  # 相关性配置
  correlation:
    enabled: true
    strength: 0.5
    structure: "block"  # block, random, toeplitz
    
  # 噪声配置
  noise:
    add_noise: true
    noise_type: "gaussian"
    noise_std: 0.1

推理配置

推理配置文件: configs/inference_config.yaml

# ========================================
# 推理配置文件
# ========================================

# 推理基本信息
inference_info:
  model_path: "models/best_model.pth"
  preprocessor_path: "models/preprocessor.pkl"
  device: "auto"
  batch_size: 32
  
# 输入配置
input:
  # 输入格式
  format: "auto"          # auto, list, numpy, pandas, json, csv
  
  # 输入验证
  validation:
    enabled: true
    check_shape: true
    check_range: true
    check_type: true
    
  # 输入预处理
  preprocessing:
    normalize: true
    handle_missing: "error"  # error, fill, skip
    missing_value: 0.0

# 输出配置
output:
  # 输出格式
  format: "dict"          # dict, json, csv, numpy
  
  # 输出内容
  include:
    predictions: true
    confidence: true
    components: true       # delta_pad, delta_pressure, confidence
    metadata: false        # inference_time, model_info
    
  # 输出后处理
  postprocessing:
    clip_predictions: true
    round_decimals: 6
    format_confidence: "percentage"  # decimal, percentage

# 性能优化配置
optimization:
  # 模型优化
  model_optimization:
    enabled: true
    torch_script: false
    onnx: false
    quantization: false
    
  # 推理优化
  inference_optimization:
    warmup: true
    warmup_samples: 5
    batch_optimization: true
    memory_optimization: true
    
  # 缓存配置
  caching:
    enabled: false
    cache_size: 1000
    cache_policy: "lru"   # lru, fifo

# 监控配置
monitoring:
  # 性能监控
  performance:
    enabled: true
    track_latency: true
    track_memory: true
    track_throughput: true
    
  # 质量监控
  quality:
    enabled: false
    confidence_threshold: 0.5
    prediction_validation: true
    
  # 异常检测
  anomaly_detection:
    enabled: false
    method: "statistical"  # statistical, isolation_forest
    threshold: 2.0

# 服务配置 (用于部署)
service:
  # API配置
  api:
    host: "0.0.0.0"
    port: 8000
    workers: 1
    timeout: 30
    
  # 限流配置
  rate_limiting:
    enabled: false
    requests_per_minute: 100
    
  # 认证配置
  authentication:
    enabled: false
    method: "api_key"     # api_key, jwt, basic
    
  # 日志配置
  logging:
    level: "INFO"
    format: "json"

日志配置

日志配置文件: configs/logging_config.yaml

# ========================================
# 日志配置文件
# ========================================

# 日志系统配置
logging:
  # 根日志器
  root:
    level: "INFO"
    handlers: ["console", "file"]
    
  # 日志器配置
  loggers:
    training:
      level: "INFO"
      handlers: ["console", "file", "tensorboard"]
      propagate: false
      
    inference:
      level: "WARNING"
      handlers: ["console", "file"]
      propagate: false
      
    data:
      level: "DEBUG"
      handlers: ["file"]
      propagate: false

# 处理器配置
handlers:
  # 控制台处理器
  console:
    class: "StreamHandler"
    level: "INFO"
    formatter: "console"
    stream: "ext://sys.stdout"
    
  # 文件处理器
  file:
    class: "RotatingFileHandler"
    level: "DEBUG"
    formatter: "detailed"
    filename: "logs/app.log"
    maxBytes: 10485760     # 10MB
    backupCount: 5
    encoding: "utf8"
    
  # 错误文件处理器
  error_file:
    class: "RotatingFileHandler"
    level: "ERROR"
    formatter: "detailed"
    filename: "logs/error.log"
    maxBytes: 10485760
    backupCount: 3
    encoding: "utf8"

# 格式化器配置
formatters:
  # 控制台格式
  console:
    format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
    datefmt: "%H:%M:%S"
    
  # 详细格式
  detailed:
    format: "%(asctime)s - %(name)s - %(levelname)s - %(module)s:%(lineno)d - %(funcName)s - %(message)s"
    datefmt: "%Y-%m-%d %H:%M:%S"
    
  # JSON格式
  json:
    format: '{"timestamp": "%(asctime)s", "level": "%(levelname)s", "logger": "%(name)s", "module": "%(module)s", "line": %(lineno)d, "message": "%(message)s"}'
    datefmt: "%Y-%m-%dT%H:%M:%S"

# 日志过滤配置
filters:
  # 性能过滤器
  performance:
    class: "PerformanceFilter"
    threshold: 0.1
    
  # 敏感信息过滤器
  sensitive:
    class: "SensitiveDataFilter"
    patterns: ["password", "token", "key"]

硬件配置

硬件配置文件: configs/hardware_config.yaml

# ========================================
# 硬件配置文件
# ========================================

# 设备配置
device:
  # 自动选择
  auto:
    priority: ["cuda", "mps", "cpu"]  # 设备优先级
    memory_threshold: 0.8              # 内存使用阈值
    
  # CPU配置
  cpu:
    num_threads: null                  # null为自动检测
    use_openmp: true
    use_mkl: true
    
  # GPU配置
  gpu:
    # GPU选择
    device_id: 0                       # GPU ID
    memory_fraction: 0.9               # GPU内存使用比例
    allow_growth: true                 # 动态内存增长
    
    # CUDA配置
    cuda:
      allow_tf32: true                 # 启用TF32
      benchmark: true                  # 启用cuDNN基准
      deterministic: false             # 确定性模式
      
    # 混合精度
    mixed_precision:
      enabled: false
      opt_level: "O1"                  # O0, O1, O2, O3
      loss_scale: "dynamic"            # static, dynamic
      
    # 多GPU配置
    multi_gpu:
      enabled: false
      device_ids: [0, 1]
      output_device: 0
      dim: 0                           # 数据并行维度

# 内存配置
memory:
  # 系统内存
  system:
    max_usage: 0.8                     # 最大使用比例
    cleanup_threshold: 0.9             # 清理阈值
    
  # GPU内存
  gpu:
    max_usage: 0.9
    cleanup_interval: 100              # 清理间隔(步数)
    
  # 内存优化
  optimization:
    enable_gc: true                    # 启用垃圾回收
    gc_threshold: 0.8                  # GC触发阈值
    pin_memory: true                   # 锁页内存
    share_memory: true                 # 共享内存

# 性能配置
performance:
  # 并行配置
  parallel:
    num_workers: 4                     # 数据加载器工作进程数
    prefetch_factor: 2                 # 预取因子
    
  # 缓存配置
  cache:
    model_cache: true                  # 模型缓存
    data_cache: true                   # 数据缓存
    cache_size: 1024                   # 缓存大小(MB)
    
  # 编译优化
  compilation:
    torch_compile: false               # PyTorch 2.0编译
    jit_script: true                   # TorchScript
    mode: "default"                    # default, reduce-overhead, max-autotune

配置最佳实践

1. 配置文件组织

configs/
├── model_config.yaml          # 模型配置
├── training_config.yaml       # 训练配置
├── data_config.yaml          # 数据配置
├── inference_config.yaml     # 推理配置
├── logging_config.yaml       # 日志配置
├── hardware_config.yaml      # 硬件配置
├── environments/             # 环境特定配置
│   ├── development.yaml
│   ├── staging.yaml
│   └── production.yaml
└── experiments/              # 实验特定配置
    ├── baseline.yaml
    ├── large_model.yaml
    └── fast_train.yaml

2. 配置继承

# configs/experiments/large_model.yaml
_base_: "../training_config.yaml"

training:
  epochs:
    max_epochs: 500
    
model:
  architecture:
    hidden_layers:
      - size: 256
        activation: "ReLU"
        dropout: 0.3
      - size: 128
        activation: "ReLU"
        dropout: 0.2
      - size: 64
        activation: "ReLU"
        dropout: 0.1

experiment_tracking:
  enabled: true
  mlflow:
    experiment_name: "large_model_experiment"

3. 环境变量替换

# 使用环境变量
model_path: "${MODEL_PATH:/models/default.pth}"
learning_rate: "${LEARNING_RATE:0.001}"
batch_size: "${BATCH_SIZE:32}"

4. 配置验证

from src.utils.config import ConfigValidator
from src.utils.config import ValidationError

# 创建验证器
validator = ConfigValidator()

# 添加验证规则
validator.add_rule("training.optimizer.lr", lambda x: 0 < x <= 1)
validator.add_rule("model.hidden_dims", lambda x: len(x) > 0)

# 验证配置
try:
    validator.validate(config)
except ValidationError as e:
    print(f"配置验证失败: {e}")

5. 配置版本管理

# 配置文件版本
config_version: "1.0"
compatibility_version: ">=0.9.0"

# 变更日志
changelog:
  - version: "1.0"
    changes: ["添加混合精度支持", "更新学习率调度器"]
  - version: "0.9"
    changes: ["初始版本"]

配置验证

配置验证器

class ConfigValidator:
    """配置验证器"""
    
    def __init__(self):
        self.rules = {}
        self.schemas = {}
    
    def add_rule(self, path: str, validator: callable, message: str = None):
        """添加验证规则"""
        self.rules[path] = {
            'validator': validator,
            'message': message or f"Invalid value at {path}"
        }
    
    def add_schema(self, section: str, schema: Dict):
        """添加配置模式"""
        self.schemas[section] = schema
    
    def validate(self, config: Dict) -> bool:
        """验证配置"""
        for path, rule in self.rules.items():
            value = self._get_nested_value(config, path)
            if not rule['validator'](value):
                raise ValidationError(rule['message'])
        return True
    
    def _get_nested_value(self, config: Dict, path: str):
        """获取嵌套值"""
        keys = path.split('.')
        value = config
        for key in keys:
            value = value.get(key)
            if value is None:
                return None
        return value

常用验证规则

# 数值范围验证
validator.add_rule("training.optimizer.lr", lambda x: 0 < x <= 1, "学习率必须在(0, 1]范围内")
validator.add_rule("model.dropout_rate", lambda x: 0 <= x < 1, "Dropout率必须在[0, 1)范围内")

# 列表验证
validator.add_rule("model.hidden_dims", lambda x: isinstance(x, list) and len(x) > 0