Chordia / examples /quick_start.py
Corolin's picture
first commit
0a6452f
#!/usr/bin/env python3
"""
快速开始教程
Quick Start Tutorial for Emotion and Physiological State Prediction Model
这个脚本演示了如何快速开始使用情绪与生理状态变化预测模型:
1. 生成合成数据
2. 训练模型
3. 进行预测推理
运行方式:
python quick_start.py
"""
import sys
import os
from pathlib import Path
import numpy as np
import pandas as pd
import torch
from typing import Dict, Any
# 添加项目根目录到Python路径
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from src.data.synthetic_generator import SyntheticDataGenerator
from src.models.pad_predictor import PADPredictor
from src.data.preprocessor import DataPreprocessor
from src.utils.trainer import ModelTrainer
from src.utils.inference_engine import create_inference_engine
from src.utils.logger import setup_logger
def main():
"""主函数"""
print("=" * 60)
print("情绪与生理状态变化预测模型 - 快速开始教程")
print("Emotion and Physiological State Prediction Model - Quick Start")
print("=" * 60)
# 设置日志
setup_logger(level='INFO')
# 1. 生成合成数据
print("\n1. 生成合成数据...")
generate_synthetic_data()
# 2. 训练模型
print("\n2. 训练模型...")
model_path = train_model()
# 3. 进行推理预测
print("\n3. 进行推理预测...")
perform_inference(model_path)
print("\n" + "=" * 60)
print("快速开始教程完成!")
print("Quick Start Tutorial Completed!")
print("=" * 60)
def generate_synthetic_data():
"""生成合成数据"""
print(" - 创建数据生成器...")
# 创建数据生成器
generator = SyntheticDataGenerator(
num_samples=1000,
seed=42
)
print(" - 生成训练数据...")
# 生成数据
features, labels = generator.generate_data(
add_noise=True,
add_correlations=True
)
print(f" - 数据形状: 特征 {features.shape}, 标签 {labels.shape}")
# 保存数据
output_dir = Path(project_root) / "examples" / "data"
output_dir.mkdir(exist_ok=True)
generator.save_data(
features,
labels,
output_dir / "training_data.csv",
format='csv'
)
print(f" - 数据已保存到: {output_dir / 'training_data.csv'}")
# 显示数据统计信息
print(" - 数据统计信息:")
stats = generator.get_data_statistics(features, labels)
print(f" 特征均值范围: [{min(stats['features']['mean'].values()):.3f}, {max(stats['features']['mean'].values()):.3f}]")
print(f" 标签均值范围: [{min(stats['labels']['mean'].values()):.3f}, {max(stats['labels']['mean'].values()):.3f}]")
return features, labels
def train_model():
"""训练模型"""
print(" - 准备训练数据...")
# 加载数据
data_path = Path(project_root) / "examples" / "data" / "training_data.csv"
data = pd.read_csv(data_path)
# 分离特征和标签
feature_columns = [
'user_pleasure', 'user_arousal', 'user_dominance',
'vitality', 'current_pleasure', 'current_arousal', 'current_dominance'
]
label_columns = [
'delta_pleasure', 'delta_arousal', 'delta_dominance',
'delta_pressure', 'confidence'
]
features = data[feature_columns].values
labels = data[label_columns].values
# 数据预处理
print(" - 数据预处理...")
preprocessor = DataPreprocessor()
preprocessor.fit(features, labels)
processed_features, processed_labels = preprocessor.transform(features, labels)
# 创建数据加载器
from torch.utils.data import DataLoader, TensorDataset
dataset = TensorDataset(
torch.FloatTensor(processed_features),
torch.FloatTensor(processed_labels)
)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
# 创建模型
print(" - 创建模型...")
model = PADPredictor(
input_dim=7,
output_dim=5,
hidden_dims=[128, 64, 32],
dropout_rate=0.3
)
# 创建训练器
print(" - 开始训练...")
trainer = ModelTrainer(model, preprocessor)
# 训练配置
training_config = {
'epochs': 50,
'learning_rate': 0.001,
'weight_decay': 1e-4,
'patience': 10,
'save_dir': Path(project_root) / "examples" / "models"
}
# 训练模型
history = trainer.train(
train_loader=train_loader,
val_loader=val_loader,
config=training_config
)
# 保存模型
model_save_path = Path(project_root) / "examples" / "models" / "quick_start_model.pth"
preprocessor_save_path = Path(project_root) / "examples" / "models" / "quick_start_preprocessor.pkl"
model.save_model(str(model_save_path))
preprocessor.save(str(preprocessor_save_path))
print(f" - 模型已保存到: {model_save_path}")
print(f" - 预处理器已保存到: {preprocessor_save_path}")
# 显示训练结果
final_train_loss = history['train_loss'][-1]
final_val_loss = history['val_loss'][-1]
print(f" - 训练完成:")
print(f" 最终训练损失: {final_train_loss:.4f}")
print(f" 最终验证损失: {final_val_loss:.4f}")
return str(model_save_path)
def perform_inference(model_path: str):
"""进行推理预测"""
print(" - 创建推理引擎...")
# 创建推理引擎
engine = create_inference_engine(
model_path=model_path,
preprocessor_path=Path(project_root) / "examples" / "models" / "quick_start_preprocessor.pkl",
device='auto'
)
# 示例数据
sample_inputs = [
[0.5, 0.3, -0.2, 80.0, 0.1, 0.4, -0.1], # 正面情绪,高活力
[-0.3, 0.6, 0.2, 45.0, -0.1, 0.7, 0.1], # 负面情绪,中等活力
[0.8, -0.4, 0.6, 92.0, 0.7, -0.3, 0.5], # 高兴,低激活度,高活力
[-0.7, -0.5, -0.3, 25.0, -0.6, -0.4, -0.2], # 负面情绪,低活力
[0.2, 0.1, 0.0, 60.0, 0.3, 0.0, 0.1] # 中性情绪,中等活力
]
print(" - 进行预测...")
for i, input_data in enumerate(sample_inputs):
result = engine.predict(input_data)
print(f"\n 样本 {i+1}:")
print(f" 输入: User PAD=[{input_data[0]:.2f}, {input_data[1]:.2f}, {input_data[2]:.2f}], "
f"Vitality={input_data[3]:.1f}, Current PAD=[{input_data[4]:.2f}, {input_data[5]:.2f}, {input_data[6]:.2f}]")
print(f" 预测:")
print(f" ΔPAD: [{result['delta_pad'][0]:.3f}, {result['delta_pad'][1]:.3f}, {result['delta_pad'][2]:.3f}]")
print(f" ΔPressure: {result['delta_pressure']:.3f}")
print(f" Confidence: {result['confidence']:.3f}")
# 解释预测结果
interpretation = interpret_prediction(result)
print(f" 解释: {interpretation}")
# 批量预测
print("\n - 批量预测...")
batch_results = engine.predict_batch(sample_inputs)
print(f" - 批量预测完成,处理了 {len(batch_results)} 个样本")
# 性能基准测试
print("\n - 性能基准测试...")
stats = engine.benchmark(num_samples=100, batch_size=32)
print(f" - 性能统计:")
print(f" 吞吐量: {stats['throughput']:.2f} 样本/秒")
print(f" 平均延迟: {stats['avg_latency']:.2f}ms")
def interpret_prediction(result: Dict[str, Any]) -> str:
"""解释预测结果"""
delta_pad = result['delta_pad']
delta_pressure = result['delta_pressure']
confidence = result['confidence']
interpretations = []
# PAD变化解释
if abs(delta_pad[0]) > 0.05: # 快乐度变化
if delta_pad[0] > 0:
interpretations.append("情绪趋向积极")
else:
interpretations.append("情绪趋向消极")
if abs(delta_pad[1]) > 0.05: # 激活度变化
if delta_pad[1] > 0:
interpretations.append("激活度增加")
else:
interpretations.append("激活度降低")
if abs(delta_pad[2]) > 0.05: # 支配度变化
if delta_pad[2] > 0:
interpretations.append("支配感增强")
else:
interpretations.append("支配感减弱")
# 压力变化解释
if abs(delta_pressure) > 0.03:
if delta_pressure > 0:
interpretations.append("压力增加")
else:
interpretations.append("压力缓解")
# 置信度解释
if confidence > 0.8:
interpretations.append("高置信度预测")
elif confidence > 0.6:
interpretations.append("中等置信度预测")
else:
interpretations.append("低置信度预测")
if not interpretations:
interpretations.append("情绪状态相对稳定")
return ",".join(interpretations)
if __name__ == "__main__":
main()