""" 简化的预测CLI工具 Simplified CLI Tool for emotion and physiological state prediction 该工具提供了简化的命令行界面,支持: - 交互式输入 - 批量文件处理 - 清晰的输出格式和解释 - 快速预测模式 """ import argparse import sys import os import json import csv import logging from pathlib import Path from typing import List, Dict, Any, Optional import numpy as np # 添加项目根目录到Python路径 project_root = Path(__file__).parent.parent.parent sys.path.insert(0, str(project_root)) from src.utils.inference_engine import create_inference_engine from src.utils.logger import setup_logger class PredictCLI: """简化的预测CLI类""" def __init__(self, model_path: str, preprocessor_path: Optional[str] = None): """ 初始化预测CLI Args: model_path: 模型文件路径 preprocessor_path: 预处理器文件路径 """ self.logger = logging.getLogger(__name__) try: # 创建推理引擎 self.engine = create_inference_engine( model_path=model_path, preprocessor_path=preprocessor_path, device='auto' ) # 获取模型信息 self.model_info = self.engine.get_model_info() self.logger.info("预测CLI初始化成功") except Exception as e: self.logger.error(f"预测CLI初始化失败: {e}") raise def interactive_mode(self): """交互式预测模式""" print("\n" + "="*60) print("情绪与生理状态变化预测工具 - 交互式模式") print("="*60) print("\n请输入以下7个参数:") print("1. 用户快乐度 (User Pleasure): [-1.0, 1.0]") print("2. 用户激活度 (User Arousal): [-1.0, 1.0]") print("3. 用户支配度 (User Dominance): [-1.0, 1.0]") print("4. 活力值 (Vitality): [0.0, 100.0]") print("5. 当前快乐度 (Current Pleasure): [-1.0, 1.0]") print("6. 当前激活度 (Current Arousal): [-1.0, 1.0]") print("7. 当前支配度 (Current Dominance): [-1.0, 1.0]") print("\n输入 'quit' 退出交互模式") print("-"*60) while True: try: print("\n请输入7个数值 (用空格分隔):") user_input = input("> ").strip() if user_input.lower() in ['quit', 'exit', 'q']: print("退出交互模式") break # 解析输入 values = user_input.split() if len(values) != 7: print(f"错误: 需要输入7个数值,但得到{len(values)}个") continue try: input_data = np.array([float(v) for v in values], dtype=np.float32) except ValueError: print("错误: 输入必须是数字") continue # 验证输入范围 if not self._validate_input_ranges(input_data): continue # 执行预测 result = self.predict_single(input_data) # 显示结果 self._display_result(result, input_data) except KeyboardInterrupt: print("\n\n用户中断,退出交互模式") break except Exception as e: print(f"预测出错: {e}") def _validate_input_ranges(self, input_data: np.ndarray) -> bool: """验证输入数据范围""" user_pad = input_data[:3] vitality = input_data[3] current_pad = input_data[4:] # 检查PAD值范围 if np.any(np.abs(user_pad) > 1.5): print("警告: 用户PAD值超出正常范围 [-1.0, 1.0]") response = input("是否继续? (y/n): ").strip().lower() if response != 'y': return False if np.any(np.abs(current_pad) > 1.5): print("警告: 当前PAD值超出正常范围 [-1.0, 1.0]") response = input("是否继续? (y/n): ").strip().lower() if response != 'y': return False # 检查活力值范围 if not (0 <= vitality <= 150): print("警告: 活力值超出正常范围 [0.0, 100.0]") response = input("是否继续? (y/n): ").strip().lower() if response != 'y': return False return True def predict_single(self, input_data: np.ndarray): """预测单个样本""" try: result = self.engine.predict(input_data) return result except Exception as e: raise RuntimeError(f"预测失败: {e}") def _display_result(self, result, input_data: np.ndarray): """显示预测结果""" print("\n" + "="*50) print("预测结果") print("="*50) # 显示输入信息 print(f"\n输入信息:") print(f" 用户PAD: 快乐度={input_data[0]:+.3f}, 激活度={input_data[1]:+.3f}, 支配度={input_data[2]:+.3f}") print(f" 活力值: {input_data[3]:.1f}") print(f" 当前PAD: 快乐度={input_data[4]:+.3f}, 激活度={input_data[5]:+.3f}, 支配度={input_data[6]:+.3f}") # 显示预测结果 delta_pad = result.delta_pad[0] delta_pressure = result.delta_pressure[0] confidence = result.confidence[0] print(f"\n预测变化:") print(f" 情绪变化 (ΔPAD):") print(f" 快乐度变化: {delta_pad:+.6f} {'↗' if delta_pad > 0 else '↘' if delta_pad < 0 else '→'}") print(f" 激活度变化: {delta_pad:+.6f} {'↗' if delta_pad > 0 else '↘' if delta_pad < 0 else '→'}") print(f" 支配度变化: {delta_pad:+.6f} {'↗' if delta_pad > 0 else '↘' if delta_pad < 0 else '→'}") print(f" 压力变化: {delta_pressure:+.6f} {'↗' if delta_pressure > 0 else '↘' if delta_pressure < 0 else '→'}") print(f" 预测置信度: {confidence:.6f} ({confidence*100:.1f}%)") # 提供解释 self._provide_interpretation(delta_pad, delta_pressure, confidence) # 显示性能信息 print(f"\n性能信息:") print(f" 推理时间: {result.inference_time*1000:.2f}ms") print("="*50) def _provide_interpretation(self, delta_pad: np.ndarray, delta_pressure: float, confidence: float): """提供预测结果解释""" print(f"\n结果解释:") # 情绪变化解释 pleasure_change = delta_pad[0] arousal_change = delta_pad[1] dominance_change = delta_pad[2] if abs(pleasure_change) > 0.1: if pleasure_change > 0: print(" • 情绪趋向积极愉快") else: print(" • 情绪趋向消极低落") if abs(arousal_change) > 0.1: if arousal_change > 0: print(" • 激活度提升,趋向兴奋") else: print(" • 激活度降低,趋向平静") if abs(dominance_change) > 0.1: if dominance_change > 0: print(" • 支配感增强,趋向自信") else: print(" • 支配感减弱,趋向顺从") # 压力变化解释 if abs(delta_pressure) > 0.05: if delta_pressure > 0: print(" • 压力水平可能上升") else: print(" • 压力水平可能下降") # 置信度解释 if confidence > 0.8: print(" • 预测置信度很高") elif confidence > 0.6: print(" • 预测置信度中等") else: print(" • 预测置信度较低,结果可能不太准确") def batch_predict(self, input_file: str, output_file: Optional[str] = None): """批量预测""" print(f"\n批量预测模式") print(f"输入文件: {input_file}") try: # 加载输入数据 input_data = self._load_batch_input(input_file) print(f"加载了 {len(input_data)} 个样本") # 执行批量预测 print("执行批量预测...") results = self.engine.predict_batch(input_data) # 处理结果 processed_results = [] for i, result in enumerate(results): processed_results.append({ 'sample_id': i + 1, 'delta_pleasure': float(result.delta_pad[0]), 'delta_arousal': float(result.delta_pad[1]), 'delta_dominance': float(result.delta_pad[2]), 'delta_pressure': float(result.delta_pressure[0]), 'confidence': float(result.confidence[0]), 'inference_time': float(result.inference_time) }) # 保存结果 if output_file: self._save_batch_results(processed_results, output_file) print(f"结果已保存到: {output_file}") else: self._display_batch_summary(processed_results) except Exception as e: print(f"批量预测失败: {e}") raise def _load_batch_input(self, input_file: str) -> np.ndarray: """加载批量输入数据""" file_path = Path(input_file) if not file_path.exists(): raise FileNotFoundError(f"输入文件不存在: {input_file}") if file_path.suffix.lower() == '.json': return self._load_json_batch(file_path) elif file_path.suffix.lower() == '.csv': return self._load_csv_batch(file_path) else: raise ValueError(f"不支持的文件格式: {file_path.suffix}") def _load_json_batch(self, file_path: Path) -> np.ndarray: """加载JSON格式的批量数据""" with open(file_path, 'r', encoding='utf-8') as f: data = json.load(f) if isinstance(data, list): return np.array(data, dtype=np.float32) elif isinstance(data, dict) and 'data' in data: return np.array(data['data'], dtype=np.float32) else: raise ValueError("JSON格式不正确,需要数据数组或包含'data'字段的对象") def _load_csv_batch(self, file_path: Path) -> np.ndarray: """加载CSV格式的批量数据""" import pandas as pd df = pd.read_csv(file_path) # 检查列数 if len(df.columns) < 7: raise ValueError(f"CSV文件至少需要7列,但得到{len(df.columns)}列") # 使用前7列 data = df.iloc[:, :7].values return data.astype(np.float32) def _save_batch_results(self, results: List[Dict[str, Any]], output_file: str): """保存批量结果""" file_path = Path(output_file) if file_path.suffix.lower() == '.json': with open(file_path, 'w', encoding='utf-8') as f: json.dump({ 'results': results, 'summary': { 'total_samples': len(results), 'avg_confidence': np.mean([r['confidence'] for r in results]), 'avg_inference_time': np.mean([r['inference_time'] for r in results]) } }, f, indent=2, ensure_ascii=False) elif file_path.suffix.lower() == '.csv': with open(file_path, 'w', newline='', encoding='utf-8') as f: if results: fieldnames = results[0].keys() writer = csv.DictWriter(f, fieldnames=fieldnames) writer.writeheader() writer.writerows(results) else: raise ValueError(f"不支持的输出格式: {file_path.suffix}") def _display_batch_summary(self, results: List[Dict[str, Any]]): """显示批量预测摘要""" if not results: print("没有预测结果") return print(f"\n批量预测摘要:") print(f"总样本数: {len(results)}") # 统计信息 confidences = [r['confidence'] for r in results] inference_times = [r['inference_time'] for r in results] print(f"平均置信度: {np.mean(confidences):.4f}") print(f"置信度范围: [{np.min(confidences):.4f}, {np.max(confidences):.4f}]") print(f"平均推理时间: {np.mean(inference_times)*1000:.2f}ms") print(f"总推理时间: {np.sum(inference_times):.4f}s") # 显示前几个结果 print(f"\n前5个样本结果:") for i, result in enumerate(results[:5]): print(f"样本{result['sample_id']}: " f"ΔPAD=[{result['delta_pleasure']:+.3f}, {result['delta_arousal']:+.3f}, {result['delta_dominance']:+.3f}], " f"Δ压力={result['delta_pressure']:+.3f}, " f"置信度={result['confidence']:.3f}") def quick_predict(self, values: List[float]): """快速预测模式""" if len(values) != 7: raise ValueError(f"需要7个输入值,但得到{len(values)}个") input_data = np.array(values, dtype=np.float32) try: result = self.predict_single(input_data) # 简洁输出 delta_pad = result.delta_pad[0] delta_pressure = result.delta_pressure[0] confidence = result.confidence[0] print(f"ΔPAD: [{delta_pad[0]:+.4f}, {delta_pad[1]:+.4f}, {delta_pad[2]:+.4f}], " f"Δ压力: {delta_pressure:+.4f}, " f"置信度: {confidence:.4f}") except Exception as e: print(f"预测失败: {e}") return False return True def main(): """主函数""" parser = argparse.ArgumentParser( description="情绪与生理状态变化预测工具", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" 使用示例: # 交互式模式 python predict.py --model model.pth # 快速预测 python predict.py --model model.pth --quick 0.5 0.3 -0.2 80 0.1 0.4 -0.1 # 批量预测 python predict.py --model model.pth --batch input.json --output results.json """ ) # 必需参数 parser.add_argument('--model', '-m', type=str, required=True, help='模型文件路径 (.pth)') parser.add_argument('--preprocessor', '-p', type=str, help='预处理器文件路径') # 模式选择 mode_group = parser.add_mutually_exclusive_group() mode_group.add_argument('--interactive', '-i', action='store_true', help='交互式模式') mode_group.add_argument('--quick', nargs=7, type=float, metavar='VALUE', help='快速预测模式 (7个数值)') mode_group.add_argument('--batch', type=str, metavar='FILE', help='批量预测模式 (输入文件)') # 输出选项 parser.add_argument('--output', '-o', type=str, help='输出文件路径 (批量模式)') parser.add_argument('--verbose', '-v', action='store_true', help='详细输出') parser.add_argument('--log-level', type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], default='WARNING', help='日志级别') args = parser.parse_args() # 设置日志 setup_logger(level=args.log_level) try: # 创建预测CLI cli = PredictCLI(args.model, args.preprocessor) # 显示模型信息 if args.verbose: print(f"模型信息:") print(f" 设备: {cli.model_info['device']}") print(f" 总参数量: {cli.model_info['total_parameters']:,}") print(f" 输入维度: {cli.model_info['input_dim']}") print(f" 输出维度: {cli.model_info['output_dim']}") # 根据模式执行 if args.interactive or (not args.quick and not args.batch): # 默认进入交互模式 cli.interactive_mode() elif args.quick: # 快速预测 success = cli.quick_predict(args.quick) if not success: sys.exit(1) elif args.batch: # 批量预测 cli.batch_predict(args.batch, args.output) except Exception as e: print(f"错误: {e}") if args.verbose: import traceback traceback.print_exc() sys.exit(1) if __name__ == "__main__": main()