File size: 16,143 Bytes
0a6452f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 |
"""
推理脚本
Inference Script for emotion and physiological state prediction
该脚本实现了完整的推理功能,支持:
- 单样本和批量推理
- 多种输入格式(JSON、CSV、命令行参数)
- 多种输出格式(JSON、CSV、文本)
- 输入数据验证和预处理
- 性能基准测试
"""
import argparse
import json
import csv
import sys
import os
import logging
import time
from pathlib import Path
from typing import List, Dict, Any, Union, Optional
import numpy as np
import pandas as pd
# 添加项目根目录到Python路径
project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))
from src.utils.inference_engine import InferenceEngine, create_inference_engine
from src.utils.logger import setup_logger
def parse_command_line_input(args: List[str]) -> np.ndarray:
"""
解析命令行输入数据
Args:
args: 命令行参数列表
Returns:
输入数据数组(7维,将被推理引擎增强到10维)
"""
if len(args) != 7:
raise ValueError(f"需要7个输入参数,但提供了{len(args)}个。参数顺序:user_pleasure, user_arousal, user_dominance, vitality, current_pleasure, current_arousal, current_dominance")
try:
data = np.array([float(arg) for arg in args], dtype=np.float32)
return data.reshape(1, -1)
except ValueError as e:
raise ValueError(f"输入参数必须是数字: {e}")
def load_json_input(input_path: str) -> np.ndarray:
"""
从JSON文件加载输入数据
Args:
input_path: JSON文件路径
Returns:
输入数据数组
"""
try:
with open(input_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# 处理不同的JSON格式
if isinstance(data, dict):
if 'data' in data:
# 格式: {"data": [[...], [...], ...]}
input_data = np.array(data['data'], dtype=np.float32)
elif 'features' in data:
# 格式: {"features": [[...], [...], ...]}
input_data = np.array(data['features'], dtype=np.float32)
else:
# 格式: {"user_pleasure": ..., "user_arousal": ..., ...}
single_sample = [
data.get('user_pleasure', 0),
data.get('user_arousal', 0),
data.get('user_dominance', 0),
data.get('vitality', 0),
data.get('current_pleasure', 0),
data.get('current_arousal', 0),
data.get('current_dominance', 0)
]
input_data = np.array(single_sample, dtype=np.float32).reshape(1, -1)
elif isinstance(data, list):
# 格式: [[...], [...], ...] 或 [...]
if len(data) > 0 and isinstance(data[0], list):
input_data = np.array(data, dtype=np.float32)
else:
input_data = np.array(data, dtype=np.float32).reshape(1, -1)
else:
raise ValueError("不支持的JSON格式")
# 验证数据维度
if input_data.ndim == 1:
input_data = input_data.reshape(1, -1)
elif input_data.ndim > 2:
raise ValueError("输入数据应该是1维或2维的")
if input_data.shape[1] != 7:
raise ValueError(f"输入数据应该有7个特征,但得到{input_data.shape[1]}个")
return input_data
except Exception as e:
raise ValueError(f"无法解析JSON文件 {input_path}: {e}")
def load_csv_input(input_path: str,
feature_columns: Optional[List[str]] = None) -> np.ndarray:
"""
从CSV文件加载输入数据
Args:
input_path: CSV文件路径
feature_columns: 特征列名列表
Returns:
输入数据数组
"""
try:
# 默认列名
default_columns = [
'user_pleasure', 'user_arousal', 'user_dominance',
'vitality', 'current_pleasure', 'current_arousal', 'current_dominance'
]
# 读取CSV文件
if feature_columns:
df = pd.read_csv(input_path, usecols=feature_columns)
else:
df = pd.read_csv(input_path)
# 自动检测列名
if len(df.columns) >= 7:
df = df.iloc[:, :7] # 使用前7列
df.columns = default_columns
elif len(df.columns) == 7:
df.columns = default_columns
else:
raise ValueError(f"CSV文件应该至少有7列,但得到{len(df.columns)}列")
# 转换为numpy数组
input_data = df.values.astype(np.float32)
return input_data
except Exception as e:
raise ValueError(f"无法解析CSV文件 {input_path}: {e}")
def save_json_output(results: List[Dict[str, Any]], output_path: str) -> None:
"""
保存结果为JSON格式
Args:
results: 推理结果列表
output_path: 输出路径
"""
output_data = {
'predictions': results,
'metadata': {
'total_samples': len(results),
'output_format': 'json',
'description': 'Emotion and physiological state prediction results'
}
}
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(output_data, f, indent=2, ensure_ascii=False)
def save_csv_output(results: List[Dict[str, Any]], output_path: str) -> None:
"""
保存结果为CSV格式
Args:
results: 推理结果列表
output_path: 输出路径
"""
if not results:
return
# 展开结果数据
rows = []
for i, result in enumerate(results):
row = {
'sample_id': i,
'delta_pleasure': result['delta_pad'][0],
'delta_arousal': result['delta_pad'][1],
'delta_dominance': result['delta_pad'][2],
'delta_pressure': result['delta_pressure'][0],
'confidence': result['confidence'][0],
'inference_time': result.get('inference_time', 0)
}
rows.append(row)
# 写入CSV文件
with open(output_path, 'w', newline='', encoding='utf-8') as f:
fieldnames = rows[0].keys()
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(rows)
def save_text_output(results: List[Dict[str, Any]], output_path: str) -> None:
"""
保存结果为文本格式
Args:
results: 推理结果列表
output_path: 输出路径
"""
with open(output_path, 'w', encoding='utf-8') as f:
f.write("情绪与生理状态变化预测结果\n")
f.write("=" * 50 + "\n\n")
for i, result in enumerate(results):
f.write(f"样本 {i+1}:\n")
f.write(f" ΔPAD (情绪变化):\n")
f.write(f" 快乐度变化: {result['delta_pad'][0]:.6f}\n")
f.write(f" 激活度变化: {result['delta_pad'][1]:.6f}\n")
f.write(f" 支配度变化: {result['delta_pad'][2]:.6f}\n")
f.write(f" Δ压力: {result['delta_pressure'][0]:.6f}\n")
f.write(f" 置信度: {result['confidence'][0]:.6f}\n")
f.write(f" 推理时间: {result.get('inference_time', 0):.6f}秒\n")
f.write("-" * 30 + "\n")
def print_results(results: List[Dict[str, Any]], verbose: bool = True) -> None:
"""
打印推理结果
Args:
results: 推理结果列表
verbose: 是否显示详细信息
"""
if not verbose:
# 简洁输出
for i, result in enumerate(results):
print(f"样本{i+1}: ΔPAD={result['delta_pad']}, Δ压力={result['delta_pressure'][0]:.4f}, 置信度={result['confidence'][0]:.4f}")
return
# 详细输出
print("\n情绪与生理状态变化预测结果")
print("=" * 60)
for i, result in enumerate(results):
print(f"\n样本 {i+1}:")
print(f" ΔPAD (情绪变化):")
print(f" 快乐度变化: {result['delta_pad'][0]:+.6f}")
print(f" 激活度变化: {result['delta_pad'][1]:+.6f}")
print(f" 支配度变化: {result['delta_pad'][2]:+.6f}")
print(f" Δ压力: {result['delta_pressure'][0]:+.6f}")
print(f" 置信度: {result['confidence'][0]:.6f}")
if 'inference_time' in result:
print(f" 推理时间: {result['inference_time']*1000:.2f}ms")
print("-" * 40)
def run_benchmark(engine: InferenceEngine,
num_samples: int = 1000,
batch_size: int = 32) -> None:
"""
运行性能基准测试
Args:
engine: 推理引擎
num_samples: 测试样本数量
batch_size: 批次大小
"""
print(f"\n运行性能基准测试...")
print(f"测试样本数: {num_samples}")
print(f"批次大小: {batch_size}")
try:
stats = engine.benchmark(num_samples, batch_size)
print("\n基准测试结果:")
print(f" 总样本数: {stats['total_samples']}")
print(f" 总时间: {stats['total_time']:.4f}秒")
print(f" 吞吐量: {stats['throughput']:.2f} 样本/秒")
print(f" 平均延迟: {stats['avg_latency']:.2f}ms")
print(f" 最小延迟: {stats['min_time']*1000:.2f}ms")
print(f" 最大延迟: {stats['max_time']*1000:.2f}ms")
print(f" P95延迟: {stats['p95_latency']:.2f}ms")
print(f" P99延迟: {stats['p99_latency']:.2f}ms")
except Exception as e:
print(f"基准测试失败: {e}")
def main():
"""主函数"""
parser = argparse.ArgumentParser(
description="情绪与生理状态变化预测推理脚本",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
输入格式说明:
1. 命令行参数: --input-cli 0.5 0.3 -0.2 80 0.1 0.4 -0.1
2. JSON文件: --input-json data.json
3. CSV文件: --input-csv data.csv
输出格式说明:
- JSON: 结构化数据,便于程序处理
- CSV: 表格数据,便于Excel处理
- TXT: 人类可读的文本格式
示例用法:
# 单样本推理
python inference.py --model model.pth --input-cli 0.5 0.3 -0.2 80 0.1 0.4 -0.1
# 批量推理
python inference.py --model model.pth --input-json batch_data.json --output-json results.json
# 基准测试
python inference.py --model model.pth --benchmark --num-samples 1000
"""
)
# 模型相关参数
parser.add_argument('--model', '-m', type=str, required=True,
help='模型文件路径 (.pth)')
parser.add_argument('--preprocessor', '-p', type=str,
help='预处理器文件路径')
parser.add_argument('--device', type=str, choices=['auto', 'cpu', 'cuda'],
default='auto', help='计算设备')
# 输入相关参数
input_group = parser.add_mutually_exclusive_group(required=True)
input_group.add_argument('--input-cli', nargs='+', metavar='VALUE',
help='命令行输入 (7个数值: user_pleasure user_arousal user_dominance vitality current_pleasure current_arousal current_dominance)')
input_group.add_argument('--input-json', type=str, metavar='FILE',
help='JSON输入文件路径')
input_group.add_argument('--input-csv', type=str, metavar='FILE',
help='CSV输入文件路径')
# 输出相关参数
parser.add_argument('--output-json', type=str, metavar='FILE',
help='JSON输出文件路径')
parser.add_argument('--output-csv', type=str, metavar='FILE',
help='CSV输出文件路径')
parser.add_argument('--output-txt', type=str, metavar='FILE',
help='文本输出文件路径')
parser.add_argument('--quiet', '-q', action='store_true',
help='静默模式,不打印结果')
# 推理参数
parser.add_argument('--batch-size', type=int, default=32,
help='批量推理的批次大小')
# 基准测试参数
parser.add_argument('--benchmark', action='store_true',
help='运行性能基准测试')
parser.add_argument('--num-samples', type=int, default=1000,
help='基准测试的样本数量')
# 其他参数
parser.add_argument('--verbose', '-v', action='store_true',
help='详细输出')
parser.add_argument('--log-level', type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'],
default='INFO', help='日志级别')
args = parser.parse_args()
# 设置日志
setup_logger(level=args.log_level)
logger = logging.getLogger(__name__)
try:
# 创建推理引擎
logger.info("初始化推理引擎...")
engine = create_inference_engine(
model_path=args.model,
preprocessor_path=args.preprocessor,
device=args.device
)
# 打印模型信息
if args.verbose:
model_info = engine.get_model_info()
print(f"\n模型信息:")
print(f" 设备: {model_info['device']}")
print(f" 总参数量: {model_info['total_parameters']:,}")
print(f" 输入维度: {model_info['input_dim']}")
print(f" 输出维度: {model_info['output_dim']}")
# 运行基准测试
if args.benchmark:
run_benchmark(engine, args.num_samples, args.batch_size)
return
# 加载输入数据
logger.info("加载输入数据...")
if args.input_cli:
input_data = parse_command_line_input(args.input_cli)
elif args.input_json:
input_data = load_json_input(args.input_json)
elif args.input_csv:
input_data = load_csv_input(args.input_csv)
logger.info(f"加载了 {len(input_data)} 个样本")
# 执行推理
logger.info("执行推理...")
start_time = time.time()
if len(input_data) == 1:
# 单样本推理
result = engine.predict(input_data[0])
results = [result.to_dict()]
else:
# 批量推理
results = engine.predict_batch(input_data, args.batch_size)
results = [result.to_dict() for result in results]
total_time = time.time() - start_time
logger.info(f"推理完成,总时间: {total_time:.4f}秒")
# 打印结果
if not args.quiet:
print_results(results, verbose=args.verbose)
# 保存结果
if args.output_json:
save_json_output(results, args.output_json)
print(f"结果已保存到: {args.output_json}")
if args.output_csv:
save_csv_output(results, args.output_csv)
print(f"结果已保存到: {args.output_csv}")
if args.output_txt:
save_text_output(results, args.output_txt)
print(f"结果已保存到: {args.output_txt}")
# 性能统计
if args.verbose:
stats = engine.get_performance_stats()
print(f"\n性能统计:")
print(f" 总推理次数: {stats['total_inferences']}")
print(f" 平均时间: {stats['avg_time']*1000:.2f}ms")
print(f" 最小时间: {stats['min_time']*1000:.2f}ms")
print(f" 最大时间: {stats['max_time']*1000:.2f}ms")
except Exception as e:
logger.error(f"推理失败: {e}")
if args.verbose:
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main() |