Chordia / src /scripts /predict.py
Corolin's picture
first commit
0a6452f
"""
简化的预测CLI工具
Simplified CLI Tool for emotion and physiological state prediction
该工具提供了简化的命令行界面,支持:
- 交互式输入
- 批量文件处理
- 清晰的输出格式和解释
- 快速预测模式
"""
import argparse
import sys
import os
import json
import csv
import logging
from pathlib import Path
from typing import List, Dict, Any, Optional
import numpy as np
# 添加项目根目录到Python路径
project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))
from src.utils.inference_engine import create_inference_engine
from src.utils.logger import setup_logger
class PredictCLI:
"""简化的预测CLI类"""
def __init__(self, model_path: str, preprocessor_path: Optional[str] = None):
"""
初始化预测CLI
Args:
model_path: 模型文件路径
preprocessor_path: 预处理器文件路径
"""
self.logger = logging.getLogger(__name__)
try:
# 创建推理引擎
self.engine = create_inference_engine(
model_path=model_path,
preprocessor_path=preprocessor_path,
device='auto'
)
# 获取模型信息
self.model_info = self.engine.get_model_info()
self.logger.info("预测CLI初始化成功")
except Exception as e:
self.logger.error(f"预测CLI初始化失败: {e}")
raise
def interactive_mode(self):
"""交互式预测模式"""
print("\n" + "="*60)
print("情绪与生理状态变化预测工具 - 交互式模式")
print("="*60)
print("\n请输入以下7个参数:")
print("1. 用户快乐度 (User Pleasure): [-1.0, 1.0]")
print("2. 用户激活度 (User Arousal): [-1.0, 1.0]")
print("3. 用户支配度 (User Dominance): [-1.0, 1.0]")
print("4. 活力值 (Vitality): [0.0, 100.0]")
print("5. 当前快乐度 (Current Pleasure): [-1.0, 1.0]")
print("6. 当前激活度 (Current Arousal): [-1.0, 1.0]")
print("7. 当前支配度 (Current Dominance): [-1.0, 1.0]")
print("\n输入 'quit' 退出交互模式")
print("-"*60)
while True:
try:
print("\n请输入7个数值 (用空格分隔):")
user_input = input("> ").strip()
if user_input.lower() in ['quit', 'exit', 'q']:
print("退出交互模式")
break
# 解析输入
values = user_input.split()
if len(values) != 7:
print(f"错误: 需要输入7个数值,但得到{len(values)}个")
continue
try:
input_data = np.array([float(v) for v in values], dtype=np.float32)
except ValueError:
print("错误: 输入必须是数字")
continue
# 验证输入范围
if not self._validate_input_ranges(input_data):
continue
# 执行预测
result = self.predict_single(input_data)
# 显示结果
self._display_result(result, input_data)
except KeyboardInterrupt:
print("\n\n用户中断,退出交互模式")
break
except Exception as e:
print(f"预测出错: {e}")
def _validate_input_ranges(self, input_data: np.ndarray) -> bool:
"""验证输入数据范围"""
user_pad = input_data[:3]
vitality = input_data[3]
current_pad = input_data[4:]
# 检查PAD值范围
if np.any(np.abs(user_pad) > 1.5):
print("警告: 用户PAD值超出正常范围 [-1.0, 1.0]")
response = input("是否继续? (y/n): ").strip().lower()
if response != 'y':
return False
if np.any(np.abs(current_pad) > 1.5):
print("警告: 当前PAD值超出正常范围 [-1.0, 1.0]")
response = input("是否继续? (y/n): ").strip().lower()
if response != 'y':
return False
# 检查活力值范围
if not (0 <= vitality <= 150):
print("警告: 活力值超出正常范围 [0.0, 100.0]")
response = input("是否继续? (y/n): ").strip().lower()
if response != 'y':
return False
return True
def predict_single(self, input_data: np.ndarray):
"""预测单个样本"""
try:
result = self.engine.predict(input_data)
return result
except Exception as e:
raise RuntimeError(f"预测失败: {e}")
def _display_result(self, result, input_data: np.ndarray):
"""显示预测结果"""
print("\n" + "="*50)
print("预测结果")
print("="*50)
# 显示输入信息
print(f"\n输入信息:")
print(f" 用户PAD: 快乐度={input_data[0]:+.3f}, 激活度={input_data[1]:+.3f}, 支配度={input_data[2]:+.3f}")
print(f" 活力值: {input_data[3]:.1f}")
print(f" 当前PAD: 快乐度={input_data[4]:+.3f}, 激活度={input_data[5]:+.3f}, 支配度={input_data[6]:+.3f}")
# 显示预测结果
delta_pad = result.delta_pad[0]
delta_pressure = result.delta_pressure[0]
confidence = result.confidence[0]
print(f"\n预测变化:")
print(f" 情绪变化 (ΔPAD):")
print(f" 快乐度变化: {delta_pad:+.6f} {'↗' if delta_pad > 0 else '↘' if delta_pad < 0 else '→'}")
print(f" 激活度变化: {delta_pad:+.6f} {'↗' if delta_pad > 0 else '↘' if delta_pad < 0 else '→'}")
print(f" 支配度变化: {delta_pad:+.6f} {'↗' if delta_pad > 0 else '↘' if delta_pad < 0 else '→'}")
print(f" 压力变化: {delta_pressure:+.6f} {'↗' if delta_pressure > 0 else '↘' if delta_pressure < 0 else '→'}")
print(f" 预测置信度: {confidence:.6f} ({confidence*100:.1f}%)")
# 提供解释
self._provide_interpretation(delta_pad, delta_pressure, confidence)
# 显示性能信息
print(f"\n性能信息:")
print(f" 推理时间: {result.inference_time*1000:.2f}ms")
print("="*50)
def _provide_interpretation(self, delta_pad: np.ndarray, delta_pressure: float, confidence: float):
"""提供预测结果解释"""
print(f"\n结果解释:")
# 情绪变化解释
pleasure_change = delta_pad[0]
arousal_change = delta_pad[1]
dominance_change = delta_pad[2]
if abs(pleasure_change) > 0.1:
if pleasure_change > 0:
print(" • 情绪趋向积极愉快")
else:
print(" • 情绪趋向消极低落")
if abs(arousal_change) > 0.1:
if arousal_change > 0:
print(" • 激活度提升,趋向兴奋")
else:
print(" • 激活度降低,趋向平静")
if abs(dominance_change) > 0.1:
if dominance_change > 0:
print(" • 支配感增强,趋向自信")
else:
print(" • 支配感减弱,趋向顺从")
# 压力变化解释
if abs(delta_pressure) > 0.05:
if delta_pressure > 0:
print(" • 压力水平可能上升")
else:
print(" • 压力水平可能下降")
# 置信度解释
if confidence > 0.8:
print(" • 预测置信度很高")
elif confidence > 0.6:
print(" • 预测置信度中等")
else:
print(" • 预测置信度较低,结果可能不太准确")
def batch_predict(self, input_file: str, output_file: Optional[str] = None):
"""批量预测"""
print(f"\n批量预测模式")
print(f"输入文件: {input_file}")
try:
# 加载输入数据
input_data = self._load_batch_input(input_file)
print(f"加载了 {len(input_data)} 个样本")
# 执行批量预测
print("执行批量预测...")
results = self.engine.predict_batch(input_data)
# 处理结果
processed_results = []
for i, result in enumerate(results):
processed_results.append({
'sample_id': i + 1,
'delta_pleasure': float(result.delta_pad[0]),
'delta_arousal': float(result.delta_pad[1]),
'delta_dominance': float(result.delta_pad[2]),
'delta_pressure': float(result.delta_pressure[0]),
'confidence': float(result.confidence[0]),
'inference_time': float(result.inference_time)
})
# 保存结果
if output_file:
self._save_batch_results(processed_results, output_file)
print(f"结果已保存到: {output_file}")
else:
self._display_batch_summary(processed_results)
except Exception as e:
print(f"批量预测失败: {e}")
raise
def _load_batch_input(self, input_file: str) -> np.ndarray:
"""加载批量输入数据"""
file_path = Path(input_file)
if not file_path.exists():
raise FileNotFoundError(f"输入文件不存在: {input_file}")
if file_path.suffix.lower() == '.json':
return self._load_json_batch(file_path)
elif file_path.suffix.lower() == '.csv':
return self._load_csv_batch(file_path)
else:
raise ValueError(f"不支持的文件格式: {file_path.suffix}")
def _load_json_batch(self, file_path: Path) -> np.ndarray:
"""加载JSON格式的批量数据"""
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
if isinstance(data, list):
return np.array(data, dtype=np.float32)
elif isinstance(data, dict) and 'data' in data:
return np.array(data['data'], dtype=np.float32)
else:
raise ValueError("JSON格式不正确,需要数据数组或包含'data'字段的对象")
def _load_csv_batch(self, file_path: Path) -> np.ndarray:
"""加载CSV格式的批量数据"""
import pandas as pd
df = pd.read_csv(file_path)
# 检查列数
if len(df.columns) < 7:
raise ValueError(f"CSV文件至少需要7列,但得到{len(df.columns)}列")
# 使用前7列
data = df.iloc[:, :7].values
return data.astype(np.float32)
def _save_batch_results(self, results: List[Dict[str, Any]], output_file: str):
"""保存批量结果"""
file_path = Path(output_file)
if file_path.suffix.lower() == '.json':
with open(file_path, 'w', encoding='utf-8') as f:
json.dump({
'results': results,
'summary': {
'total_samples': len(results),
'avg_confidence': np.mean([r['confidence'] for r in results]),
'avg_inference_time': np.mean([r['inference_time'] for r in results])
}
}, f, indent=2, ensure_ascii=False)
elif file_path.suffix.lower() == '.csv':
with open(file_path, 'w', newline='', encoding='utf-8') as f:
if results:
fieldnames = results[0].keys()
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(results)
else:
raise ValueError(f"不支持的输出格式: {file_path.suffix}")
def _display_batch_summary(self, results: List[Dict[str, Any]]):
"""显示批量预测摘要"""
if not results:
print("没有预测结果")
return
print(f"\n批量预测摘要:")
print(f"总样本数: {len(results)}")
# 统计信息
confidences = [r['confidence'] for r in results]
inference_times = [r['inference_time'] for r in results]
print(f"平均置信度: {np.mean(confidences):.4f}")
print(f"置信度范围: [{np.min(confidences):.4f}, {np.max(confidences):.4f}]")
print(f"平均推理时间: {np.mean(inference_times)*1000:.2f}ms")
print(f"总推理时间: {np.sum(inference_times):.4f}s")
# 显示前几个结果
print(f"\n前5个样本结果:")
for i, result in enumerate(results[:5]):
print(f"样本{result['sample_id']}: "
f"ΔPAD=[{result['delta_pleasure']:+.3f}, {result['delta_arousal']:+.3f}, {result['delta_dominance']:+.3f}], "
f"Δ压力={result['delta_pressure']:+.3f}, "
f"置信度={result['confidence']:.3f}")
def quick_predict(self, values: List[float]):
"""快速预测模式"""
if len(values) != 7:
raise ValueError(f"需要7个输入值,但得到{len(values)}个")
input_data = np.array(values, dtype=np.float32)
try:
result = self.predict_single(input_data)
# 简洁输出
delta_pad = result.delta_pad[0]
delta_pressure = result.delta_pressure[0]
confidence = result.confidence[0]
print(f"ΔPAD: [{delta_pad[0]:+.4f}, {delta_pad[1]:+.4f}, {delta_pad[2]:+.4f}], "
f"Δ压力: {delta_pressure:+.4f}, "
f"置信度: {confidence:.4f}")
except Exception as e:
print(f"预测失败: {e}")
return False
return True
def main():
"""主函数"""
parser = argparse.ArgumentParser(
description="情绪与生理状态变化预测工具",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
使用示例:
# 交互式模式
python predict.py --model model.pth
# 快速预测
python predict.py --model model.pth --quick 0.5 0.3 -0.2 80 0.1 0.4 -0.1
# 批量预测
python predict.py --model model.pth --batch input.json --output results.json
"""
)
# 必需参数
parser.add_argument('--model', '-m', type=str, required=True,
help='模型文件路径 (.pth)')
parser.add_argument('--preprocessor', '-p', type=str,
help='预处理器文件路径')
# 模式选择
mode_group = parser.add_mutually_exclusive_group()
mode_group.add_argument('--interactive', '-i', action='store_true',
help='交互式模式')
mode_group.add_argument('--quick', nargs=7, type=float, metavar='VALUE',
help='快速预测模式 (7个数值)')
mode_group.add_argument('--batch', type=str, metavar='FILE',
help='批量预测模式 (输入文件)')
# 输出选项
parser.add_argument('--output', '-o', type=str,
help='输出文件路径 (批量模式)')
parser.add_argument('--verbose', '-v', action='store_true',
help='详细输出')
parser.add_argument('--log-level', type=str,
choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'],
default='WARNING', help='日志级别')
args = parser.parse_args()
# 设置日志
setup_logger(level=args.log_level)
try:
# 创建预测CLI
cli = PredictCLI(args.model, args.preprocessor)
# 显示模型信息
if args.verbose:
print(f"模型信息:")
print(f" 设备: {cli.model_info['device']}")
print(f" 总参数量: {cli.model_info['total_parameters']:,}")
print(f" 输入维度: {cli.model_info['input_dim']}")
print(f" 输出维度: {cli.model_info['output_dim']}")
# 根据模式执行
if args.interactive or (not args.quick and not args.batch):
# 默认进入交互模式
cli.interactive_mode()
elif args.quick:
# 快速预测
success = cli.quick_predict(args.quick)
if not success:
sys.exit(1)
elif args.batch:
# 批量预测
cli.batch_predict(args.batch, args.output)
except Exception as e:
print(f"错误: {e}")
if args.verbose:
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()