Chordia / src /scripts /inference.py
Corolin's picture
first commit
0a6452f
"""
推理脚本
Inference Script for emotion and physiological state prediction
该脚本实现了完整的推理功能,支持:
- 单样本和批量推理
- 多种输入格式(JSON、CSV、命令行参数)
- 多种输出格式(JSON、CSV、文本)
- 输入数据验证和预处理
- 性能基准测试
"""
import argparse
import json
import csv
import sys
import os
import logging
import time
from pathlib import Path
from typing import List, Dict, Any, Union, Optional
import numpy as np
import pandas as pd
# 添加项目根目录到Python路径
project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))
from src.utils.inference_engine import InferenceEngine, create_inference_engine
from src.utils.logger import setup_logger
def parse_command_line_input(args: List[str]) -> np.ndarray:
"""
解析命令行输入数据
Args:
args: 命令行参数列表
Returns:
输入数据数组(7维,将被推理引擎增强到10维)
"""
if len(args) != 7:
raise ValueError(f"需要7个输入参数,但提供了{len(args)}个。参数顺序:user_pleasure, user_arousal, user_dominance, vitality, current_pleasure, current_arousal, current_dominance")
try:
data = np.array([float(arg) for arg in args], dtype=np.float32)
return data.reshape(1, -1)
except ValueError as e:
raise ValueError(f"输入参数必须是数字: {e}")
def load_json_input(input_path: str) -> np.ndarray:
"""
从JSON文件加载输入数据
Args:
input_path: JSON文件路径
Returns:
输入数据数组
"""
try:
with open(input_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# 处理不同的JSON格式
if isinstance(data, dict):
if 'data' in data:
# 格式: {"data": [[...], [...], ...]}
input_data = np.array(data['data'], dtype=np.float32)
elif 'features' in data:
# 格式: {"features": [[...], [...], ...]}
input_data = np.array(data['features'], dtype=np.float32)
else:
# 格式: {"user_pleasure": ..., "user_arousal": ..., ...}
single_sample = [
data.get('user_pleasure', 0),
data.get('user_arousal', 0),
data.get('user_dominance', 0),
data.get('vitality', 0),
data.get('current_pleasure', 0),
data.get('current_arousal', 0),
data.get('current_dominance', 0)
]
input_data = np.array(single_sample, dtype=np.float32).reshape(1, -1)
elif isinstance(data, list):
# 格式: [[...], [...], ...] 或 [...]
if len(data) > 0 and isinstance(data[0], list):
input_data = np.array(data, dtype=np.float32)
else:
input_data = np.array(data, dtype=np.float32).reshape(1, -1)
else:
raise ValueError("不支持的JSON格式")
# 验证数据维度
if input_data.ndim == 1:
input_data = input_data.reshape(1, -1)
elif input_data.ndim > 2:
raise ValueError("输入数据应该是1维或2维的")
if input_data.shape[1] != 7:
raise ValueError(f"输入数据应该有7个特征,但得到{input_data.shape[1]}个")
return input_data
except Exception as e:
raise ValueError(f"无法解析JSON文件 {input_path}: {e}")
def load_csv_input(input_path: str,
feature_columns: Optional[List[str]] = None) -> np.ndarray:
"""
从CSV文件加载输入数据
Args:
input_path: CSV文件路径
feature_columns: 特征列名列表
Returns:
输入数据数组
"""
try:
# 默认列名
default_columns = [
'user_pleasure', 'user_arousal', 'user_dominance',
'vitality', 'current_pleasure', 'current_arousal', 'current_dominance'
]
# 读取CSV文件
if feature_columns:
df = pd.read_csv(input_path, usecols=feature_columns)
else:
df = pd.read_csv(input_path)
# 自动检测列名
if len(df.columns) >= 7:
df = df.iloc[:, :7] # 使用前7列
df.columns = default_columns
elif len(df.columns) == 7:
df.columns = default_columns
else:
raise ValueError(f"CSV文件应该至少有7列,但得到{len(df.columns)}列")
# 转换为numpy数组
input_data = df.values.astype(np.float32)
return input_data
except Exception as e:
raise ValueError(f"无法解析CSV文件 {input_path}: {e}")
def save_json_output(results: List[Dict[str, Any]], output_path: str) -> None:
"""
保存结果为JSON格式
Args:
results: 推理结果列表
output_path: 输出路径
"""
output_data = {
'predictions': results,
'metadata': {
'total_samples': len(results),
'output_format': 'json',
'description': 'Emotion and physiological state prediction results'
}
}
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(output_data, f, indent=2, ensure_ascii=False)
def save_csv_output(results: List[Dict[str, Any]], output_path: str) -> None:
"""
保存结果为CSV格式
Args:
results: 推理结果列表
output_path: 输出路径
"""
if not results:
return
# 展开结果数据
rows = []
for i, result in enumerate(results):
row = {
'sample_id': i,
'delta_pleasure': result['delta_pad'][0],
'delta_arousal': result['delta_pad'][1],
'delta_dominance': result['delta_pad'][2],
'delta_pressure': result['delta_pressure'][0],
'confidence': result['confidence'][0],
'inference_time': result.get('inference_time', 0)
}
rows.append(row)
# 写入CSV文件
with open(output_path, 'w', newline='', encoding='utf-8') as f:
fieldnames = rows[0].keys()
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(rows)
def save_text_output(results: List[Dict[str, Any]], output_path: str) -> None:
"""
保存结果为文本格式
Args:
results: 推理结果列表
output_path: 输出路径
"""
with open(output_path, 'w', encoding='utf-8') as f:
f.write("情绪与生理状态变化预测结果\n")
f.write("=" * 50 + "\n\n")
for i, result in enumerate(results):
f.write(f"样本 {i+1}:\n")
f.write(f" ΔPAD (情绪变化):\n")
f.write(f" 快乐度变化: {result['delta_pad'][0]:.6f}\n")
f.write(f" 激活度变化: {result['delta_pad'][1]:.6f}\n")
f.write(f" 支配度变化: {result['delta_pad'][2]:.6f}\n")
f.write(f" Δ压力: {result['delta_pressure'][0]:.6f}\n")
f.write(f" 置信度: {result['confidence'][0]:.6f}\n")
f.write(f" 推理时间: {result.get('inference_time', 0):.6f}秒\n")
f.write("-" * 30 + "\n")
def print_results(results: List[Dict[str, Any]], verbose: bool = True) -> None:
"""
打印推理结果
Args:
results: 推理结果列表
verbose: 是否显示详细信息
"""
if not verbose:
# 简洁输出
for i, result in enumerate(results):
print(f"样本{i+1}: ΔPAD={result['delta_pad']}, Δ压力={result['delta_pressure'][0]:.4f}, 置信度={result['confidence'][0]:.4f}")
return
# 详细输出
print("\n情绪与生理状态变化预测结果")
print("=" * 60)
for i, result in enumerate(results):
print(f"\n样本 {i+1}:")
print(f" ΔPAD (情绪变化):")
print(f" 快乐度变化: {result['delta_pad'][0]:+.6f}")
print(f" 激活度变化: {result['delta_pad'][1]:+.6f}")
print(f" 支配度变化: {result['delta_pad'][2]:+.6f}")
print(f" Δ压力: {result['delta_pressure'][0]:+.6f}")
print(f" 置信度: {result['confidence'][0]:.6f}")
if 'inference_time' in result:
print(f" 推理时间: {result['inference_time']*1000:.2f}ms")
print("-" * 40)
def run_benchmark(engine: InferenceEngine,
num_samples: int = 1000,
batch_size: int = 32) -> None:
"""
运行性能基准测试
Args:
engine: 推理引擎
num_samples: 测试样本数量
batch_size: 批次大小
"""
print(f"\n运行性能基准测试...")
print(f"测试样本数: {num_samples}")
print(f"批次大小: {batch_size}")
try:
stats = engine.benchmark(num_samples, batch_size)
print("\n基准测试结果:")
print(f" 总样本数: {stats['total_samples']}")
print(f" 总时间: {stats['total_time']:.4f}秒")
print(f" 吞吐量: {stats['throughput']:.2f} 样本/秒")
print(f" 平均延迟: {stats['avg_latency']:.2f}ms")
print(f" 最小延迟: {stats['min_time']*1000:.2f}ms")
print(f" 最大延迟: {stats['max_time']*1000:.2f}ms")
print(f" P95延迟: {stats['p95_latency']:.2f}ms")
print(f" P99延迟: {stats['p99_latency']:.2f}ms")
except Exception as e:
print(f"基准测试失败: {e}")
def main():
"""主函数"""
parser = argparse.ArgumentParser(
description="情绪与生理状态变化预测推理脚本",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
输入格式说明:
1. 命令行参数: --input-cli 0.5 0.3 -0.2 80 0.1 0.4 -0.1
2. JSON文件: --input-json data.json
3. CSV文件: --input-csv data.csv
输出格式说明:
- JSON: 结构化数据,便于程序处理
- CSV: 表格数据,便于Excel处理
- TXT: 人类可读的文本格式
示例用法:
# 单样本推理
python inference.py --model model.pth --input-cli 0.5 0.3 -0.2 80 0.1 0.4 -0.1
# 批量推理
python inference.py --model model.pth --input-json batch_data.json --output-json results.json
# 基准测试
python inference.py --model model.pth --benchmark --num-samples 1000
"""
)
# 模型相关参数
parser.add_argument('--model', '-m', type=str, required=True,
help='模型文件路径 (.pth)')
parser.add_argument('--preprocessor', '-p', type=str,
help='预处理器文件路径')
parser.add_argument('--device', type=str, choices=['auto', 'cpu', 'cuda'],
default='auto', help='计算设备')
# 输入相关参数
input_group = parser.add_mutually_exclusive_group(required=True)
input_group.add_argument('--input-cli', nargs='+', metavar='VALUE',
help='命令行输入 (7个数值: user_pleasure user_arousal user_dominance vitality current_pleasure current_arousal current_dominance)')
input_group.add_argument('--input-json', type=str, metavar='FILE',
help='JSON输入文件路径')
input_group.add_argument('--input-csv', type=str, metavar='FILE',
help='CSV输入文件路径')
# 输出相关参数
parser.add_argument('--output-json', type=str, metavar='FILE',
help='JSON输出文件路径')
parser.add_argument('--output-csv', type=str, metavar='FILE',
help='CSV输出文件路径')
parser.add_argument('--output-txt', type=str, metavar='FILE',
help='文本输出文件路径')
parser.add_argument('--quiet', '-q', action='store_true',
help='静默模式,不打印结果')
# 推理参数
parser.add_argument('--batch-size', type=int, default=32,
help='批量推理的批次大小')
# 基准测试参数
parser.add_argument('--benchmark', action='store_true',
help='运行性能基准测试')
parser.add_argument('--num-samples', type=int, default=1000,
help='基准测试的样本数量')
# 其他参数
parser.add_argument('--verbose', '-v', action='store_true',
help='详细输出')
parser.add_argument('--log-level', type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'],
default='INFO', help='日志级别')
args = parser.parse_args()
# 设置日志
setup_logger(level=args.log_level)
logger = logging.getLogger(__name__)
try:
# 创建推理引擎
logger.info("初始化推理引擎...")
engine = create_inference_engine(
model_path=args.model,
preprocessor_path=args.preprocessor,
device=args.device
)
# 打印模型信息
if args.verbose:
model_info = engine.get_model_info()
print(f"\n模型信息:")
print(f" 设备: {model_info['device']}")
print(f" 总参数量: {model_info['total_parameters']:,}")
print(f" 输入维度: {model_info['input_dim']}")
print(f" 输出维度: {model_info['output_dim']}")
# 运行基准测试
if args.benchmark:
run_benchmark(engine, args.num_samples, args.batch_size)
return
# 加载输入数据
logger.info("加载输入数据...")
if args.input_cli:
input_data = parse_command_line_input(args.input_cli)
elif args.input_json:
input_data = load_json_input(args.input_json)
elif args.input_csv:
input_data = load_csv_input(args.input_csv)
logger.info(f"加载了 {len(input_data)} 个样本")
# 执行推理
logger.info("执行推理...")
start_time = time.time()
if len(input_data) == 1:
# 单样本推理
result = engine.predict(input_data[0])
results = [result.to_dict()]
else:
# 批量推理
results = engine.predict_batch(input_data, args.batch_size)
results = [result.to_dict() for result in results]
total_time = time.time() - start_time
logger.info(f"推理完成,总时间: {total_time:.4f}秒")
# 打印结果
if not args.quiet:
print_results(results, verbose=args.verbose)
# 保存结果
if args.output_json:
save_json_output(results, args.output_json)
print(f"结果已保存到: {args.output_json}")
if args.output_csv:
save_csv_output(results, args.output_csv)
print(f"结果已保存到: {args.output_csv}")
if args.output_txt:
save_text_output(results, args.output_txt)
print(f"结果已保存到: {args.output_txt}")
# 性能统计
if args.verbose:
stats = engine.get_performance_stats()
print(f"\n性能统计:")
print(f" 总推理次数: {stats['total_inferences']}")
print(f" 平均时间: {stats['avg_time']*1000:.2f}ms")
print(f" 最小时间: {stats['min_time']*1000:.2f}ms")
print(f" 最大时间: {stats['max_time']*1000:.2f}ms")
except Exception as e:
logger.error(f"推理失败: {e}")
if args.verbose:
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()