|
|
|
|
|
""" |
|
|
快速开始教程 |
|
|
Quick Start Tutorial for Emotion and Physiological State Prediction Model |
|
|
|
|
|
这个脚本演示了如何快速开始使用情绪与生理状态变化预测模型: |
|
|
1. 生成合成数据 |
|
|
2. 训练模型 |
|
|
3. 进行预测推理 |
|
|
|
|
|
运行方式: |
|
|
python quick_start.py |
|
|
""" |
|
|
|
|
|
import sys |
|
|
import os |
|
|
from pathlib import Path |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import torch |
|
|
from typing import Dict, Any |
|
|
|
|
|
|
|
|
project_root = Path(__file__).parent.parent |
|
|
sys.path.insert(0, str(project_root)) |
|
|
|
|
|
from src.data.synthetic_generator import SyntheticDataGenerator |
|
|
from src.models.pad_predictor import PADPredictor |
|
|
from src.data.preprocessor import DataPreprocessor |
|
|
from src.utils.trainer import ModelTrainer |
|
|
from src.utils.inference_engine import create_inference_engine |
|
|
from src.utils.logger import setup_logger |
|
|
|
|
|
def main(): |
|
|
"""主函数""" |
|
|
print("=" * 60) |
|
|
print("情绪与生理状态变化预测模型 - 快速开始教程") |
|
|
print("Emotion and Physiological State Prediction Model - Quick Start") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
setup_logger(level='INFO') |
|
|
|
|
|
|
|
|
print("\n1. 生成合成数据...") |
|
|
generate_synthetic_data() |
|
|
|
|
|
|
|
|
print("\n2. 训练模型...") |
|
|
model_path = train_model() |
|
|
|
|
|
|
|
|
print("\n3. 进行推理预测...") |
|
|
perform_inference(model_path) |
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("快速开始教程完成!") |
|
|
print("Quick Start Tutorial Completed!") |
|
|
print("=" * 60) |
|
|
|
|
|
def generate_synthetic_data(): |
|
|
"""生成合成数据""" |
|
|
print(" - 创建数据生成器...") |
|
|
|
|
|
|
|
|
generator = SyntheticDataGenerator( |
|
|
num_samples=1000, |
|
|
seed=42 |
|
|
) |
|
|
|
|
|
print(" - 生成训练数据...") |
|
|
|
|
|
features, labels = generator.generate_data( |
|
|
add_noise=True, |
|
|
add_correlations=True |
|
|
) |
|
|
|
|
|
print(f" - 数据形状: 特征 {features.shape}, 标签 {labels.shape}") |
|
|
|
|
|
|
|
|
output_dir = Path(project_root) / "examples" / "data" |
|
|
output_dir.mkdir(exist_ok=True) |
|
|
|
|
|
generator.save_data( |
|
|
features, |
|
|
labels, |
|
|
output_dir / "training_data.csv", |
|
|
format='csv' |
|
|
) |
|
|
|
|
|
print(f" - 数据已保存到: {output_dir / 'training_data.csv'}") |
|
|
|
|
|
|
|
|
print(" - 数据统计信息:") |
|
|
stats = generator.get_data_statistics(features, labels) |
|
|
|
|
|
print(f" 特征均值范围: [{min(stats['features']['mean'].values()):.3f}, {max(stats['features']['mean'].values()):.3f}]") |
|
|
print(f" 标签均值范围: [{min(stats['labels']['mean'].values()):.3f}, {max(stats['labels']['mean'].values()):.3f}]") |
|
|
|
|
|
return features, labels |
|
|
|
|
|
def train_model(): |
|
|
"""训练模型""" |
|
|
print(" - 准备训练数据...") |
|
|
|
|
|
|
|
|
data_path = Path(project_root) / "examples" / "data" / "training_data.csv" |
|
|
data = pd.read_csv(data_path) |
|
|
|
|
|
|
|
|
feature_columns = [ |
|
|
'user_pleasure', 'user_arousal', 'user_dominance', |
|
|
'vitality', 'current_pleasure', 'current_arousal', 'current_dominance' |
|
|
] |
|
|
label_columns = [ |
|
|
'delta_pleasure', 'delta_arousal', 'delta_dominance', |
|
|
'delta_pressure', 'confidence' |
|
|
] |
|
|
|
|
|
features = data[feature_columns].values |
|
|
labels = data[label_columns].values |
|
|
|
|
|
|
|
|
print(" - 数据预处理...") |
|
|
preprocessor = DataPreprocessor() |
|
|
preprocessor.fit(features, labels) |
|
|
|
|
|
processed_features, processed_labels = preprocessor.transform(features, labels) |
|
|
|
|
|
|
|
|
from torch.utils.data import DataLoader, TensorDataset |
|
|
|
|
|
dataset = TensorDataset( |
|
|
torch.FloatTensor(processed_features), |
|
|
torch.FloatTensor(processed_labels) |
|
|
) |
|
|
|
|
|
train_size = int(0.8 * len(dataset)) |
|
|
val_size = len(dataset) - train_size |
|
|
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size]) |
|
|
|
|
|
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) |
|
|
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False) |
|
|
|
|
|
|
|
|
print(" - 创建模型...") |
|
|
model = PADPredictor( |
|
|
input_dim=7, |
|
|
output_dim=5, |
|
|
hidden_dims=[128, 64, 32], |
|
|
dropout_rate=0.3 |
|
|
) |
|
|
|
|
|
|
|
|
print(" - 开始训练...") |
|
|
trainer = ModelTrainer(model, preprocessor) |
|
|
|
|
|
|
|
|
training_config = { |
|
|
'epochs': 50, |
|
|
'learning_rate': 0.001, |
|
|
'weight_decay': 1e-4, |
|
|
'patience': 10, |
|
|
'save_dir': Path(project_root) / "examples" / "models" |
|
|
} |
|
|
|
|
|
|
|
|
history = trainer.train( |
|
|
train_loader=train_loader, |
|
|
val_loader=val_loader, |
|
|
config=training_config |
|
|
) |
|
|
|
|
|
|
|
|
model_save_path = Path(project_root) / "examples" / "models" / "quick_start_model.pth" |
|
|
preprocessor_save_path = Path(project_root) / "examples" / "models" / "quick_start_preprocessor.pkl" |
|
|
|
|
|
model.save_model(str(model_save_path)) |
|
|
preprocessor.save(str(preprocessor_save_path)) |
|
|
|
|
|
print(f" - 模型已保存到: {model_save_path}") |
|
|
print(f" - 预处理器已保存到: {preprocessor_save_path}") |
|
|
|
|
|
|
|
|
final_train_loss = history['train_loss'][-1] |
|
|
final_val_loss = history['val_loss'][-1] |
|
|
|
|
|
print(f" - 训练完成:") |
|
|
print(f" 最终训练损失: {final_train_loss:.4f}") |
|
|
print(f" 最终验证损失: {final_val_loss:.4f}") |
|
|
|
|
|
return str(model_save_path) |
|
|
|
|
|
def perform_inference(model_path: str): |
|
|
"""进行推理预测""" |
|
|
print(" - 创建推理引擎...") |
|
|
|
|
|
|
|
|
engine = create_inference_engine( |
|
|
model_path=model_path, |
|
|
preprocessor_path=Path(project_root) / "examples" / "models" / "quick_start_preprocessor.pkl", |
|
|
device='auto' |
|
|
) |
|
|
|
|
|
|
|
|
sample_inputs = [ |
|
|
[0.5, 0.3, -0.2, 80.0, 0.1, 0.4, -0.1], |
|
|
[-0.3, 0.6, 0.2, 45.0, -0.1, 0.7, 0.1], |
|
|
[0.8, -0.4, 0.6, 92.0, 0.7, -0.3, 0.5], |
|
|
[-0.7, -0.5, -0.3, 25.0, -0.6, -0.4, -0.2], |
|
|
[0.2, 0.1, 0.0, 60.0, 0.3, 0.0, 0.1] |
|
|
] |
|
|
|
|
|
print(" - 进行预测...") |
|
|
|
|
|
for i, input_data in enumerate(sample_inputs): |
|
|
result = engine.predict(input_data) |
|
|
|
|
|
print(f"\n 样本 {i+1}:") |
|
|
print(f" 输入: User PAD=[{input_data[0]:.2f}, {input_data[1]:.2f}, {input_data[2]:.2f}], " |
|
|
f"Vitality={input_data[3]:.1f}, Current PAD=[{input_data[4]:.2f}, {input_data[5]:.2f}, {input_data[6]:.2f}]") |
|
|
|
|
|
print(f" 预测:") |
|
|
print(f" ΔPAD: [{result['delta_pad'][0]:.3f}, {result['delta_pad'][1]:.3f}, {result['delta_pad'][2]:.3f}]") |
|
|
print(f" ΔPressure: {result['delta_pressure']:.3f}") |
|
|
print(f" Confidence: {result['confidence']:.3f}") |
|
|
|
|
|
|
|
|
interpretation = interpret_prediction(result) |
|
|
print(f" 解释: {interpretation}") |
|
|
|
|
|
|
|
|
print("\n - 批量预测...") |
|
|
batch_results = engine.predict_batch(sample_inputs) |
|
|
|
|
|
print(f" - 批量预测完成,处理了 {len(batch_results)} 个样本") |
|
|
|
|
|
|
|
|
print("\n - 性能基准测试...") |
|
|
stats = engine.benchmark(num_samples=100, batch_size=32) |
|
|
|
|
|
print(f" - 性能统计:") |
|
|
print(f" 吞吐量: {stats['throughput']:.2f} 样本/秒") |
|
|
print(f" 平均延迟: {stats['avg_latency']:.2f}ms") |
|
|
|
|
|
def interpret_prediction(result: Dict[str, Any]) -> str: |
|
|
"""解释预测结果""" |
|
|
delta_pad = result['delta_pad'] |
|
|
delta_pressure = result['delta_pressure'] |
|
|
confidence = result['confidence'] |
|
|
|
|
|
interpretations = [] |
|
|
|
|
|
|
|
|
if abs(delta_pad[0]) > 0.05: |
|
|
if delta_pad[0] > 0: |
|
|
interpretations.append("情绪趋向积极") |
|
|
else: |
|
|
interpretations.append("情绪趋向消极") |
|
|
|
|
|
if abs(delta_pad[1]) > 0.05: |
|
|
if delta_pad[1] > 0: |
|
|
interpretations.append("激活度增加") |
|
|
else: |
|
|
interpretations.append("激活度降低") |
|
|
|
|
|
if abs(delta_pad[2]) > 0.05: |
|
|
if delta_pad[2] > 0: |
|
|
interpretations.append("支配感增强") |
|
|
else: |
|
|
interpretations.append("支配感减弱") |
|
|
|
|
|
|
|
|
if abs(delta_pressure) > 0.03: |
|
|
if delta_pressure > 0: |
|
|
interpretations.append("压力增加") |
|
|
else: |
|
|
interpretations.append("压力缓解") |
|
|
|
|
|
|
|
|
if confidence > 0.8: |
|
|
interpretations.append("高置信度预测") |
|
|
elif confidence > 0.6: |
|
|
interpretations.append("中等置信度预测") |
|
|
else: |
|
|
interpretations.append("低置信度预测") |
|
|
|
|
|
if not interpretations: |
|
|
interpretations.append("情绪状态相对稳定") |
|
|
|
|
|
return ",".join(interpretations) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |