Chordia / src /cli /main.py
Corolin's picture
first commit
0a6452f
"""
主CLI入口点
Main CLI Entry Point for emotion and physiological state prediction model
该模块提供了统一的命令行界面,支持:
- train: 模型训练
- predict: 模型预测
- evaluate: 模型评估
- inference: 推理脚本
- benchmark: 性能基准测试
"""
import argparse
import sys
import os
import logging
from pathlib import Path
from typing import List, Optional
# 添加项目根目录到Python路径
project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))
from src.utils.logger import setup_logger
def create_train_parser(subparsers):
"""创建训练子命令解析器"""
train_parser = subparsers.add_parser(
'train',
help='训练模型',
description='训练情绪与生理状态变化预测模型',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
训练示例:
# 使用配置文件训练
emotion-train --config configs/training_config.yaml
# 指定输出目录
emotion-train --config configs/training_config.yaml --output-dir ./models
# 使用GPU训练
emotion-train --config configs/training_config.yaml --device cuda
"""
)
# 必需参数
train_parser.add_argument(
'--config', '-c',
type=str,
required=True,
help='训练配置文件路径 (.yaml)'
)
# 可选参数
train_parser.add_argument(
'--output-dir', '-o',
type=str,
default='./outputs',
help='输出目录 (默认: ./outputs)'
)
train_parser.add_argument(
'--device',
type=str,
choices=['auto', 'cpu', 'cuda'],
default='auto',
help='计算设备 (默认: auto)'
)
train_parser.add_argument(
'--resume',
type=str,
help='从检查点恢复训练'
)
train_parser.add_argument(
'--epochs',
type=int,
help='覆盖配置文件中的训练轮数'
)
train_parser.add_argument(
'--batch-size',
type=int,
help='覆盖配置文件中的批次大小'
)
train_parser.add_argument(
'--learning-rate',
type=float,
help='覆盖配置文件中的学习率'
)
train_parser.add_argument(
'--seed',
type=int,
default=42,
help='随机种子 (默认: 42)'
)
train_parser.add_argument(
'--verbose', '-v',
action='store_true',
help='详细输出'
)
train_parser.add_argument(
'--log-level',
type=str,
choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'],
default='INFO',
help='日志级别 (默认: INFO)'
)
train_parser.set_defaults(func=run_train)
return train_parser
def create_predict_parser(subparsers):
"""创建预测子命令解析器"""
predict_parser = subparsers.add_parser(
'predict',
help='预测',
description='使用训练好的模型进行预测',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
预测示例:
# 交互式预测
emotion-predict --model model.pth
# 快速预测
emotion-predict --model model.pth --quick 0.5 0.3 -0.2 80 0.1 0.4 -0.1
# 批量预测
emotion-predict --model model.pth --batch input.json --output results.json
"""
)
# 必需参数
predict_parser.add_argument(
'--model', '-m',
type=str,
required=True,
help='模型文件路径 (.pth)'
)
# 可选参数
predict_parser.add_argument(
'--preprocessor', '-p',
type=str,
help='预处理器文件路径'
)
# 模式选择
mode_group = predict_parser.add_mutually_exclusive_group()
mode_group.add_argument(
'--interactive', '-i',
action='store_true',
help='交互式模式'
)
mode_group.add_argument(
'--quick',
nargs=7,
type=float,
metavar='VALUE',
help='快速预测模式 (7个数值: user_pleasure user_arousal user_dominance vitality current_pleasure current_arousal current_dominance)'
)
mode_group.add_argument(
'--batch',
type=str,
metavar='FILE',
help='批量预测模式 (输入文件)'
)
predict_parser.add_argument(
'--output', '-o',
type=str,
help='输出文件路径 (批量模式)'
)
predict_parser.add_argument(
'--device',
type=str,
choices=['auto', 'cpu', 'cuda'],
default='auto',
help='计算设备 (默认: auto)'
)
predict_parser.add_argument(
'--verbose', '-v',
action='store_true',
help='详细输出'
)
predict_parser.add_argument(
'--log-level',
type=str,
choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'],
default='WARNING',
help='日志级别 (默认: WARNING)'
)
predict_parser.set_defaults(func=run_predict)
return predict_parser
def create_evaluate_parser(subparsers):
"""创建评估子命令解析器"""
evaluate_parser = subparsers.add_parser(
'evaluate',
help='评估模型',
description='评估模型性能',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
评估示例:
# 评估模型
emotion-evaluate --model model.pth --data test_data.csv
# 生成详细报告
emotion-evaluate --model model.pth --data test_data.csv --report detailed_report.html
# 指定指标
emotion-evaluate --model model.pth --data test_data.csv --metrics mse mae r2
"""
)
# 必需参数
evaluate_parser.add_argument(
'--model', '-m',
type=str,
required=True,
help='模型文件路径 (.pth)'
)
evaluate_parser.add_argument(
'--data', '-d',
type=str,
required=True,
help='测试数据文件路径'
)
# 可选参数
evaluate_parser.add_argument(
'--preprocessor', '-p',
type=str,
help='预处理器文件路径'
)
evaluate_parser.add_argument(
'--output', '-o',
type=str,
help='评估结果输出路径'
)
evaluate_parser.add_argument(
'--report',
type=str,
help='生成详细报告文件路径'
)
evaluate_parser.add_argument(
'--metrics',
nargs='+',
choices=['mse', 'mae', 'rmse', 'r2', 'mape'],
default=['mse', 'mae', 'r2'],
help='评估指标 (默认: mse mae r2)'
)
evaluate_parser.add_argument(
'--batch-size',
type=int,
default=32,
help='批次大小 (默认: 32)'
)
evaluate_parser.add_argument(
'--device',
type=str,
choices=['auto', 'cpu', 'cuda'],
default='auto',
help='计算设备 (默认: auto)'
)
evaluate_parser.add_argument(
'--verbose', '-v',
action='store_true',
help='详细输出'
)
evaluate_parser.add_argument(
'--log-level',
type=str,
choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'],
default='INFO',
help='日志级别 (默认: INFO)'
)
evaluate_parser.set_defaults(func=run_evaluate)
return evaluate_parser
def create_inference_parser(subparsers):
"""创建推理子命令解析器"""
inference_parser = subparsers.add_parser(
'inference',
help='推理脚本',
description='使用推理脚本进行高级推理',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
推理示例:
# 单样本推理
emotion-inference --model model.pth --input-cli 0.5 0.3 -0.2 80 0.1 0.4 -0.1
# JSON文件推理
emotion-inference --model model.pth --input-json data.json --output-json results.json
# CSV文件推理
emotion-inference --model model.pth --input-csv data.csv --output-csv results.csv
# 基准测试
emotion-inference --model model.pth --benchmark --num-samples 1000
"""
)
# 模型相关参数
inference_parser.add_argument(
'--model', '-m',
type=str,
required=True,
help='模型文件路径 (.pth)'
)
inference_parser.add_argument(
'--preprocessor', '-p',
type=str,
help='预处理器文件路径'
)
inference_parser.add_argument(
'--device',
type=str,
choices=['auto', 'cpu', 'cuda'],
default='auto',
help='计算设备 (默认: auto)'
)
# 输入相关参数
input_group = inference_parser.add_mutually_exclusive_group(required=True)
input_group.add_argument(
'--input-cli',
nargs='+',
metavar='VALUE',
help='命令行输入 (7个数值)'
)
input_group.add_argument(
'--input-json',
type=str,
metavar='FILE',
help='JSON输入文件路径'
)
input_group.add_argument(
'--input-csv',
type=str,
metavar='FILE',
help='CSV输入文件路径'
)
# 输出相关参数
inference_parser.add_argument(
'--output-json',
type=str,
metavar='FILE',
help='JSON输出文件路径'
)
inference_parser.add_argument(
'--output-csv',
type=str,
metavar='FILE',
help='CSV输出文件路径'
)
inference_parser.add_argument(
'--output-txt',
type=str,
metavar='FILE',
help='文本输出文件路径'
)
inference_parser.add_argument(
'--quiet', '-q',
action='store_true',
help='静默模式,不打印结果'
)
# 推理参数
inference_parser.add_argument(
'--batch-size',
type=int,
default=32,
help='批量推理的批次大小 (默认: 32)'
)
# 基准测试参数
inference_parser.add_argument(
'--benchmark',
action='store_true',
help='运行性能基准测试'
)
inference_parser.add_argument(
'--num-samples',
type=int,
default=1000,
help='基准测试的样本数量 (默认: 1000)'
)
inference_parser.add_argument(
'--verbose', '-v',
action='store_true',
help='详细输出'
)
inference_parser.add_argument(
'--log-level',
type=str,
choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'],
default='INFO',
help='日志级别 (默认: INFO)'
)
inference_parser.set_defaults(func=run_inference)
return inference_parser
def create_benchmark_parser(subparsers):
"""创建基准测试子命令解析器"""
benchmark_parser = subparsers.add_parser(
'benchmark',
help='性能基准测试',
description='运行模型性能基准测试',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
基准测试示例:
# 标准基准测试
emotion-benchmark --model model.pth
# 自定义测试参数
emotion-benchmark --model model.pth --num-samples 5000 --batch-size 64
# 生成性能报告
emotion-benchmark --model model.pth --report performance_report.json
"""
)
# 必需参数
benchmark_parser.add_argument(
'--model', '-m',
type=str,
required=True,
help='模型文件路径 (.pth)'
)
# 可选参数
benchmark_parser.add_argument(
'--preprocessor', '-p',
type=str,
help='预处理器文件路径'
)
benchmark_parser.add_argument(
'--num-samples',
type=int,
default=1000,
help='测试样本数量 (默认: 1000)'
)
benchmark_parser.add_argument(
'--batch-size',
type=int,
default=32,
help='批次大小 (默认: 32)'
)
benchmark_parser.add_argument(
'--device',
type=str,
choices=['auto', 'cpu', 'cuda'],
default='auto',
help='计算设备 (默认: auto)'
)
benchmark_parser.add_argument(
'--report',
type=str,
help='生成性能报告文件路径'
)
benchmark_parser.add_argument(
'--warmup',
type=int,
default=10,
help='预热轮数 (默认: 10)'
)
benchmark_parser.add_argument(
'--verbose', '-v',
action='store_true',
help='详细输出'
)
benchmark_parser.add_argument(
'--log-level',
type=str,
choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'],
default='INFO',
help='日志级别 (默认: INFO)'
)
benchmark_parser.set_defaults(func=run_benchmark)
return benchmark_parser
def run_train(args):
"""运行训练"""
try:
from src.scripts.train import main as train_main
# 构建训练参数
train_args = [
'--config', args.config,
'--output-dir', args.output_dir,
'--device', args.device,
'--seed', str(args.seed),
'--log-level', args.log_level
]
if args.resume:
train_args.extend(['--resume', args.resume])
if args.epochs:
train_args.extend(['--epochs', str(args.epochs)])
if args.batch_size:
train_args.extend(['--batch-size', str(args.batch_size)])
if args.learning_rate:
train_args.extend(['--learning-rate', str(args.learning_rate)])
if args.verbose:
train_args.append('--verbose')
# 临时修改sys.argv
original_argv = sys.argv
sys.argv = ['train'] + train_args
try:
train_main()
finally:
sys.argv = original_argv
except ImportError as e:
print(f"错误: 无法导入训练模块: {e}")
print("请确保训练脚本存在: src/scripts/train.py")
sys.exit(1)
except Exception as e:
print(f"训练失败: {e}")
sys.exit(1)
def run_predict(args):
"""运行预测"""
try:
from src.scripts.predict import main as predict_main
# 构建预测参数
predict_args = ['--model', args.model]
if args.preprocessor:
predict_args.extend(['--preprocessor', args.preprocessor])
if args.interactive:
predict_args.append('--interactive')
if args.quick:
predict_args.extend(['--quick'] + [str(v) for v in args.quick])
if args.batch:
predict_args.extend(['--batch', args.batch])
if args.output:
predict_args.extend(['--output', args.output])
predict_args.extend(['--device', args.device])
if args.verbose:
predict_args.append('--verbose')
predict_args.extend(['--log-level', args.log_level])
# 临时修改sys.argv
original_argv = sys.argv
sys.argv = ['predict'] + predict_args
try:
predict_main()
finally:
sys.argv = original_argv
except ImportError as e:
print(f"错误: 无法导入预测模块: {e}")
print("请确保预测脚本存在: src/scripts/predict.py")
sys.exit(1)
except Exception as e:
print(f"预测失败: {e}")
sys.exit(1)
def run_evaluate(args):
"""运行评估"""
try:
from src.scripts.evaluate import main as evaluate_main
# 构建评估参数
evaluate_args = [
'--model', args.model,
'--data', args.data,
'--batch-size', str(args.batch_size),
'--device', args.device,
'--log-level', args.log_level
]
if args.preprocessor:
evaluate_args.extend(['--preprocessor', args.preprocessor])
if args.output:
evaluate_args.extend(['--output', args.output])
if args.report:
evaluate_args.extend(['--report', args.report])
if args.metrics:
evaluate_args.extend(['--metrics'] + args.metrics)
if args.verbose:
evaluate_args.append('--verbose')
# 临时修改sys.argv
original_argv = sys.argv
sys.argv = ['evaluate'] + evaluate_args
try:
evaluate_main()
finally:
sys.argv = original_argv
except ImportError as e:
print(f"错误: 无法导入评估模块: {e}")
print("请确保评估脚本存在: src/scripts/evaluate.py")
sys.exit(1)
except Exception as e:
print(f"评估失败: {e}")
sys.exit(1)
def run_inference(args):
"""运行推理"""
try:
from src.scripts.inference import main as inference_main
# 构建推理参数
inference_args = [
'--model', args.model,
'--device', args.device,
'--batch-size', str(args.batch_size),
'--log-level', args.log_level
]
if args.preprocessor:
inference_args.extend(['--preprocessor', args.preprocessor])
if args.input_cli:
inference_args.extend(['--input-cli'] + args.input_cli)
if args.input_json:
inference_args.extend(['--input-json', args.input_json])
if args.input_csv:
inference_args.extend(['--input-csv', args.input_csv])
if args.output_json:
inference_args.extend(['--output-json', args.output_json])
if args.output_csv:
inference_args.extend(['--output-csv', args.output_csv])
if args.output_txt:
inference_args.extend(['--output-txt', args.output_txt])
if args.quiet:
inference_args.append('--quiet')
if args.benchmark:
inference_args.extend(['--benchmark', '--num-samples', str(args.num_samples)])
if args.verbose:
inference_args.append('--verbose')
# 临时修改sys.argv
original_argv = sys.argv
sys.argv = ['inference'] + inference_args
try:
inference_main()
finally:
sys.argv = original_argv
except ImportError as e:
print(f"错误: 无法导入推理模块: {e}")
print("请确保推理脚本存在: src/scripts/inference.py")
sys.exit(1)
except Exception as e:
print(f"推理失败: {e}")
sys.exit(1)
def run_benchmark(args):
"""运行基准测试"""
try:
from src.utils.inference_engine import create_inference_engine
import json
# 设置日志
setup_logger(level=args.log_level)
logger = logging.getLogger(__name__)
# 创建推理引擎
logger.info("初始化推理引擎...")
engine = create_inference_engine(
model_path=args.model,
preprocessor_path=args.preprocessor,
device=args.device
)
# 运行基准测试
logger.info(f"运行基准测试...")
stats = engine.benchmark(args.num_samples, args.batch_size)
# 显示结果
print("\n性能基准测试结果")
print("=" * 50)
print(f"模型: {args.model}")
print(f"设备: {engine.device}")
print(f"测试样本数: {stats['total_samples']}")
print(f"批次大小: {stats['batch_size']}")
print(f"总时间: {stats['total_time']:.4f}秒")
print(f"吞吐量: {stats['throughput']:.2f} 样本/秒")
print(f"平均延迟: {stats['avg_latency']:.2f}ms")
print(f"最小延迟: {stats['min_time']*1000:.2f}ms")
print(f"最大延迟: {stats['max_time']*1000:.2f}ms")
print(f"P95延迟: {stats['p95_latency']:.2f}ms")
print(f"P99延迟: {stats['p99_latency']:.2f}ms")
# 保存报告
if args.report:
report_data = {
'model_info': engine.get_model_info(),
'benchmark_stats': stats,
'test_config': {
'num_samples': args.num_samples,
'batch_size': args.batch_size,
'device': args.device,
'warmup': args.warmup
}
}
with open(args.report, 'w', encoding='utf-8') as f:
json.dump(report_data, f, indent=2, ensure_ascii=False)
print(f"\n性能报告已保存到: {args.report}")
except Exception as e:
print(f"基准测试失败: {e}")
sys.exit(1)
def main():
"""主函数"""
parser = argparse.ArgumentParser(
prog='emotion-prediction',
description='情绪与生理状态变化预测模型工具集',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
使用示例:
%(prog)s train --config configs/training_config.yaml
%(prog)s predict --model model.pth --quick 0.5 0.3 -0.2 80 0.1 0.4 -0.1
%(prog)s evaluate --model model.pth --data test.csv
%(prog)s inference --model model.pth --input-json data.json
%(prog)s benchmark --model model.pth --num-samples 1000
子命令帮助:
%(prog)s <command> --help
"""
)
parser.add_argument(
'--version',
action='version',
version='%(prog)s 1.0.0'
)
# 创建子命令解析器
subparsers = parser.add_subparsers(
dest='command',
help='可用命令',
metavar='COMMAND'
)
# 添加各种子命令
create_train_parser(subparsers)
create_predict_parser(subparsers)
create_evaluate_parser(subparsers)
create_inference_parser(subparsers)
create_benchmark_parser(subparsers)
# 解析参数
args = parser.parse_args()
# 如果没有提供子命令,显示帮助
if not hasattr(args, 'func'):
parser.print_help()
sys.exit(1)
# 设置日志
if hasattr(args, 'log_level'):
setup_logger(level=args.log_level)
# 执行对应的函数
try:
args.func(args)
except KeyboardInterrupt:
print("\n用户中断操作")
sys.exit(1)
except Exception as e:
print(f"执行失败: {e}")
sys.exit(1)
if __name__ == "__main__":
main()