|
|
""" |
|
|
推理脚本 |
|
|
Inference Script for emotion and physiological state prediction |
|
|
|
|
|
该脚本实现了完整的推理功能,支持: |
|
|
- 单样本和批量推理 |
|
|
- 多种输入格式(JSON、CSV、命令行参数) |
|
|
- 多种输出格式(JSON、CSV、文本) |
|
|
- 输入数据验证和预处理 |
|
|
- 性能基准测试 |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import json |
|
|
import csv |
|
|
import sys |
|
|
import os |
|
|
import logging |
|
|
import time |
|
|
from pathlib import Path |
|
|
from typing import List, Dict, Any, Union, Optional |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
|
|
|
|
|
|
project_root = Path(__file__).parent.parent.parent |
|
|
sys.path.insert(0, str(project_root)) |
|
|
|
|
|
from src.utils.inference_engine import InferenceEngine, create_inference_engine |
|
|
from src.utils.logger import setup_logger |
|
|
|
|
|
|
|
|
def parse_command_line_input(args: List[str]) -> np.ndarray: |
|
|
""" |
|
|
解析命令行输入数据 |
|
|
|
|
|
Args: |
|
|
args: 命令行参数列表 |
|
|
|
|
|
Returns: |
|
|
输入数据数组(7维,将被推理引擎增强到10维) |
|
|
""" |
|
|
if len(args) != 7: |
|
|
raise ValueError(f"需要7个输入参数,但提供了{len(args)}个。参数顺序:user_pleasure, user_arousal, user_dominance, vitality, current_pleasure, current_arousal, current_dominance") |
|
|
|
|
|
try: |
|
|
data = np.array([float(arg) for arg in args], dtype=np.float32) |
|
|
return data.reshape(1, -1) |
|
|
except ValueError as e: |
|
|
raise ValueError(f"输入参数必须是数字: {e}") |
|
|
|
|
|
|
|
|
def load_json_input(input_path: str) -> np.ndarray: |
|
|
""" |
|
|
从JSON文件加载输入数据 |
|
|
|
|
|
Args: |
|
|
input_path: JSON文件路径 |
|
|
|
|
|
Returns: |
|
|
输入数据数组 |
|
|
""" |
|
|
try: |
|
|
with open(input_path, 'r', encoding='utf-8') as f: |
|
|
data = json.load(f) |
|
|
|
|
|
|
|
|
if isinstance(data, dict): |
|
|
if 'data' in data: |
|
|
|
|
|
input_data = np.array(data['data'], dtype=np.float32) |
|
|
elif 'features' in data: |
|
|
|
|
|
input_data = np.array(data['features'], dtype=np.float32) |
|
|
else: |
|
|
|
|
|
single_sample = [ |
|
|
data.get('user_pleasure', 0), |
|
|
data.get('user_arousal', 0), |
|
|
data.get('user_dominance', 0), |
|
|
data.get('vitality', 0), |
|
|
data.get('current_pleasure', 0), |
|
|
data.get('current_arousal', 0), |
|
|
data.get('current_dominance', 0) |
|
|
] |
|
|
input_data = np.array(single_sample, dtype=np.float32).reshape(1, -1) |
|
|
|
|
|
elif isinstance(data, list): |
|
|
|
|
|
if len(data) > 0 and isinstance(data[0], list): |
|
|
input_data = np.array(data, dtype=np.float32) |
|
|
else: |
|
|
input_data = np.array(data, dtype=np.float32).reshape(1, -1) |
|
|
|
|
|
else: |
|
|
raise ValueError("不支持的JSON格式") |
|
|
|
|
|
|
|
|
if input_data.ndim == 1: |
|
|
input_data = input_data.reshape(1, -1) |
|
|
elif input_data.ndim > 2: |
|
|
raise ValueError("输入数据应该是1维或2维的") |
|
|
|
|
|
if input_data.shape[1] != 7: |
|
|
raise ValueError(f"输入数据应该有7个特征,但得到{input_data.shape[1]}个") |
|
|
|
|
|
return input_data |
|
|
|
|
|
except Exception as e: |
|
|
raise ValueError(f"无法解析JSON文件 {input_path}: {e}") |
|
|
|
|
|
|
|
|
def load_csv_input(input_path: str, |
|
|
feature_columns: Optional[List[str]] = None) -> np.ndarray: |
|
|
""" |
|
|
从CSV文件加载输入数据 |
|
|
|
|
|
Args: |
|
|
input_path: CSV文件路径 |
|
|
feature_columns: 特征列名列表 |
|
|
|
|
|
Returns: |
|
|
输入数据数组 |
|
|
""" |
|
|
try: |
|
|
|
|
|
default_columns = [ |
|
|
'user_pleasure', 'user_arousal', 'user_dominance', |
|
|
'vitality', 'current_pleasure', 'current_arousal', 'current_dominance' |
|
|
] |
|
|
|
|
|
|
|
|
if feature_columns: |
|
|
df = pd.read_csv(input_path, usecols=feature_columns) |
|
|
else: |
|
|
df = pd.read_csv(input_path) |
|
|
|
|
|
|
|
|
if len(df.columns) >= 7: |
|
|
df = df.iloc[:, :7] |
|
|
df.columns = default_columns |
|
|
elif len(df.columns) == 7: |
|
|
df.columns = default_columns |
|
|
else: |
|
|
raise ValueError(f"CSV文件应该至少有7列,但得到{len(df.columns)}列") |
|
|
|
|
|
|
|
|
input_data = df.values.astype(np.float32) |
|
|
|
|
|
return input_data |
|
|
|
|
|
except Exception as e: |
|
|
raise ValueError(f"无法解析CSV文件 {input_path}: {e}") |
|
|
|
|
|
|
|
|
def save_json_output(results: List[Dict[str, Any]], output_path: str) -> None: |
|
|
""" |
|
|
保存结果为JSON格式 |
|
|
|
|
|
Args: |
|
|
results: 推理结果列表 |
|
|
output_path: 输出路径 |
|
|
""" |
|
|
output_data = { |
|
|
'predictions': results, |
|
|
'metadata': { |
|
|
'total_samples': len(results), |
|
|
'output_format': 'json', |
|
|
'description': 'Emotion and physiological state prediction results' |
|
|
} |
|
|
} |
|
|
|
|
|
with open(output_path, 'w', encoding='utf-8') as f: |
|
|
json.dump(output_data, f, indent=2, ensure_ascii=False) |
|
|
|
|
|
|
|
|
def save_csv_output(results: List[Dict[str, Any]], output_path: str) -> None: |
|
|
""" |
|
|
保存结果为CSV格式 |
|
|
|
|
|
Args: |
|
|
results: 推理结果列表 |
|
|
output_path: 输出路径 |
|
|
""" |
|
|
if not results: |
|
|
return |
|
|
|
|
|
|
|
|
rows = [] |
|
|
for i, result in enumerate(results): |
|
|
row = { |
|
|
'sample_id': i, |
|
|
'delta_pleasure': result['delta_pad'][0], |
|
|
'delta_arousal': result['delta_pad'][1], |
|
|
'delta_dominance': result['delta_pad'][2], |
|
|
'delta_pressure': result['delta_pressure'][0], |
|
|
'confidence': result['confidence'][0], |
|
|
'inference_time': result.get('inference_time', 0) |
|
|
} |
|
|
rows.append(row) |
|
|
|
|
|
|
|
|
with open(output_path, 'w', newline='', encoding='utf-8') as f: |
|
|
fieldnames = rows[0].keys() |
|
|
writer = csv.DictWriter(f, fieldnames=fieldnames) |
|
|
writer.writeheader() |
|
|
writer.writerows(rows) |
|
|
|
|
|
|
|
|
def save_text_output(results: List[Dict[str, Any]], output_path: str) -> None: |
|
|
""" |
|
|
保存结果为文本格式 |
|
|
|
|
|
Args: |
|
|
results: 推理结果列表 |
|
|
output_path: 输出路径 |
|
|
""" |
|
|
with open(output_path, 'w', encoding='utf-8') as f: |
|
|
f.write("情绪与生理状态变化预测结果\n") |
|
|
f.write("=" * 50 + "\n\n") |
|
|
|
|
|
for i, result in enumerate(results): |
|
|
f.write(f"样本 {i+1}:\n") |
|
|
f.write(f" ΔPAD (情绪变化):\n") |
|
|
f.write(f" 快乐度变化: {result['delta_pad'][0]:.6f}\n") |
|
|
f.write(f" 激活度变化: {result['delta_pad'][1]:.6f}\n") |
|
|
f.write(f" 支配度变化: {result['delta_pad'][2]:.6f}\n") |
|
|
f.write(f" Δ压力: {result['delta_pressure'][0]:.6f}\n") |
|
|
f.write(f" 置信度: {result['confidence'][0]:.6f}\n") |
|
|
f.write(f" 推理时间: {result.get('inference_time', 0):.6f}秒\n") |
|
|
f.write("-" * 30 + "\n") |
|
|
|
|
|
|
|
|
def print_results(results: List[Dict[str, Any]], verbose: bool = True) -> None: |
|
|
""" |
|
|
打印推理结果 |
|
|
|
|
|
Args: |
|
|
results: 推理结果列表 |
|
|
verbose: 是否显示详细信息 |
|
|
""" |
|
|
if not verbose: |
|
|
|
|
|
for i, result in enumerate(results): |
|
|
print(f"样本{i+1}: ΔPAD={result['delta_pad']}, Δ压力={result['delta_pressure'][0]:.4f}, 置信度={result['confidence'][0]:.4f}") |
|
|
return |
|
|
|
|
|
|
|
|
print("\n情绪与生理状态变化预测结果") |
|
|
print("=" * 60) |
|
|
|
|
|
for i, result in enumerate(results): |
|
|
print(f"\n样本 {i+1}:") |
|
|
print(f" ΔPAD (情绪变化):") |
|
|
print(f" 快乐度变化: {result['delta_pad'][0]:+.6f}") |
|
|
print(f" 激活度变化: {result['delta_pad'][1]:+.6f}") |
|
|
print(f" 支配度变化: {result['delta_pad'][2]:+.6f}") |
|
|
print(f" Δ压力: {result['delta_pressure'][0]:+.6f}") |
|
|
print(f" 置信度: {result['confidence'][0]:.6f}") |
|
|
if 'inference_time' in result: |
|
|
print(f" 推理时间: {result['inference_time']*1000:.2f}ms") |
|
|
print("-" * 40) |
|
|
|
|
|
|
|
|
def run_benchmark(engine: InferenceEngine, |
|
|
num_samples: int = 1000, |
|
|
batch_size: int = 32) -> None: |
|
|
""" |
|
|
运行性能基准测试 |
|
|
|
|
|
Args: |
|
|
engine: 推理引擎 |
|
|
num_samples: 测试样本数量 |
|
|
batch_size: 批次大小 |
|
|
""" |
|
|
print(f"\n运行性能基准测试...") |
|
|
print(f"测试样本数: {num_samples}") |
|
|
print(f"批次大小: {batch_size}") |
|
|
|
|
|
try: |
|
|
stats = engine.benchmark(num_samples, batch_size) |
|
|
|
|
|
print("\n基准测试结果:") |
|
|
print(f" 总样本数: {stats['total_samples']}") |
|
|
print(f" 总时间: {stats['total_time']:.4f}秒") |
|
|
print(f" 吞吐量: {stats['throughput']:.2f} 样本/秒") |
|
|
print(f" 平均延迟: {stats['avg_latency']:.2f}ms") |
|
|
print(f" 最小延迟: {stats['min_time']*1000:.2f}ms") |
|
|
print(f" 最大延迟: {stats['max_time']*1000:.2f}ms") |
|
|
print(f" P95延迟: {stats['p95_latency']:.2f}ms") |
|
|
print(f" P99延迟: {stats['p99_latency']:.2f}ms") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"基准测试失败: {e}") |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""主函数""" |
|
|
parser = argparse.ArgumentParser( |
|
|
description="情绪与生理状态变化预测推理脚本", |
|
|
formatter_class=argparse.RawDescriptionHelpFormatter, |
|
|
epilog=""" |
|
|
输入格式说明: |
|
|
1. 命令行参数: --input-cli 0.5 0.3 -0.2 80 0.1 0.4 -0.1 |
|
|
2. JSON文件: --input-json data.json |
|
|
3. CSV文件: --input-csv data.csv |
|
|
|
|
|
输出格式说明: |
|
|
- JSON: 结构化数据,便于程序处理 |
|
|
- CSV: 表格数据,便于Excel处理 |
|
|
- TXT: 人类可读的文本格式 |
|
|
|
|
|
示例用法: |
|
|
# 单样本推理 |
|
|
python inference.py --model model.pth --input-cli 0.5 0.3 -0.2 80 0.1 0.4 -0.1 |
|
|
|
|
|
# 批量推理 |
|
|
python inference.py --model model.pth --input-json batch_data.json --output-json results.json |
|
|
|
|
|
# 基准测试 |
|
|
python inference.py --model model.pth --benchmark --num-samples 1000 |
|
|
""" |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument('--model', '-m', type=str, required=True, |
|
|
help='模型文件路径 (.pth)') |
|
|
parser.add_argument('--preprocessor', '-p', type=str, |
|
|
help='预处理器文件路径') |
|
|
parser.add_argument('--device', type=str, choices=['auto', 'cpu', 'cuda'], |
|
|
default='auto', help='计算设备') |
|
|
|
|
|
|
|
|
input_group = parser.add_mutually_exclusive_group(required=True) |
|
|
input_group.add_argument('--input-cli', nargs='+', metavar='VALUE', |
|
|
help='命令行输入 (7个数值: user_pleasure user_arousal user_dominance vitality current_pleasure current_arousal current_dominance)') |
|
|
input_group.add_argument('--input-json', type=str, metavar='FILE', |
|
|
help='JSON输入文件路径') |
|
|
input_group.add_argument('--input-csv', type=str, metavar='FILE', |
|
|
help='CSV输入文件路径') |
|
|
|
|
|
|
|
|
parser.add_argument('--output-json', type=str, metavar='FILE', |
|
|
help='JSON输出文件路径') |
|
|
parser.add_argument('--output-csv', type=str, metavar='FILE', |
|
|
help='CSV输出文件路径') |
|
|
parser.add_argument('--output-txt', type=str, metavar='FILE', |
|
|
help='文本输出文件路径') |
|
|
parser.add_argument('--quiet', '-q', action='store_true', |
|
|
help='静默模式,不打印结果') |
|
|
|
|
|
|
|
|
parser.add_argument('--batch-size', type=int, default=32, |
|
|
help='批量推理的批次大小') |
|
|
|
|
|
|
|
|
parser.add_argument('--benchmark', action='store_true', |
|
|
help='运行性能基准测试') |
|
|
parser.add_argument('--num-samples', type=int, default=1000, |
|
|
help='基准测试的样本数量') |
|
|
|
|
|
|
|
|
parser.add_argument('--verbose', '-v', action='store_true', |
|
|
help='详细输出') |
|
|
parser.add_argument('--log-level', type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], |
|
|
default='INFO', help='日志级别') |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
setup_logger(level=args.log_level) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
try: |
|
|
|
|
|
logger.info("初始化推理引擎...") |
|
|
engine = create_inference_engine( |
|
|
model_path=args.model, |
|
|
preprocessor_path=args.preprocessor, |
|
|
device=args.device |
|
|
) |
|
|
|
|
|
|
|
|
if args.verbose: |
|
|
model_info = engine.get_model_info() |
|
|
print(f"\n模型信息:") |
|
|
print(f" 设备: {model_info['device']}") |
|
|
print(f" 总参数量: {model_info['total_parameters']:,}") |
|
|
print(f" 输入维度: {model_info['input_dim']}") |
|
|
print(f" 输出维度: {model_info['output_dim']}") |
|
|
|
|
|
|
|
|
if args.benchmark: |
|
|
run_benchmark(engine, args.num_samples, args.batch_size) |
|
|
return |
|
|
|
|
|
|
|
|
logger.info("加载输入数据...") |
|
|
if args.input_cli: |
|
|
input_data = parse_command_line_input(args.input_cli) |
|
|
elif args.input_json: |
|
|
input_data = load_json_input(args.input_json) |
|
|
elif args.input_csv: |
|
|
input_data = load_csv_input(args.input_csv) |
|
|
|
|
|
logger.info(f"加载了 {len(input_data)} 个样本") |
|
|
|
|
|
|
|
|
logger.info("执行推理...") |
|
|
start_time = time.time() |
|
|
|
|
|
if len(input_data) == 1: |
|
|
|
|
|
result = engine.predict(input_data[0]) |
|
|
results = [result.to_dict()] |
|
|
else: |
|
|
|
|
|
results = engine.predict_batch(input_data, args.batch_size) |
|
|
results = [result.to_dict() for result in results] |
|
|
|
|
|
total_time = time.time() - start_time |
|
|
|
|
|
logger.info(f"推理完成,总时间: {total_time:.4f}秒") |
|
|
|
|
|
|
|
|
if not args.quiet: |
|
|
print_results(results, verbose=args.verbose) |
|
|
|
|
|
|
|
|
if args.output_json: |
|
|
save_json_output(results, args.output_json) |
|
|
print(f"结果已保存到: {args.output_json}") |
|
|
|
|
|
if args.output_csv: |
|
|
save_csv_output(results, args.output_csv) |
|
|
print(f"结果已保存到: {args.output_csv}") |
|
|
|
|
|
if args.output_txt: |
|
|
save_text_output(results, args.output_txt) |
|
|
print(f"结果已保存到: {args.output_txt}") |
|
|
|
|
|
|
|
|
if args.verbose: |
|
|
stats = engine.get_performance_stats() |
|
|
print(f"\n性能统计:") |
|
|
print(f" 总推理次数: {stats['total_inferences']}") |
|
|
print(f" 平均时间: {stats['avg_time']*1000:.2f}ms") |
|
|
print(f" 最小时间: {stats['min_time']*1000:.2f}ms") |
|
|
print(f" 最大时间: {stats['max_time']*1000:.2f}ms") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"推理失败: {e}") |
|
|
if args.verbose: |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |