|
|
""" |
|
|
简化的预测CLI工具 |
|
|
Simplified CLI Tool for emotion and physiological state prediction |
|
|
|
|
|
该工具提供了简化的命令行界面,支持: |
|
|
- 交互式输入 |
|
|
- 批量文件处理 |
|
|
- 清晰的输出格式和解释 |
|
|
- 快速预测模式 |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import sys |
|
|
import os |
|
|
import json |
|
|
import csv |
|
|
import logging |
|
|
from pathlib import Path |
|
|
from typing import List, Dict, Any, Optional |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
project_root = Path(__file__).parent.parent.parent |
|
|
sys.path.insert(0, str(project_root)) |
|
|
|
|
|
from src.utils.inference_engine import create_inference_engine |
|
|
from src.utils.logger import setup_logger |
|
|
|
|
|
|
|
|
class PredictCLI: |
|
|
"""简化的预测CLI类""" |
|
|
|
|
|
def __init__(self, model_path: str, preprocessor_path: Optional[str] = None): |
|
|
""" |
|
|
初始化预测CLI |
|
|
|
|
|
Args: |
|
|
model_path: 模型文件路径 |
|
|
preprocessor_path: 预处理器文件路径 |
|
|
""" |
|
|
self.logger = logging.getLogger(__name__) |
|
|
|
|
|
try: |
|
|
|
|
|
self.engine = create_inference_engine( |
|
|
model_path=model_path, |
|
|
preprocessor_path=preprocessor_path, |
|
|
device='auto' |
|
|
) |
|
|
|
|
|
|
|
|
self.model_info = self.engine.get_model_info() |
|
|
self.logger.info("预测CLI初始化成功") |
|
|
|
|
|
except Exception as e: |
|
|
self.logger.error(f"预测CLI初始化失败: {e}") |
|
|
raise |
|
|
|
|
|
def interactive_mode(self): |
|
|
"""交互式预测模式""" |
|
|
print("\n" + "="*60) |
|
|
print("情绪与生理状态变化预测工具 - 交互式模式") |
|
|
print("="*60) |
|
|
print("\n请输入以下7个参数:") |
|
|
print("1. 用户快乐度 (User Pleasure): [-1.0, 1.0]") |
|
|
print("2. 用户激活度 (User Arousal): [-1.0, 1.0]") |
|
|
print("3. 用户支配度 (User Dominance): [-1.0, 1.0]") |
|
|
print("4. 活力值 (Vitality): [0.0, 100.0]") |
|
|
print("5. 当前快乐度 (Current Pleasure): [-1.0, 1.0]") |
|
|
print("6. 当前激活度 (Current Arousal): [-1.0, 1.0]") |
|
|
print("7. 当前支配度 (Current Dominance): [-1.0, 1.0]") |
|
|
print("\n输入 'quit' 退出交互模式") |
|
|
print("-"*60) |
|
|
|
|
|
while True: |
|
|
try: |
|
|
print("\n请输入7个数值 (用空格分隔):") |
|
|
user_input = input("> ").strip() |
|
|
|
|
|
if user_input.lower() in ['quit', 'exit', 'q']: |
|
|
print("退出交互模式") |
|
|
break |
|
|
|
|
|
|
|
|
values = user_input.split() |
|
|
if len(values) != 7: |
|
|
print(f"错误: 需要输入7个数值,但得到{len(values)}个") |
|
|
continue |
|
|
|
|
|
try: |
|
|
input_data = np.array([float(v) for v in values], dtype=np.float32) |
|
|
except ValueError: |
|
|
print("错误: 输入必须是数字") |
|
|
continue |
|
|
|
|
|
|
|
|
if not self._validate_input_ranges(input_data): |
|
|
continue |
|
|
|
|
|
|
|
|
result = self.predict_single(input_data) |
|
|
|
|
|
|
|
|
self._display_result(result, input_data) |
|
|
|
|
|
except KeyboardInterrupt: |
|
|
print("\n\n用户中断,退出交互模式") |
|
|
break |
|
|
except Exception as e: |
|
|
print(f"预测出错: {e}") |
|
|
|
|
|
def _validate_input_ranges(self, input_data: np.ndarray) -> bool: |
|
|
"""验证输入数据范围""" |
|
|
user_pad = input_data[:3] |
|
|
vitality = input_data[3] |
|
|
current_pad = input_data[4:] |
|
|
|
|
|
|
|
|
if np.any(np.abs(user_pad) > 1.5): |
|
|
print("警告: 用户PAD值超出正常范围 [-1.0, 1.0]") |
|
|
response = input("是否继续? (y/n): ").strip().lower() |
|
|
if response != 'y': |
|
|
return False |
|
|
|
|
|
if np.any(np.abs(current_pad) > 1.5): |
|
|
print("警告: 当前PAD值超出正常范围 [-1.0, 1.0]") |
|
|
response = input("是否继续? (y/n): ").strip().lower() |
|
|
if response != 'y': |
|
|
return False |
|
|
|
|
|
|
|
|
if not (0 <= vitality <= 150): |
|
|
print("警告: 活力值超出正常范围 [0.0, 100.0]") |
|
|
response = input("是否继续? (y/n): ").strip().lower() |
|
|
if response != 'y': |
|
|
return False |
|
|
|
|
|
return True |
|
|
|
|
|
def predict_single(self, input_data: np.ndarray): |
|
|
"""预测单个样本""" |
|
|
try: |
|
|
result = self.engine.predict(input_data) |
|
|
return result |
|
|
except Exception as e: |
|
|
raise RuntimeError(f"预测失败: {e}") |
|
|
|
|
|
def _display_result(self, result, input_data: np.ndarray): |
|
|
"""显示预测结果""" |
|
|
print("\n" + "="*50) |
|
|
print("预测结果") |
|
|
print("="*50) |
|
|
|
|
|
|
|
|
print(f"\n输入信息:") |
|
|
print(f" 用户PAD: 快乐度={input_data[0]:+.3f}, 激活度={input_data[1]:+.3f}, 支配度={input_data[2]:+.3f}") |
|
|
print(f" 活力值: {input_data[3]:.1f}") |
|
|
print(f" 当前PAD: 快乐度={input_data[4]:+.3f}, 激活度={input_data[5]:+.3f}, 支配度={input_data[6]:+.3f}") |
|
|
|
|
|
|
|
|
delta_pad = result.delta_pad[0] |
|
|
delta_pressure = result.delta_pressure[0] |
|
|
confidence = result.confidence[0] |
|
|
|
|
|
print(f"\n预测变化:") |
|
|
print(f" 情绪变化 (ΔPAD):") |
|
|
print(f" 快乐度变化: {delta_pad:+.6f} {'↗' if delta_pad > 0 else '↘' if delta_pad < 0 else '→'}") |
|
|
print(f" 激活度变化: {delta_pad:+.6f} {'↗' if delta_pad > 0 else '↘' if delta_pad < 0 else '→'}") |
|
|
print(f" 支配度变化: {delta_pad:+.6f} {'↗' if delta_pad > 0 else '↘' if delta_pad < 0 else '→'}") |
|
|
print(f" 压力变化: {delta_pressure:+.6f} {'↗' if delta_pressure > 0 else '↘' if delta_pressure < 0 else '→'}") |
|
|
print(f" 预测置信度: {confidence:.6f} ({confidence*100:.1f}%)") |
|
|
|
|
|
|
|
|
self._provide_interpretation(delta_pad, delta_pressure, confidence) |
|
|
|
|
|
|
|
|
print(f"\n性能信息:") |
|
|
print(f" 推理时间: {result.inference_time*1000:.2f}ms") |
|
|
|
|
|
print("="*50) |
|
|
|
|
|
def _provide_interpretation(self, delta_pad: np.ndarray, delta_pressure: float, confidence: float): |
|
|
"""提供预测结果解释""" |
|
|
print(f"\n结果解释:") |
|
|
|
|
|
|
|
|
pleasure_change = delta_pad[0] |
|
|
arousal_change = delta_pad[1] |
|
|
dominance_change = delta_pad[2] |
|
|
|
|
|
if abs(pleasure_change) > 0.1: |
|
|
if pleasure_change > 0: |
|
|
print(" • 情绪趋向积极愉快") |
|
|
else: |
|
|
print(" • 情绪趋向消极低落") |
|
|
|
|
|
if abs(arousal_change) > 0.1: |
|
|
if arousal_change > 0: |
|
|
print(" • 激活度提升,趋向兴奋") |
|
|
else: |
|
|
print(" • 激活度降低,趋向平静") |
|
|
|
|
|
if abs(dominance_change) > 0.1: |
|
|
if dominance_change > 0: |
|
|
print(" • 支配感增强,趋向自信") |
|
|
else: |
|
|
print(" • 支配感减弱,趋向顺从") |
|
|
|
|
|
|
|
|
if abs(delta_pressure) > 0.05: |
|
|
if delta_pressure > 0: |
|
|
print(" • 压力水平可能上升") |
|
|
else: |
|
|
print(" • 压力水平可能下降") |
|
|
|
|
|
|
|
|
if confidence > 0.8: |
|
|
print(" • 预测置信度很高") |
|
|
elif confidence > 0.6: |
|
|
print(" • 预测置信度中等") |
|
|
else: |
|
|
print(" • 预测置信度较低,结果可能不太准确") |
|
|
|
|
|
def batch_predict(self, input_file: str, output_file: Optional[str] = None): |
|
|
"""批量预测""" |
|
|
print(f"\n批量预测模式") |
|
|
print(f"输入文件: {input_file}") |
|
|
|
|
|
try: |
|
|
|
|
|
input_data = self._load_batch_input(input_file) |
|
|
print(f"加载了 {len(input_data)} 个样本") |
|
|
|
|
|
|
|
|
print("执行批量预测...") |
|
|
results = self.engine.predict_batch(input_data) |
|
|
|
|
|
|
|
|
processed_results = [] |
|
|
for i, result in enumerate(results): |
|
|
processed_results.append({ |
|
|
'sample_id': i + 1, |
|
|
'delta_pleasure': float(result.delta_pad[0]), |
|
|
'delta_arousal': float(result.delta_pad[1]), |
|
|
'delta_dominance': float(result.delta_pad[2]), |
|
|
'delta_pressure': float(result.delta_pressure[0]), |
|
|
'confidence': float(result.confidence[0]), |
|
|
'inference_time': float(result.inference_time) |
|
|
}) |
|
|
|
|
|
|
|
|
if output_file: |
|
|
self._save_batch_results(processed_results, output_file) |
|
|
print(f"结果已保存到: {output_file}") |
|
|
else: |
|
|
self._display_batch_summary(processed_results) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"批量预测失败: {e}") |
|
|
raise |
|
|
|
|
|
def _load_batch_input(self, input_file: str) -> np.ndarray: |
|
|
"""加载批量输入数据""" |
|
|
file_path = Path(input_file) |
|
|
|
|
|
if not file_path.exists(): |
|
|
raise FileNotFoundError(f"输入文件不存在: {input_file}") |
|
|
|
|
|
if file_path.suffix.lower() == '.json': |
|
|
return self._load_json_batch(file_path) |
|
|
elif file_path.suffix.lower() == '.csv': |
|
|
return self._load_csv_batch(file_path) |
|
|
else: |
|
|
raise ValueError(f"不支持的文件格式: {file_path.suffix}") |
|
|
|
|
|
def _load_json_batch(self, file_path: Path) -> np.ndarray: |
|
|
"""加载JSON格式的批量数据""" |
|
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
|
data = json.load(f) |
|
|
|
|
|
if isinstance(data, list): |
|
|
return np.array(data, dtype=np.float32) |
|
|
elif isinstance(data, dict) and 'data' in data: |
|
|
return np.array(data['data'], dtype=np.float32) |
|
|
else: |
|
|
raise ValueError("JSON格式不正确,需要数据数组或包含'data'字段的对象") |
|
|
|
|
|
def _load_csv_batch(self, file_path: Path) -> np.ndarray: |
|
|
"""加载CSV格式的批量数据""" |
|
|
import pandas as pd |
|
|
|
|
|
df = pd.read_csv(file_path) |
|
|
|
|
|
|
|
|
if len(df.columns) < 7: |
|
|
raise ValueError(f"CSV文件至少需要7列,但得到{len(df.columns)}列") |
|
|
|
|
|
|
|
|
data = df.iloc[:, :7].values |
|
|
return data.astype(np.float32) |
|
|
|
|
|
def _save_batch_results(self, results: List[Dict[str, Any]], output_file: str): |
|
|
"""保存批量结果""" |
|
|
file_path = Path(output_file) |
|
|
|
|
|
if file_path.suffix.lower() == '.json': |
|
|
with open(file_path, 'w', encoding='utf-8') as f: |
|
|
json.dump({ |
|
|
'results': results, |
|
|
'summary': { |
|
|
'total_samples': len(results), |
|
|
'avg_confidence': np.mean([r['confidence'] for r in results]), |
|
|
'avg_inference_time': np.mean([r['inference_time'] for r in results]) |
|
|
} |
|
|
}, f, indent=2, ensure_ascii=False) |
|
|
|
|
|
elif file_path.suffix.lower() == '.csv': |
|
|
with open(file_path, 'w', newline='', encoding='utf-8') as f: |
|
|
if results: |
|
|
fieldnames = results[0].keys() |
|
|
writer = csv.DictWriter(f, fieldnames=fieldnames) |
|
|
writer.writeheader() |
|
|
writer.writerows(results) |
|
|
|
|
|
else: |
|
|
raise ValueError(f"不支持的输出格式: {file_path.suffix}") |
|
|
|
|
|
def _display_batch_summary(self, results: List[Dict[str, Any]]): |
|
|
"""显示批量预测摘要""" |
|
|
if not results: |
|
|
print("没有预测结果") |
|
|
return |
|
|
|
|
|
print(f"\n批量预测摘要:") |
|
|
print(f"总样本数: {len(results)}") |
|
|
|
|
|
|
|
|
confidences = [r['confidence'] for r in results] |
|
|
inference_times = [r['inference_time'] for r in results] |
|
|
|
|
|
print(f"平均置信度: {np.mean(confidences):.4f}") |
|
|
print(f"置信度范围: [{np.min(confidences):.4f}, {np.max(confidences):.4f}]") |
|
|
print(f"平均推理时间: {np.mean(inference_times)*1000:.2f}ms") |
|
|
print(f"总推理时间: {np.sum(inference_times):.4f}s") |
|
|
|
|
|
|
|
|
print(f"\n前5个样本结果:") |
|
|
for i, result in enumerate(results[:5]): |
|
|
print(f"样本{result['sample_id']}: " |
|
|
f"ΔPAD=[{result['delta_pleasure']:+.3f}, {result['delta_arousal']:+.3f}, {result['delta_dominance']:+.3f}], " |
|
|
f"Δ压力={result['delta_pressure']:+.3f}, " |
|
|
f"置信度={result['confidence']:.3f}") |
|
|
|
|
|
def quick_predict(self, values: List[float]): |
|
|
"""快速预测模式""" |
|
|
if len(values) != 7: |
|
|
raise ValueError(f"需要7个输入值,但得到{len(values)}个") |
|
|
|
|
|
input_data = np.array(values, dtype=np.float32) |
|
|
|
|
|
try: |
|
|
result = self.predict_single(input_data) |
|
|
|
|
|
|
|
|
delta_pad = result.delta_pad[0] |
|
|
delta_pressure = result.delta_pressure[0] |
|
|
confidence = result.confidence[0] |
|
|
|
|
|
print(f"ΔPAD: [{delta_pad[0]:+.4f}, {delta_pad[1]:+.4f}, {delta_pad[2]:+.4f}], " |
|
|
f"Δ压力: {delta_pressure:+.4f}, " |
|
|
f"置信度: {confidence:.4f}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"预测失败: {e}") |
|
|
return False |
|
|
|
|
|
return True |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""主函数""" |
|
|
parser = argparse.ArgumentParser( |
|
|
description="情绪与生理状态变化预测工具", |
|
|
formatter_class=argparse.RawDescriptionHelpFormatter, |
|
|
epilog=""" |
|
|
使用示例: |
|
|
# 交互式模式 |
|
|
python predict.py --model model.pth |
|
|
|
|
|
# 快速预测 |
|
|
python predict.py --model model.pth --quick 0.5 0.3 -0.2 80 0.1 0.4 -0.1 |
|
|
|
|
|
# 批量预测 |
|
|
python predict.py --model model.pth --batch input.json --output results.json |
|
|
""" |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument('--model', '-m', type=str, required=True, |
|
|
help='模型文件路径 (.pth)') |
|
|
parser.add_argument('--preprocessor', '-p', type=str, |
|
|
help='预处理器文件路径') |
|
|
|
|
|
|
|
|
mode_group = parser.add_mutually_exclusive_group() |
|
|
mode_group.add_argument('--interactive', '-i', action='store_true', |
|
|
help='交互式模式') |
|
|
mode_group.add_argument('--quick', nargs=7, type=float, metavar='VALUE', |
|
|
help='快速预测模式 (7个数值)') |
|
|
mode_group.add_argument('--batch', type=str, metavar='FILE', |
|
|
help='批量预测模式 (输入文件)') |
|
|
|
|
|
|
|
|
parser.add_argument('--output', '-o', type=str, |
|
|
help='输出文件路径 (批量模式)') |
|
|
parser.add_argument('--verbose', '-v', action='store_true', |
|
|
help='详细输出') |
|
|
parser.add_argument('--log-level', type=str, |
|
|
choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], |
|
|
default='WARNING', help='日志级别') |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
setup_logger(level=args.log_level) |
|
|
|
|
|
try: |
|
|
|
|
|
cli = PredictCLI(args.model, args.preprocessor) |
|
|
|
|
|
|
|
|
if args.verbose: |
|
|
print(f"模型信息:") |
|
|
print(f" 设备: {cli.model_info['device']}") |
|
|
print(f" 总参数量: {cli.model_info['total_parameters']:,}") |
|
|
print(f" 输入维度: {cli.model_info['input_dim']}") |
|
|
print(f" 输出维度: {cli.model_info['output_dim']}") |
|
|
|
|
|
|
|
|
if args.interactive or (not args.quick and not args.batch): |
|
|
|
|
|
cli.interactive_mode() |
|
|
|
|
|
elif args.quick: |
|
|
|
|
|
success = cli.quick_predict(args.quick) |
|
|
if not success: |
|
|
sys.exit(1) |
|
|
|
|
|
elif args.batch: |
|
|
|
|
|
cli.batch_predict(args.batch, args.output) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"错误: {e}") |
|
|
if args.verbose: |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |