|
|
""" |
|
|
主CLI入口点 |
|
|
Main CLI Entry Point for emotion and physiological state prediction model |
|
|
|
|
|
该模块提供了统一的命令行界面,支持: |
|
|
- train: 模型训练 |
|
|
- predict: 模型预测 |
|
|
- evaluate: 模型评估 |
|
|
- inference: 推理脚本 |
|
|
- benchmark: 性能基准测试 |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import sys |
|
|
import os |
|
|
import logging |
|
|
from pathlib import Path |
|
|
from typing import List, Optional |
|
|
|
|
|
|
|
|
project_root = Path(__file__).parent.parent.parent |
|
|
sys.path.insert(0, str(project_root)) |
|
|
|
|
|
from src.utils.logger import setup_logger |
|
|
|
|
|
|
|
|
def create_train_parser(subparsers): |
|
|
"""创建训练子命令解析器""" |
|
|
train_parser = subparsers.add_parser( |
|
|
'train', |
|
|
help='训练模型', |
|
|
description='训练情绪与生理状态变化预测模型', |
|
|
formatter_class=argparse.RawDescriptionHelpFormatter, |
|
|
epilog=""" |
|
|
训练示例: |
|
|
# 使用配置文件训练 |
|
|
emotion-train --config configs/training_config.yaml |
|
|
|
|
|
# 指定输出目录 |
|
|
emotion-train --config configs/training_config.yaml --output-dir ./models |
|
|
|
|
|
# 使用GPU训练 |
|
|
emotion-train --config configs/training_config.yaml --device cuda |
|
|
""" |
|
|
) |
|
|
|
|
|
|
|
|
train_parser.add_argument( |
|
|
'--config', '-c', |
|
|
type=str, |
|
|
required=True, |
|
|
help='训练配置文件路径 (.yaml)' |
|
|
) |
|
|
|
|
|
|
|
|
train_parser.add_argument( |
|
|
'--output-dir', '-o', |
|
|
type=str, |
|
|
default='./outputs', |
|
|
help='输出目录 (默认: ./outputs)' |
|
|
) |
|
|
|
|
|
train_parser.add_argument( |
|
|
'--device', |
|
|
type=str, |
|
|
choices=['auto', 'cpu', 'cuda'], |
|
|
default='auto', |
|
|
help='计算设备 (默认: auto)' |
|
|
) |
|
|
|
|
|
train_parser.add_argument( |
|
|
'--resume', |
|
|
type=str, |
|
|
help='从检查点恢复训练' |
|
|
) |
|
|
|
|
|
train_parser.add_argument( |
|
|
'--epochs', |
|
|
type=int, |
|
|
help='覆盖配置文件中的训练轮数' |
|
|
) |
|
|
|
|
|
train_parser.add_argument( |
|
|
'--batch-size', |
|
|
type=int, |
|
|
help='覆盖配置文件中的批次大小' |
|
|
) |
|
|
|
|
|
train_parser.add_argument( |
|
|
'--learning-rate', |
|
|
type=float, |
|
|
help='覆盖配置文件中的学习率' |
|
|
) |
|
|
|
|
|
train_parser.add_argument( |
|
|
'--seed', |
|
|
type=int, |
|
|
default=42, |
|
|
help='随机种子 (默认: 42)' |
|
|
) |
|
|
|
|
|
train_parser.add_argument( |
|
|
'--verbose', '-v', |
|
|
action='store_true', |
|
|
help='详细输出' |
|
|
) |
|
|
|
|
|
train_parser.add_argument( |
|
|
'--log-level', |
|
|
type=str, |
|
|
choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], |
|
|
default='INFO', |
|
|
help='日志级别 (默认: INFO)' |
|
|
) |
|
|
|
|
|
train_parser.set_defaults(func=run_train) |
|
|
|
|
|
return train_parser |
|
|
|
|
|
|
|
|
def create_predict_parser(subparsers): |
|
|
"""创建预测子命令解析器""" |
|
|
predict_parser = subparsers.add_parser( |
|
|
'predict', |
|
|
help='预测', |
|
|
description='使用训练好的模型进行预测', |
|
|
formatter_class=argparse.RawDescriptionHelpFormatter, |
|
|
epilog=""" |
|
|
预测示例: |
|
|
# 交互式预测 |
|
|
emotion-predict --model model.pth |
|
|
|
|
|
# 快速预测 |
|
|
emotion-predict --model model.pth --quick 0.5 0.3 -0.2 80 0.1 0.4 -0.1 |
|
|
|
|
|
# 批量预测 |
|
|
emotion-predict --model model.pth --batch input.json --output results.json |
|
|
""" |
|
|
) |
|
|
|
|
|
|
|
|
predict_parser.add_argument( |
|
|
'--model', '-m', |
|
|
type=str, |
|
|
required=True, |
|
|
help='模型文件路径 (.pth)' |
|
|
) |
|
|
|
|
|
|
|
|
predict_parser.add_argument( |
|
|
'--preprocessor', '-p', |
|
|
type=str, |
|
|
help='预处理器文件路径' |
|
|
) |
|
|
|
|
|
|
|
|
mode_group = predict_parser.add_mutually_exclusive_group() |
|
|
mode_group.add_argument( |
|
|
'--interactive', '-i', |
|
|
action='store_true', |
|
|
help='交互式模式' |
|
|
) |
|
|
|
|
|
mode_group.add_argument( |
|
|
'--quick', |
|
|
nargs=7, |
|
|
type=float, |
|
|
metavar='VALUE', |
|
|
help='快速预测模式 (7个数值: user_pleasure user_arousal user_dominance vitality current_pleasure current_arousal current_dominance)' |
|
|
) |
|
|
|
|
|
mode_group.add_argument( |
|
|
'--batch', |
|
|
type=str, |
|
|
metavar='FILE', |
|
|
help='批量预测模式 (输入文件)' |
|
|
) |
|
|
|
|
|
predict_parser.add_argument( |
|
|
'--output', '-o', |
|
|
type=str, |
|
|
help='输出文件路径 (批量模式)' |
|
|
) |
|
|
|
|
|
predict_parser.add_argument( |
|
|
'--device', |
|
|
type=str, |
|
|
choices=['auto', 'cpu', 'cuda'], |
|
|
default='auto', |
|
|
help='计算设备 (默认: auto)' |
|
|
) |
|
|
|
|
|
predict_parser.add_argument( |
|
|
'--verbose', '-v', |
|
|
action='store_true', |
|
|
help='详细输出' |
|
|
) |
|
|
|
|
|
predict_parser.add_argument( |
|
|
'--log-level', |
|
|
type=str, |
|
|
choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], |
|
|
default='WARNING', |
|
|
help='日志级别 (默认: WARNING)' |
|
|
) |
|
|
|
|
|
predict_parser.set_defaults(func=run_predict) |
|
|
|
|
|
return predict_parser |
|
|
|
|
|
|
|
|
def create_evaluate_parser(subparsers): |
|
|
"""创建评估子命令解析器""" |
|
|
evaluate_parser = subparsers.add_parser( |
|
|
'evaluate', |
|
|
help='评估模型', |
|
|
description='评估模型性能', |
|
|
formatter_class=argparse.RawDescriptionHelpFormatter, |
|
|
epilog=""" |
|
|
评估示例: |
|
|
# 评估模型 |
|
|
emotion-evaluate --model model.pth --data test_data.csv |
|
|
|
|
|
# 生成详细报告 |
|
|
emotion-evaluate --model model.pth --data test_data.csv --report detailed_report.html |
|
|
|
|
|
# 指定指标 |
|
|
emotion-evaluate --model model.pth --data test_data.csv --metrics mse mae r2 |
|
|
""" |
|
|
) |
|
|
|
|
|
|
|
|
evaluate_parser.add_argument( |
|
|
'--model', '-m', |
|
|
type=str, |
|
|
required=True, |
|
|
help='模型文件路径 (.pth)' |
|
|
) |
|
|
|
|
|
evaluate_parser.add_argument( |
|
|
'--data', '-d', |
|
|
type=str, |
|
|
required=True, |
|
|
help='测试数据文件路径' |
|
|
) |
|
|
|
|
|
|
|
|
evaluate_parser.add_argument( |
|
|
'--preprocessor', '-p', |
|
|
type=str, |
|
|
help='预处理器文件路径' |
|
|
) |
|
|
|
|
|
evaluate_parser.add_argument( |
|
|
'--output', '-o', |
|
|
type=str, |
|
|
help='评估结果输出路径' |
|
|
) |
|
|
|
|
|
evaluate_parser.add_argument( |
|
|
'--report', |
|
|
type=str, |
|
|
help='生成详细报告文件路径' |
|
|
) |
|
|
|
|
|
evaluate_parser.add_argument( |
|
|
'--metrics', |
|
|
nargs='+', |
|
|
choices=['mse', 'mae', 'rmse', 'r2', 'mape'], |
|
|
default=['mse', 'mae', 'r2'], |
|
|
help='评估指标 (默认: mse mae r2)' |
|
|
) |
|
|
|
|
|
evaluate_parser.add_argument( |
|
|
'--batch-size', |
|
|
type=int, |
|
|
default=32, |
|
|
help='批次大小 (默认: 32)' |
|
|
) |
|
|
|
|
|
evaluate_parser.add_argument( |
|
|
'--device', |
|
|
type=str, |
|
|
choices=['auto', 'cpu', 'cuda'], |
|
|
default='auto', |
|
|
help='计算设备 (默认: auto)' |
|
|
) |
|
|
|
|
|
evaluate_parser.add_argument( |
|
|
'--verbose', '-v', |
|
|
action='store_true', |
|
|
help='详细输出' |
|
|
) |
|
|
|
|
|
evaluate_parser.add_argument( |
|
|
'--log-level', |
|
|
type=str, |
|
|
choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], |
|
|
default='INFO', |
|
|
help='日志级别 (默认: INFO)' |
|
|
) |
|
|
|
|
|
evaluate_parser.set_defaults(func=run_evaluate) |
|
|
|
|
|
return evaluate_parser |
|
|
|
|
|
|
|
|
def create_inference_parser(subparsers): |
|
|
"""创建推理子命令解析器""" |
|
|
inference_parser = subparsers.add_parser( |
|
|
'inference', |
|
|
help='推理脚本', |
|
|
description='使用推理脚本进行高级推理', |
|
|
formatter_class=argparse.RawDescriptionHelpFormatter, |
|
|
epilog=""" |
|
|
推理示例: |
|
|
# 单样本推理 |
|
|
emotion-inference --model model.pth --input-cli 0.5 0.3 -0.2 80 0.1 0.4 -0.1 |
|
|
|
|
|
# JSON文件推理 |
|
|
emotion-inference --model model.pth --input-json data.json --output-json results.json |
|
|
|
|
|
# CSV文件推理 |
|
|
emotion-inference --model model.pth --input-csv data.csv --output-csv results.csv |
|
|
|
|
|
# 基准测试 |
|
|
emotion-inference --model model.pth --benchmark --num-samples 1000 |
|
|
""" |
|
|
) |
|
|
|
|
|
|
|
|
inference_parser.add_argument( |
|
|
'--model', '-m', |
|
|
type=str, |
|
|
required=True, |
|
|
help='模型文件路径 (.pth)' |
|
|
) |
|
|
|
|
|
inference_parser.add_argument( |
|
|
'--preprocessor', '-p', |
|
|
type=str, |
|
|
help='预处理器文件路径' |
|
|
) |
|
|
|
|
|
inference_parser.add_argument( |
|
|
'--device', |
|
|
type=str, |
|
|
choices=['auto', 'cpu', 'cuda'], |
|
|
default='auto', |
|
|
help='计算设备 (默认: auto)' |
|
|
) |
|
|
|
|
|
|
|
|
input_group = inference_parser.add_mutually_exclusive_group(required=True) |
|
|
input_group.add_argument( |
|
|
'--input-cli', |
|
|
nargs='+', |
|
|
metavar='VALUE', |
|
|
help='命令行输入 (7个数值)' |
|
|
) |
|
|
input_group.add_argument( |
|
|
'--input-json', |
|
|
type=str, |
|
|
metavar='FILE', |
|
|
help='JSON输入文件路径' |
|
|
) |
|
|
input_group.add_argument( |
|
|
'--input-csv', |
|
|
type=str, |
|
|
metavar='FILE', |
|
|
help='CSV输入文件路径' |
|
|
) |
|
|
|
|
|
|
|
|
inference_parser.add_argument( |
|
|
'--output-json', |
|
|
type=str, |
|
|
metavar='FILE', |
|
|
help='JSON输出文件路径' |
|
|
) |
|
|
|
|
|
inference_parser.add_argument( |
|
|
'--output-csv', |
|
|
type=str, |
|
|
metavar='FILE', |
|
|
help='CSV输出文件路径' |
|
|
) |
|
|
|
|
|
inference_parser.add_argument( |
|
|
'--output-txt', |
|
|
type=str, |
|
|
metavar='FILE', |
|
|
help='文本输出文件路径' |
|
|
) |
|
|
|
|
|
inference_parser.add_argument( |
|
|
'--quiet', '-q', |
|
|
action='store_true', |
|
|
help='静默模式,不打印结果' |
|
|
) |
|
|
|
|
|
|
|
|
inference_parser.add_argument( |
|
|
'--batch-size', |
|
|
type=int, |
|
|
default=32, |
|
|
help='批量推理的批次大小 (默认: 32)' |
|
|
) |
|
|
|
|
|
|
|
|
inference_parser.add_argument( |
|
|
'--benchmark', |
|
|
action='store_true', |
|
|
help='运行性能基准测试' |
|
|
) |
|
|
|
|
|
inference_parser.add_argument( |
|
|
'--num-samples', |
|
|
type=int, |
|
|
default=1000, |
|
|
help='基准测试的样本数量 (默认: 1000)' |
|
|
) |
|
|
|
|
|
inference_parser.add_argument( |
|
|
'--verbose', '-v', |
|
|
action='store_true', |
|
|
help='详细输出' |
|
|
) |
|
|
|
|
|
inference_parser.add_argument( |
|
|
'--log-level', |
|
|
type=str, |
|
|
choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], |
|
|
default='INFO', |
|
|
help='日志级别 (默认: INFO)' |
|
|
) |
|
|
|
|
|
inference_parser.set_defaults(func=run_inference) |
|
|
|
|
|
return inference_parser |
|
|
|
|
|
|
|
|
def create_benchmark_parser(subparsers): |
|
|
"""创建基准测试子命令解析器""" |
|
|
benchmark_parser = subparsers.add_parser( |
|
|
'benchmark', |
|
|
help='性能基准测试', |
|
|
description='运行模型性能基准测试', |
|
|
formatter_class=argparse.RawDescriptionHelpFormatter, |
|
|
epilog=""" |
|
|
基准测试示例: |
|
|
# 标准基准测试 |
|
|
emotion-benchmark --model model.pth |
|
|
|
|
|
# 自定义测试参数 |
|
|
emotion-benchmark --model model.pth --num-samples 5000 --batch-size 64 |
|
|
|
|
|
# 生成性能报告 |
|
|
emotion-benchmark --model model.pth --report performance_report.json |
|
|
""" |
|
|
) |
|
|
|
|
|
|
|
|
benchmark_parser.add_argument( |
|
|
'--model', '-m', |
|
|
type=str, |
|
|
required=True, |
|
|
help='模型文件路径 (.pth)' |
|
|
) |
|
|
|
|
|
|
|
|
benchmark_parser.add_argument( |
|
|
'--preprocessor', '-p', |
|
|
type=str, |
|
|
help='预处理器文件路径' |
|
|
) |
|
|
|
|
|
benchmark_parser.add_argument( |
|
|
'--num-samples', |
|
|
type=int, |
|
|
default=1000, |
|
|
help='测试样本数量 (默认: 1000)' |
|
|
) |
|
|
|
|
|
benchmark_parser.add_argument( |
|
|
'--batch-size', |
|
|
type=int, |
|
|
default=32, |
|
|
help='批次大小 (默认: 32)' |
|
|
) |
|
|
|
|
|
benchmark_parser.add_argument( |
|
|
'--device', |
|
|
type=str, |
|
|
choices=['auto', 'cpu', 'cuda'], |
|
|
default='auto', |
|
|
help='计算设备 (默认: auto)' |
|
|
) |
|
|
|
|
|
benchmark_parser.add_argument( |
|
|
'--report', |
|
|
type=str, |
|
|
help='生成性能报告文件路径' |
|
|
) |
|
|
|
|
|
benchmark_parser.add_argument( |
|
|
'--warmup', |
|
|
type=int, |
|
|
default=10, |
|
|
help='预热轮数 (默认: 10)' |
|
|
) |
|
|
|
|
|
benchmark_parser.add_argument( |
|
|
'--verbose', '-v', |
|
|
action='store_true', |
|
|
help='详细输出' |
|
|
) |
|
|
|
|
|
benchmark_parser.add_argument( |
|
|
'--log-level', |
|
|
type=str, |
|
|
choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], |
|
|
default='INFO', |
|
|
help='日志级别 (默认: INFO)' |
|
|
) |
|
|
|
|
|
benchmark_parser.set_defaults(func=run_benchmark) |
|
|
|
|
|
return benchmark_parser |
|
|
|
|
|
|
|
|
def run_train(args): |
|
|
"""运行训练""" |
|
|
try: |
|
|
from src.scripts.train import main as train_main |
|
|
|
|
|
|
|
|
train_args = [ |
|
|
'--config', args.config, |
|
|
'--output-dir', args.output_dir, |
|
|
'--device', args.device, |
|
|
'--seed', str(args.seed), |
|
|
'--log-level', args.log_level |
|
|
] |
|
|
|
|
|
if args.resume: |
|
|
train_args.extend(['--resume', args.resume]) |
|
|
|
|
|
if args.epochs: |
|
|
train_args.extend(['--epochs', str(args.epochs)]) |
|
|
|
|
|
if args.batch_size: |
|
|
train_args.extend(['--batch-size', str(args.batch_size)]) |
|
|
|
|
|
if args.learning_rate: |
|
|
train_args.extend(['--learning-rate', str(args.learning_rate)]) |
|
|
|
|
|
if args.verbose: |
|
|
train_args.append('--verbose') |
|
|
|
|
|
|
|
|
original_argv = sys.argv |
|
|
sys.argv = ['train'] + train_args |
|
|
|
|
|
try: |
|
|
train_main() |
|
|
finally: |
|
|
sys.argv = original_argv |
|
|
|
|
|
except ImportError as e: |
|
|
print(f"错误: 无法导入训练模块: {e}") |
|
|
print("请确保训练脚本存在: src/scripts/train.py") |
|
|
sys.exit(1) |
|
|
except Exception as e: |
|
|
print(f"训练失败: {e}") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
def run_predict(args): |
|
|
"""运行预测""" |
|
|
try: |
|
|
from src.scripts.predict import main as predict_main |
|
|
|
|
|
|
|
|
predict_args = ['--model', args.model] |
|
|
|
|
|
if args.preprocessor: |
|
|
predict_args.extend(['--preprocessor', args.preprocessor]) |
|
|
|
|
|
if args.interactive: |
|
|
predict_args.append('--interactive') |
|
|
|
|
|
if args.quick: |
|
|
predict_args.extend(['--quick'] + [str(v) for v in args.quick]) |
|
|
|
|
|
if args.batch: |
|
|
predict_args.extend(['--batch', args.batch]) |
|
|
|
|
|
if args.output: |
|
|
predict_args.extend(['--output', args.output]) |
|
|
|
|
|
predict_args.extend(['--device', args.device]) |
|
|
|
|
|
if args.verbose: |
|
|
predict_args.append('--verbose') |
|
|
|
|
|
predict_args.extend(['--log-level', args.log_level]) |
|
|
|
|
|
|
|
|
original_argv = sys.argv |
|
|
sys.argv = ['predict'] + predict_args |
|
|
|
|
|
try: |
|
|
predict_main() |
|
|
finally: |
|
|
sys.argv = original_argv |
|
|
|
|
|
except ImportError as e: |
|
|
print(f"错误: 无法导入预测模块: {e}") |
|
|
print("请确保预测脚本存在: src/scripts/predict.py") |
|
|
sys.exit(1) |
|
|
except Exception as e: |
|
|
print(f"预测失败: {e}") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
def run_evaluate(args): |
|
|
"""运行评估""" |
|
|
try: |
|
|
from src.scripts.evaluate import main as evaluate_main |
|
|
|
|
|
|
|
|
evaluate_args = [ |
|
|
'--model', args.model, |
|
|
'--data', args.data, |
|
|
'--batch-size', str(args.batch_size), |
|
|
'--device', args.device, |
|
|
'--log-level', args.log_level |
|
|
] |
|
|
|
|
|
if args.preprocessor: |
|
|
evaluate_args.extend(['--preprocessor', args.preprocessor]) |
|
|
|
|
|
if args.output: |
|
|
evaluate_args.extend(['--output', args.output]) |
|
|
|
|
|
if args.report: |
|
|
evaluate_args.extend(['--report', args.report]) |
|
|
|
|
|
if args.metrics: |
|
|
evaluate_args.extend(['--metrics'] + args.metrics) |
|
|
|
|
|
if args.verbose: |
|
|
evaluate_args.append('--verbose') |
|
|
|
|
|
|
|
|
original_argv = sys.argv |
|
|
sys.argv = ['evaluate'] + evaluate_args |
|
|
|
|
|
try: |
|
|
evaluate_main() |
|
|
finally: |
|
|
sys.argv = original_argv |
|
|
|
|
|
except ImportError as e: |
|
|
print(f"错误: 无法导入评估模块: {e}") |
|
|
print("请确保评估脚本存在: src/scripts/evaluate.py") |
|
|
sys.exit(1) |
|
|
except Exception as e: |
|
|
print(f"评估失败: {e}") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
def run_inference(args): |
|
|
"""运行推理""" |
|
|
try: |
|
|
from src.scripts.inference import main as inference_main |
|
|
|
|
|
|
|
|
inference_args = [ |
|
|
'--model', args.model, |
|
|
'--device', args.device, |
|
|
'--batch-size', str(args.batch_size), |
|
|
'--log-level', args.log_level |
|
|
] |
|
|
|
|
|
if args.preprocessor: |
|
|
inference_args.extend(['--preprocessor', args.preprocessor]) |
|
|
|
|
|
if args.input_cli: |
|
|
inference_args.extend(['--input-cli'] + args.input_cli) |
|
|
|
|
|
if args.input_json: |
|
|
inference_args.extend(['--input-json', args.input_json]) |
|
|
|
|
|
if args.input_csv: |
|
|
inference_args.extend(['--input-csv', args.input_csv]) |
|
|
|
|
|
if args.output_json: |
|
|
inference_args.extend(['--output-json', args.output_json]) |
|
|
|
|
|
if args.output_csv: |
|
|
inference_args.extend(['--output-csv', args.output_csv]) |
|
|
|
|
|
if args.output_txt: |
|
|
inference_args.extend(['--output-txt', args.output_txt]) |
|
|
|
|
|
if args.quiet: |
|
|
inference_args.append('--quiet') |
|
|
|
|
|
if args.benchmark: |
|
|
inference_args.extend(['--benchmark', '--num-samples', str(args.num_samples)]) |
|
|
|
|
|
if args.verbose: |
|
|
inference_args.append('--verbose') |
|
|
|
|
|
|
|
|
original_argv = sys.argv |
|
|
sys.argv = ['inference'] + inference_args |
|
|
|
|
|
try: |
|
|
inference_main() |
|
|
finally: |
|
|
sys.argv = original_argv |
|
|
|
|
|
except ImportError as e: |
|
|
print(f"错误: 无法导入推理模块: {e}") |
|
|
print("请确保推理脚本存在: src/scripts/inference.py") |
|
|
sys.exit(1) |
|
|
except Exception as e: |
|
|
print(f"推理失败: {e}") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
def run_benchmark(args): |
|
|
"""运行基准测试""" |
|
|
try: |
|
|
from src.utils.inference_engine import create_inference_engine |
|
|
import json |
|
|
|
|
|
|
|
|
setup_logger(level=args.log_level) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
logger.info("初始化推理引擎...") |
|
|
engine = create_inference_engine( |
|
|
model_path=args.model, |
|
|
preprocessor_path=args.preprocessor, |
|
|
device=args.device |
|
|
) |
|
|
|
|
|
|
|
|
logger.info(f"运行基准测试...") |
|
|
stats = engine.benchmark(args.num_samples, args.batch_size) |
|
|
|
|
|
|
|
|
print("\n性能基准测试结果") |
|
|
print("=" * 50) |
|
|
print(f"模型: {args.model}") |
|
|
print(f"设备: {engine.device}") |
|
|
print(f"测试样本数: {stats['total_samples']}") |
|
|
print(f"批次大小: {stats['batch_size']}") |
|
|
print(f"总时间: {stats['total_time']:.4f}秒") |
|
|
print(f"吞吐量: {stats['throughput']:.2f} 样本/秒") |
|
|
print(f"平均延迟: {stats['avg_latency']:.2f}ms") |
|
|
print(f"最小延迟: {stats['min_time']*1000:.2f}ms") |
|
|
print(f"最大延迟: {stats['max_time']*1000:.2f}ms") |
|
|
print(f"P95延迟: {stats['p95_latency']:.2f}ms") |
|
|
print(f"P99延迟: {stats['p99_latency']:.2f}ms") |
|
|
|
|
|
|
|
|
if args.report: |
|
|
report_data = { |
|
|
'model_info': engine.get_model_info(), |
|
|
'benchmark_stats': stats, |
|
|
'test_config': { |
|
|
'num_samples': args.num_samples, |
|
|
'batch_size': args.batch_size, |
|
|
'device': args.device, |
|
|
'warmup': args.warmup |
|
|
} |
|
|
} |
|
|
|
|
|
with open(args.report, 'w', encoding='utf-8') as f: |
|
|
json.dump(report_data, f, indent=2, ensure_ascii=False) |
|
|
|
|
|
print(f"\n性能报告已保存到: {args.report}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"基准测试失败: {e}") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""主函数""" |
|
|
parser = argparse.ArgumentParser( |
|
|
prog='emotion-prediction', |
|
|
description='情绪与生理状态变化预测模型工具集', |
|
|
formatter_class=argparse.RawDescriptionHelpFormatter, |
|
|
epilog=""" |
|
|
使用示例: |
|
|
%(prog)s train --config configs/training_config.yaml |
|
|
%(prog)s predict --model model.pth --quick 0.5 0.3 -0.2 80 0.1 0.4 -0.1 |
|
|
%(prog)s evaluate --model model.pth --data test.csv |
|
|
%(prog)s inference --model model.pth --input-json data.json |
|
|
%(prog)s benchmark --model model.pth --num-samples 1000 |
|
|
|
|
|
子命令帮助: |
|
|
%(prog)s <command> --help |
|
|
""" |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--version', |
|
|
action='version', |
|
|
version='%(prog)s 1.0.0' |
|
|
) |
|
|
|
|
|
|
|
|
subparsers = parser.add_subparsers( |
|
|
dest='command', |
|
|
help='可用命令', |
|
|
metavar='COMMAND' |
|
|
) |
|
|
|
|
|
|
|
|
create_train_parser(subparsers) |
|
|
create_predict_parser(subparsers) |
|
|
create_evaluate_parser(subparsers) |
|
|
create_inference_parser(subparsers) |
|
|
create_benchmark_parser(subparsers) |
|
|
|
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
if not hasattr(args, 'func'): |
|
|
parser.print_help() |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
if hasattr(args, 'log_level'): |
|
|
setup_logger(level=args.log_level) |
|
|
|
|
|
|
|
|
try: |
|
|
args.func(args) |
|
|
except KeyboardInterrupt: |
|
|
print("\n用户中断操作") |
|
|
sys.exit(1) |
|
|
except Exception as e: |
|
|
print(f"执行失败: {e}") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |