| # 数据预处理模块 (Data Preprocessing Module) | |
| 本模块实现了情绪与生理状态变化预测模型的数据预处理功能。 | |
| ## 功能特性 | |
| - **数据集类**: 处理7维输入和5维输出的数据 | |
| - **数据加载器**: 支持训练/验证/测试分割 | |
| - **数据预处理**: 标准化、清洗和异常值处理 | |
| - **合成数据生成**: 生成符合要求的模拟数据 | |
| ## 数据格式 | |
| ### 输入特征 (7维) | |
| - User PAD: Pleasure, Arousal, Dominance (3维) [-1, 1] | |
| - Vitality: 生理活力值 (1维) [0, 100] | |
| - Current PAD: 当前状态 Pleasure, Arousal, Dominance (3维) [-1, 1] | |
| ### 输出标签 (5维) | |
| - ΔPAD: PAD状态变化量 (3维) [-0.5, 0.5] | |
| - ΔPressure: 压力变化 (1维) [-0.3, 0.3] | |
| - Confidence: 预测置信度 (1维) [0, 1] | |
| ## 使用示例 | |
| ### 1. 生成合成数据 | |
| ```python | |
| from src.data import generate_synthetic_data, SyntheticDataGenerator | |
| # 便捷函数生成数据 | |
| features, labels = generate_synthetic_data(num_samples=1000) | |
| print(f"Features: {features.shape}, Labels: {labels.shape}") | |
| # 使用生成器类 | |
| generator = SyntheticDataGenerator(num_samples=1000, seed=42) | |
| features, labels = generator.generate_data() | |
| # 生成特定模式的数据 | |
| features, labels = generator.generate_dataset_with_patterns( | |
| patterns=['stress', 'relaxation', 'excitement'], | |
| pattern_weights=[0.3, 0.4, 0.3] | |
| ) | |
| ``` | |
| ### 2. 数据预处理 | |
| ```python | |
| from src.data import create_preprocessor | |
| # 创建预处理器 | |
| preprocessor = create_preprocessor() | |
| # 拟合并转换数据 | |
| features_scaled, labels_scaled = preprocessor.fit_transform(features, labels) | |
| # 获取统计信息 | |
| feature_stats = preprocessor.get_feature_statistics() | |
| label_stats = preprocessor.get_label_statistics() | |
| # 保存预处理器 | |
| preprocessor.save_preprocessor('preprocessor.pkl') | |
| # 加载预处理器 | |
| preprocessor = DataPreprocessor.load_preprocessor('preprocessor.pkl') | |
| ``` | |
| ### 3. 创建数据集 | |
| ```python | |
| from src.data import EmotionDataset | |
| # 从numpy数组创建 | |
| dataset = EmotionDataset(features, labels) | |
| # 从文件创建 | |
| dataset = EmotionDataset('data.csv') | |
| # 获取单个样本 | |
| sample_features, sample_labels = dataset[0] | |
| # 获取统计信息 | |
| stats = dataset.get_feature_statistics() | |
| ``` | |
| ### 4. 数据加载器 | |
| ```python | |
| from src.data import create_data_loader | |
| # 创建数据加载器 | |
| loader = create_data_loader(batch_size=32, shuffle=True) | |
| # 获取所有数据加载器 | |
| train_loader, val_loader, test_loader = loader.get_all_loaders( | |
| data=features, labels=labels | |
| ) | |
| # 获取单个加载器 | |
| train_loader = loader.get_train_loader(data=features, labels=labels) | |
| val_loader = loader.get_val_loader(data=features, labels=labels) | |
| # 使用合成数据 | |
| train_loader, val_loader, test_loader = loader.get_synthetic_loaders( | |
| num_samples=1000 | |
| ) | |
| ``` | |
| ### 5. 从配置文件加载 | |
| ```python | |
| from src.data import load_data_from_config | |
| # 从配置文件加载数据 | |
| train_loader, val_loader, test_loader = load_data_from_config( | |
| 'configs/training_config.yaml' | |
| ) | |
| ``` | |
| ## 配置选项 | |
| ### 数据预处理配置 | |
| ```python | |
| config = { | |
| 'feature_scaling': { | |
| 'method': 'standard', # standard, min_max, robust, none | |
| 'pad_features': 'standard', | |
| 'vitality_feature': 'min_max' | |
| }, | |
| 'missing_values': { | |
| 'strategy': 'mean', # mean, median, most_frequent, constant, knn | |
| 'knn_neighbors': 5 | |
| }, | |
| 'outliers': { | |
| 'method': 'isolation_forest', # isolation_forest, z_score, iqr | |
| 'contamination': 0.1 | |
| } | |
| } | |
| preprocessor = create_preprocessor(config) | |
| ``` | |
| ### 数据加载器配置 | |
| ```python | |
| config = { | |
| 'batch_size': 32, | |
| 'num_workers': 4, | |
| 'train_split': 0.7, | |
| 'val_split': 0.15, | |
| 'test_split': 0.15, | |
| 'normalize_features': True, | |
| 'normalize_labels': False | |
| } | |
| loader = create_data_loader(config) | |
| ``` | |
| ## 数据验证 | |
| 模块包含完整的数据验证功能: | |
| - **范围检查**: 验证PAD值、Vitality值和置信度在合理范围内 | |
| - **缺失值检测**: 自动检测和处理NaN值 | |
| - **异常值检测**: 使用多种方法检测异常值 | |
| - **维度验证**: 确保数据维度正确 | |
| ## 文件结构 | |
| ``` | |
| src/data/ | |
| ├── __init__.py # 模块导出 | |
| ├── dataset.py # EmotionDataset类 | |
| ├── data_loader.py # 数据加载器工厂 | |
| ├── preprocessor.py # 数据预处理类 | |
| ├── synthetic_generator.py # 合成数据生成器 | |
| └── README.md # 使用说明 | |
| ``` | |
| ## 依赖要求 | |
| - torch >= 1.12.0 | |
| - numpy >= 1.21.0 | |
| - pandas >= 1.3.0 | |
| - scikit-learn >= 1.0.0 | |
| - scipy >= 1.7.0 | |
| - loguru >= 0.6.0 | |
| ## 测试 | |
| 运行测试脚本验证功能: | |
| ```bash | |
| # 在虚拟环境中运行 | |
| python simple_test.py | |
| # 完整测试(需要torch) | |
| python test_data_module.py | |
| ``` | |
| ## 注意事项 | |
| 1. 确保在虚拟环境中安装所有依赖 | |
| 2. PAD值范围应在[-1, 1]内 | |
| 3. Vitality值范围应在[0, 100]内 | |
| 4. 置信度范围应在[0, 1]内 | |
| 5. 数据预处理时应先拟合预处理器再转换数据 |