# 配置文件说明文档 本文档详细介绍了情绪与生理状态变化预测模型的所有配置选项、参数说明和使用示例。 ## 目录 1. [配置系统概述](#配置系统概述) 2. [模型配置](#模型配置) 3. [训练配置](#训练配置) 4. [数据配置](#数据配置) 5. [推理配置](#推理配置) 6. [日志配置](#日志配置) 7. [硬件配置](#硬件配置) 8. [实验跟踪配置](#实验跟踪配置) 9. [配置最佳实践](#配置最佳实践) 10. [配置验证](#配置验证) ## 配置系统概述 ### 配置文件格式 项目使用YAML格式的配置文件,支持: - 层次化结构 - 注释支持 - 变量引用 - 环境变量替换 - 配置继承 ### 配置文件加载顺序 1. 默认配置 (内置) 2. 全局配置文件 (`~/.emotion-prediction/config.yaml`) 3. 项目配置文件 (`configs/`) 4. 命令行参数覆盖 ### 配置管理器 ```python from src.utils.config import ConfigManager # 加载配置 config_manager = ConfigManager() config = config_manager.load_config("configs/training_config.yaml") # 访问配置 learning_rate = config.training.optimizer.learning_rate batch_size = config.training.batch_size # 配置验证 config_manager.validate_config(config) ``` ## 模型配置 ### 主配置文件: `configs/model_config.yaml` ```yaml # ======================================== # 模型配置文件 # ======================================== # 模型基本信息 model_info: name: "MLP_Emotion_Predictor" type: "MLP" version: "1.0" description: "基于MLP的情绪与生理状态变化预测模型" author: "Research Team" # 输入输出维度配置 dimensions: input_dim: 7 # 输入维度:User PAD 3维 + Vitality 1维 + Current PAD 3维 output_dim: 3 # 输出维度:ΔPAD 3维(ΔPleasure, ΔArousal, ΔDominance) # 网络架构配置 architecture: # 隐藏层配置 hidden_layers: - size: 128 activation: "ReLU" dropout: 0.2 batch_norm: false layer_norm: false - size: 64 activation: "ReLU" dropout: 0.2 batch_norm: false layer_norm: false - size: 32 activation: "ReLU" dropout: 0.1 batch_norm: false layer_norm: false # 输出层配置 output_layer: activation: "Linear" # 线性激活,用于回归任务 # 正则化配置 use_batch_norm: false use_layer_norm: false # 权重初始化配置 initialization: weight_init: "xavier_uniform" # 可选: xavier_uniform, xavier_normal, kaiming_uniform, kaiming_normal bias_init: "zeros" # 可选: zeros, ones, uniform, normal # 正则化配置 regularization: # L2正则化 weight_decay: 0.0001 # Dropout配置 dropout_config: type: "standard" # 标准 dropout rate: 0.2 # Dropout 概率 # 批归一化 batch_norm_config: momentum: 0.1 eps: 1e-5 # 模型保存配置 model_saving: save_best_only: true # 只保存最佳模型 save_format: "pytorch" # 保存格式: pytorch, onnx, torchscript checkpoint_interval: 10 # 每10个epoch保存一次检查点 max_checkpoints: 5 # 最多保存5个检查点 # PAD情绪空间特殊配置 emotion_model: # PAD值的范围限制 pad_space: pleasure_range: [-1.0, 1.0] # 快乐维度范围 arousal_range: [-1.0, 1.0] # 激活度维度范围 dominance_range: [-1.0, 1.0] # 支配度维度范围 # 生理指标配置 vitality: range: [0.0, 100.0] # 活力值范围 normalization: "min_max" # 标准化方法: min_max, z_score, robust # 预测输出配置 prediction: # ΔPAD的变化范围限制 delta_pad_range: [-0.5, 0.5] # PAD变化的合理范围 # 压力值变化范围 delta_pressure_range: [-0.3, 0.3] # 置信度范围 confidence_range: [0.0, 1.0] ``` ### 模型配置参数详解 #### `model_info` 模型基本信息 | 参数 | 类型 | 必需 | 默认值 | 说明 | |------|------|------|--------|------| | `name` | str | 是 | - | 模型名称 | | `type` | str | 是 | - | 模型类型 (MLP, CNN, RNN等) | | `version` | str | 是 | - | 模型版本号 | | `description` | str | 否 | - | 模型描述 | | `author` | str | 否 | - | 作者信息 | #### `dimensions` 输入输出维度 | 参数 | 类型 | 必需 | 默认值 | 说明 | |------|------|------|--------|------| | `input_dim` | int | 是 | 7 | 输入特征维度 | | `output_dim` | int | 是 | 3 | 输出预测维度(ΔPAD 3维) | #### `architecture` 网络架构 ##### `hidden_layers` 隐藏层配置 每个隐藏层支持以下参数: | 参数 | 类型 | 必需 | 默认值 | 说明 | |------|------|------|--------|------| | `size` | int | 是 | - | 神经元数量 | | `activation` | str | 否 | ReLU | 激活函数 | | `dropout` | float | 否 | 0.0 | Dropout概率 | | `batch_norm` | bool | 否 | false | 是否使用批归一化 | | `layer_norm` | bool | 否 | false | 是否使用层归一化 | **激活函数选项**: - `ReLU`: 修正线性单元 - `LeakyReLU`: 泄漏ReLU - `Tanh`: 双曲正切 - `Sigmoid`: Sigmoid函数 - `GELU`: 高斯误差线性单元 - `Swish`: Swish激活函数 ##### `output_layer` 输出层配置 | 参数 | 类型 | 必需 | 默认值 | 说明 | |------|------|------|--------|------| | `activation` | str | 否 | Linear | 输出激活函数 | #### `initialization` 权重初始化 | 参数 | 类型 | 必需 | 默认值 | 说明 | |------|------|------|--------|------| | `weight_init` | str | 否 | xavier_uniform | 权重初始化方法 | | `bias_init` | str | 否 | zeros | 偏置初始化方法 | **权重初始化选项**: - `xavier_uniform`: Xavier均匀初始化 - `xavier_normal`: Xavier正态初始化 - `kaiming_uniform`: Kaiming均匀初始化 (适合ReLU) - `kaiming_normal`: Kaiming正态初始化 (适合ReLU) - `uniform`: 均匀分布初始化 - `normal`: 正态分布初始化 ## 训练配置 ### 主配置文件: `configs/training_config.yaml` ```yaml # ======================================== # 训练配置文件 # ======================================== # 训练基本信息 training_info: experiment_name: "emotion_prediction_v1" description: "基于MLP的情绪与生理状态变化预测模型训练" seed: 42 tags: ["baseline", "mlp", "emotion_prediction"] # 数据配置 data: # 数据路径 paths: train_data: "data/train.csv" val_data: "data/val.csv" test_data: "data/test.csv" # 数据预处理 preprocessing: # 特征标准化 feature_scaling: method: "standard" # standard, min_max, robust, none pad_features: "standard" # PAD特征标准化方法 vitality_feature: "min_max" # 活力值标准化方法 # 标签标准化 label_scaling: method: "standard" delta_pad: "standard" delta_pressure: "standard" confidence: "none" # 数据增强 augmentation: enabled: false noise_std: 0.01 mixup_alpha: 0.2 augmentation_factor: 2 # 数据验证 validation: check_ranges: true check_missing: true check_outliers: true outlier_method: "iqr" # iqr, zscore, isolation_forest # 数据加载器配置 dataloader: batch_size: 32 num_workers: 4 pin_memory: true shuffle: true drop_last: false persistent_workers: true # 数据分割 split: train_ratio: 0.8 val_ratio: 0.1 test_ratio: 0.1 stratify: false random_seed: 42 # 训练超参数 training: # 训练轮次 epochs: max_epochs: 200 warmup_epochs: 5 # 早停配置 early_stopping: enabled: true patience: 15 # 监控轮数 min_delta: 1e-4 # 最小改善 monitor: "val_loss" # 监控指标 mode: "min" # min/max restore_best_weights: true # 梯度配置 gradient: clip_enabled: true clip_value: 1.0 clip_norm: 2 # 1: L1 norm, 2: L2 norm # 混合精度训练 mixed_precision: enabled: false opt_level: "O1" # O0, O1, O2, O3 # 梯度累积 gradient_accumulation: enabled: false accumulation_steps: 4 # 优化器配置 optimizer: type: "AdamW" # Adam, SGD, AdamW, RMSprop, Adagrad # Adam/AdamW 参数 adam_config: lr: 0.0005 # 学习率 weight_decay: 0.01 # 权重衰减 betas: [0.9, 0.999] # Beta参数 eps: 1e-8 # 数值稳定性 amsgrad: false # AMSGrad变体 # SGD 参数 sgd_config: lr: 0.01 momentum: 0.9 weight_decay: 0.0001 nesterov: true # RMSprop 参数 rmsprop_config: lr: 0.001 alpha: 0.99 weight_decay: 0.0 momentum: 0.0 # 学习率调度器配置 scheduler: type: "CosineAnnealingLR" # StepLR, CosineAnnealingLR, ReduceLROnPlateau, ExponentialLR # 余弦退火调度器 cosine_config: T_max: 200 # 最大轮数 eta_min: 1e-6 # 最小学习率 last_epoch: -1 # 步长调度器 step_config: step_size: 30 # 步长 gamma: 0.1 # 衰减因子 # 平台衰减调度器 plateau_config: patience: 10 # 耐心值 factor: 0.5 # 衰减因子 min_lr: 1e-7 # 最小学习率 threshold: 1e-4 # 改善阈值 verbose: true # 损失函数配置 loss: type: "WeightedMSELoss" # MSELoss, L1Loss, SmoothL1Loss, HuberLoss, WeightedMSELoss # 基础损失参数 base_config: reduction: "mean" # mean, sum, none # 加权损失配置 weighted_config: delta_pad_weight: 1.0 # ΔPAD预测权重 delta_pressure_weight: 1.0 # ΔPressure预测权重 confidence_weight: 0.5 # 置信度预测权重 # Huber损失配置 huber_config: delta: 1.0 # Huber阈值 # 焦点损失配置 (可选) focal_config: alpha: 1.0 gamma: 2.0 # 验证配置 validation: # 验证频率 val_frequency: 1 # 每多少个epoch验证一次 # 验证指标 metrics: - "MSE" - "MAE" - "RMSE" - "R2" - "MAPE" # 模型选择 model_selection: criterion: "val_loss" # val_loss, val_mae, val_r2 mode: "min" # min/max # 验证数据增强 val_augmentation: enabled: false methods: [] # 日志和监控配置 logging: # 日志级别 level: "INFO" # DEBUG, INFO, WARNING, ERROR # 日志文件 log_dir: "logs" log_file: "training.log" max_file_size: "10MB" backup_count: 5 # TensorBoard tensorboard: enabled: true log_dir: "runs" comment: "" flush_secs: 10 # Wandb wandb: enabled: false project: "emotion-prediction" entity: "your-team" tags: [] notes: "" # 进度条 progress_bar: enabled: true update_frequency: 10 # 更新频率 leave: true # 训练完成后是否保留 # 检查点保存配置 checkpointing: # 保存目录 save_dir: "checkpoints" # 保存策略 save_strategy: "best" # best, last, all # 文件命名 filename_template: "model_epoch_{epoch}_val_{val_loss:.4f}.pth" # 保存内容 save_items: - "model_state_dict" - "optimizer_state_dict" - "scheduler_state_dict" - "epoch" - "loss" - "metrics" - "config" # 保存频率 save_frequency: 1 # 每多少个epoch保存一次 # 最大检查点数量 max_checkpoints: 5 # 硬件配置 hardware: # 设备选择 device: "auto" # auto, cpu, cuda, mps # GPU配置 gpu: id: 0 # GPU ID memory_fraction: 0.9 # GPU内存使用比例 allow_growth: true # 动态内存增长 # 混合精度 mixed_precision: enabled: false opt_level: "O1" # 分布式训练 distributed: enabled: false backend: "nccl" init_method: "env://" world_size: 1 rank: 0 # 调试配置 debug: # 调试模式 enabled: false # 快速训练 fast_train: enabled: false max_epochs: 5 batch_size: 8 subset_size: 100 # 梯度检查 gradient_checking: enabled: false clip_value: 1.0 check_nan: true check_inf: true # 数据检查 data_checking: enabled: true check_nan: true check_inf: true check_range: true sample_output: true # 模型检查 model_checking: enabled: false count_parameters: true check_gradients: true visualize_model: false # 实验跟踪配置 experiment_tracking: # 是否启用实验跟踪 enabled: false # MLflow配置 mlflow: tracking_uri: "http://localhost:5000" experiment_name: "emotion_prediction" run_name: null tags: {} params: {} # Weights & Biases配置 wandb: project: "emotion-prediction" entity: null group: null job_type: "training" tags: [] notes: "" config: {} # 本地实验跟踪 local: save_dir: "experiments" save_config: true save_metrics: true save_model: true ``` ### 训练配置参数详解 #### `training.epochs` 训练轮次 | 参数 | 类型 | 必需 | 默认值 | 说明 | |------|------|------|--------|------| | `max_epochs` | int | 是 | 200 | 最大训练轮数 | | `warmup_epochs` | int | 否 | 0 | 预热轮数 | #### `training.early_stopping` 早停配置 | 参数 | 类型 | 必需 | 默认值 | 说明 | |------|------|------|--------|------| | `enabled` | bool | 否 | true | 是否启用早停 | | `patience` | int | 否 | 10 | 耐心值(轮数) | | `min_delta` | float | 否 | 1e-4 | 最小改善阈值 | | `monitor` | str | 否 | val_loss | 监控指标 | | `mode` | str | 否 | min | 监控模式 (min/max) | | `restore_best_weights` | bool | 否 | true | 恢复最佳权重 | #### `optimizer` 优化器配置 支持的优化器类型: - `Adam`: 自适应矩估计 - `AdamW`: Adam with Weight Decay - `SGD`: 随机梯度下降 - `RMSprop`: RMSprop优化器 - `Adagrad`: 自适应梯度算法 #### `scheduler` 学习率调度器 支持的调度器类型: - `StepLR`: 步长衰减 - `CosineAnnealingLR`: 余弦退火 - `ReduceLROnPlateau`: 平台衰减 - `ExponentialLR`: 指数衰减 ## 数据配置 ### 数据配置文件: `configs/data_config.yaml` ```yaml # ======================================== # 数据配置文件 # ======================================== # 数据路径配置 paths: # 训练数据 train_data: "data/train.csv" val_data: "data/val.csv" test_data: "data/test.csv" # 预处理器 preprocessor: "models/preprocessor.pkl" # 数据统计 statistics: "data/statistics.json" # 数据质量报告 quality_report: "reports/data_quality.html" # 数据源配置 data_source: type: "csv" # csv, json, parquet, hdf5, database # CSV配置 csv_config: delimiter: "," encoding: "utf-8" header: 0 index_col: null # JSON配置 json_config: orient: "records" # records, index, values, columns lines: false # 数据库配置 database_config: connection_string: "sqlite:///data.db" table: "emotion_data" query: null # 数据预处理配置 preprocessing: # 特征处理 features: # 缺失值处理 missing_values: strategy: "drop" # drop, fill_mean, fill_median, fill_mode, fill_constant fill_value: 0.0 # 异常值处理 outliers: method: "iqr" # iqr, zscore, isolation_forest, none threshold: 1.5 action: "clip" # clip, remove, flag # 特征缩放 scaling: method: "standard" # standard, minmax, robust, none feature_range: [-1, 1] # MinMax缩放范围 # 特征选择 selection: enabled: false method: "correlation" # correlation, mutual_info, rfe k_best: 10 # 标签处理 labels: # 缺失值处理 missing_values: strategy: "fill_mean" # 标签缩放 scaling: method: "standard" # 标签变换 transformation: enabled: false method: "log" # log, sqrt, boxcox # 数据增强配置 augmentation: enabled: false # 噪声注入 noise_injection: enabled: true noise_type: "gaussian" # gaussian, uniform noise_std: 0.01 feature_wise: true # Mixup增强 mixup: enabled: true alpha: 0.2 # SMOTE增强 (用于不平衡数据) smote: enabled: false k_neighbors: 5 sampling_strategy: "auto" # 数据验证配置 validation: # 数值范围检查 range_validation: enabled: true features: user_pleasure: [-1.0, 1.0] user_arousal: [-1.0, 1.0] user_dominance: [-1.0, 1.0] vitality: [0.0, 100.0] current_pleasure: [-1.0, 1.0] current_arousal: [-1.0, 1.0] current_dominance: [-1.0, 1.0] labels: delta_pleasure: [-0.5, 0.5] delta_arousal: [-0.5, 0.5] delta_dominance: [-0.5, 0.5] delta_pressure: [-0.3, 0.3] confidence: [0.0, 1.0] # 数据质量检查 quality_checks: check_duplicates: true check_missing: true check_outliers: true check_correlations: true check_distribution: true # 统计报告 statistics: compute_descriptive: true compute_correlations: true compute_distributions: true save_plots: true # 合成数据配置 synthetic_data: enabled: false # 生成参数 generation: num_samples: 1000 seed: 42 # 数据分布 distribution: type: "multivariate_normal" # normal, uniform, multivariate_normal mean: null cov: null # 相关性配置 correlation: enabled: true strength: 0.5 structure: "block" # block, random, toeplitz # 噪声配置 noise: add_noise: true noise_type: "gaussian" noise_std: 0.1 ``` ## 推理配置 ### 推理配置文件: `configs/inference_config.yaml` ```yaml # ======================================== # 推理配置文件 # ======================================== # 推理基本信息 inference_info: model_path: "models/best_model.pth" preprocessor_path: "models/preprocessor.pkl" device: "auto" batch_size: 32 # 输入配置 input: # 输入格式 format: "auto" # auto, list, numpy, pandas, json, csv # 输入验证 validation: enabled: true check_shape: true check_range: true check_type: true # 输入预处理 preprocessing: normalize: true handle_missing: "error" # error, fill, skip missing_value: 0.0 # 输出配置 output: # 输出格式 format: "dict" # dict, json, csv, numpy # 输出内容 include: predictions: true confidence: true components: true # delta_pad, delta_pressure, confidence metadata: false # inference_time, model_info # 输出后处理 postprocessing: clip_predictions: true round_decimals: 6 format_confidence: "percentage" # decimal, percentage # 性能优化配置 optimization: # 模型优化 model_optimization: enabled: true torch_script: false onnx: false quantization: false # 推理优化 inference_optimization: warmup: true warmup_samples: 5 batch_optimization: true memory_optimization: true # 缓存配置 caching: enabled: false cache_size: 1000 cache_policy: "lru" # lru, fifo # 监控配置 monitoring: # 性能监控 performance: enabled: true track_latency: true track_memory: true track_throughput: true # 质量监控 quality: enabled: false confidence_threshold: 0.5 prediction_validation: true # 异常检测 anomaly_detection: enabled: false method: "statistical" # statistical, isolation_forest threshold: 2.0 # 服务配置 (用于部署) service: # API配置 api: host: "0.0.0.0" port: 8000 workers: 1 timeout: 30 # 限流配置 rate_limiting: enabled: false requests_per_minute: 100 # 认证配置 authentication: enabled: false method: "api_key" # api_key, jwt, basic # 日志配置 logging: level: "INFO" format: "json" ``` ## 日志配置 ### 日志配置文件: `configs/logging_config.yaml` ```yaml # ======================================== # 日志配置文件 # ======================================== # 日志系统配置 logging: # 根日志器 root: level: "INFO" handlers: ["console", "file"] # 日志器配置 loggers: training: level: "INFO" handlers: ["console", "file", "tensorboard"] propagate: false inference: level: "WARNING" handlers: ["console", "file"] propagate: false data: level: "DEBUG" handlers: ["file"] propagate: false # 处理器配置 handlers: # 控制台处理器 console: class: "StreamHandler" level: "INFO" formatter: "console" stream: "ext://sys.stdout" # 文件处理器 file: class: "RotatingFileHandler" level: "DEBUG" formatter: "detailed" filename: "logs/app.log" maxBytes: 10485760 # 10MB backupCount: 5 encoding: "utf8" # 错误文件处理器 error_file: class: "RotatingFileHandler" level: "ERROR" formatter: "detailed" filename: "logs/error.log" maxBytes: 10485760 backupCount: 3 encoding: "utf8" # 格式化器配置 formatters: # 控制台格式 console: format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s" datefmt: "%H:%M:%S" # 详细格式 detailed: format: "%(asctime)s - %(name)s - %(levelname)s - %(module)s:%(lineno)d - %(funcName)s - %(message)s" datefmt: "%Y-%m-%d %H:%M:%S" # JSON格式 json: format: '{"timestamp": "%(asctime)s", "level": "%(levelname)s", "logger": "%(name)s", "module": "%(module)s", "line": %(lineno)d, "message": "%(message)s"}' datefmt: "%Y-%m-%dT%H:%M:%S" # 日志过滤配置 filters: # 性能过滤器 performance: class: "PerformanceFilter" threshold: 0.1 # 敏感信息过滤器 sensitive: class: "SensitiveDataFilter" patterns: ["password", "token", "key"] ``` ## 硬件配置 ### 硬件配置文件: `configs/hardware_config.yaml` ```yaml # ======================================== # 硬件配置文件 # ======================================== # 设备配置 device: # 自动选择 auto: priority: ["cuda", "mps", "cpu"] # 设备优先级 memory_threshold: 0.8 # 内存使用阈值 # CPU配置 cpu: num_threads: null # null为自动检测 use_openmp: true use_mkl: true # GPU配置 gpu: # GPU选择 device_id: 0 # GPU ID memory_fraction: 0.9 # GPU内存使用比例 allow_growth: true # 动态内存增长 # CUDA配置 cuda: allow_tf32: true # 启用TF32 benchmark: true # 启用cuDNN基准 deterministic: false # 确定性模式 # 混合精度 mixed_precision: enabled: false opt_level: "O1" # O0, O1, O2, O3 loss_scale: "dynamic" # static, dynamic # 多GPU配置 multi_gpu: enabled: false device_ids: [0, 1] output_device: 0 dim: 0 # 数据并行维度 # 内存配置 memory: # 系统内存 system: max_usage: 0.8 # 最大使用比例 cleanup_threshold: 0.9 # 清理阈值 # GPU内存 gpu: max_usage: 0.9 cleanup_interval: 100 # 清理间隔(步数) # 内存优化 optimization: enable_gc: true # 启用垃圾回收 gc_threshold: 0.8 # GC触发阈值 pin_memory: true # 锁页内存 share_memory: true # 共享内存 # 性能配置 performance: # 并行配置 parallel: num_workers: 4 # 数据加载器工作进程数 prefetch_factor: 2 # 预取因子 # 缓存配置 cache: model_cache: true # 模型缓存 data_cache: true # 数据缓存 cache_size: 1024 # 缓存大小(MB) # 编译优化 compilation: torch_compile: false # PyTorch 2.0编译 jit_script: true # TorchScript mode: "default" # default, reduce-overhead, max-autotune ``` ## 配置最佳实践 ### 1. 配置文件组织 ``` configs/ ├── model_config.yaml # 模型配置 ├── training_config.yaml # 训练配置 ├── data_config.yaml # 数据配置 ├── inference_config.yaml # 推理配置 ├── logging_config.yaml # 日志配置 ├── hardware_config.yaml # 硬件配置 ├── environments/ # 环境特定配置 │ ├── development.yaml │ ├── staging.yaml │ └── production.yaml └── experiments/ # 实验特定配置 ├── baseline.yaml ├── large_model.yaml └── fast_train.yaml ``` ### 2. 配置继承 ```yaml # configs/experiments/large_model.yaml _base_: "../training_config.yaml" training: epochs: max_epochs: 500 model: architecture: hidden_layers: - size: 256 activation: "ReLU" dropout: 0.3 - size: 128 activation: "ReLU" dropout: 0.2 - size: 64 activation: "ReLU" dropout: 0.1 experiment_tracking: enabled: true mlflow: experiment_name: "large_model_experiment" ``` ### 3. 环境变量替换 ```yaml # 使用环境变量 model_path: "${MODEL_PATH:/models/default.pth}" learning_rate: "${LEARNING_RATE:0.001}" batch_size: "${BATCH_SIZE:32}" ``` ### 4. 配置验证 ```python from src.utils.config import ConfigValidator from src.utils.config import ValidationError # 创建验证器 validator = ConfigValidator() # 添加验证规则 validator.add_rule("training.optimizer.lr", lambda x: 0 < x <= 1) validator.add_rule("model.hidden_dims", lambda x: len(x) > 0) # 验证配置 try: validator.validate(config) except ValidationError as e: print(f"配置验证失败: {e}") ``` ### 5. 配置版本管理 ```yaml # 配置文件版本 config_version: "1.0" compatibility_version: ">=0.9.0" # 变更日志 changelog: - version: "1.0" changes: ["添加混合精度支持", "更新学习率调度器"] - version: "0.9" changes: ["初始版本"] ``` ## 配置验证 ### 配置验证器 ```python class ConfigValidator: """配置验证器""" def __init__(self): self.rules = {} self.schemas = {} def add_rule(self, path: str, validator: callable, message: str = None): """添加验证规则""" self.rules[path] = { 'validator': validator, 'message': message or f"Invalid value at {path}" } def add_schema(self, section: str, schema: Dict): """添加配置模式""" self.schemas[section] = schema def validate(self, config: Dict) -> bool: """验证配置""" for path, rule in self.rules.items(): value = self._get_nested_value(config, path) if not rule['validator'](value): raise ValidationError(rule['message']) return True def _get_nested_value(self, config: Dict, path: str): """获取嵌套值""" keys = path.split('.') value = config for key in keys: value = value.get(key) if value is None: return None return value ``` ### 常用验证规则 ```python # 数值范围验证 validator.add_rule("training.optimizer.lr", lambda x: 0 < x <= 1, "学习率必须在(0, 1]范围内") validator.add_rule("model.dropout_rate", lambda x: 0 <= x < 1, "Dropout率必须在[0, 1)范围内") # 列表验证 validator.add_rule("model.hidden_dims", lambda x: isinstance(x, list) and len(x) > 0