""" 主CLI入口点 Main CLI Entry Point for emotion and physiological state prediction model 该模块提供了统一的命令行界面,支持: - train: 模型训练 - predict: 模型预测 - evaluate: 模型评估 - inference: 推理脚本 - benchmark: 性能基准测试 """ import argparse import sys import os import logging from pathlib import Path from typing import List, Optional # 添加项目根目录到Python路径 project_root = Path(__file__).parent.parent.parent sys.path.insert(0, str(project_root)) from src.utils.logger import setup_logger def create_train_parser(subparsers): """创建训练子命令解析器""" train_parser = subparsers.add_parser( 'train', help='训练模型', description='训练情绪与生理状态变化预测模型', formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" 训练示例: # 使用配置文件训练 emotion-train --config configs/training_config.yaml # 指定输出目录 emotion-train --config configs/training_config.yaml --output-dir ./models # 使用GPU训练 emotion-train --config configs/training_config.yaml --device cuda """ ) # 必需参数 train_parser.add_argument( '--config', '-c', type=str, required=True, help='训练配置文件路径 (.yaml)' ) # 可选参数 train_parser.add_argument( '--output-dir', '-o', type=str, default='./outputs', help='输出目录 (默认: ./outputs)' ) train_parser.add_argument( '--device', type=str, choices=['auto', 'cpu', 'cuda'], default='auto', help='计算设备 (默认: auto)' ) train_parser.add_argument( '--resume', type=str, help='从检查点恢复训练' ) train_parser.add_argument( '--epochs', type=int, help='覆盖配置文件中的训练轮数' ) train_parser.add_argument( '--batch-size', type=int, help='覆盖配置文件中的批次大小' ) train_parser.add_argument( '--learning-rate', type=float, help='覆盖配置文件中的学习率' ) train_parser.add_argument( '--seed', type=int, default=42, help='随机种子 (默认: 42)' ) train_parser.add_argument( '--verbose', '-v', action='store_true', help='详细输出' ) train_parser.add_argument( '--log-level', type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], default='INFO', help='日志级别 (默认: INFO)' ) train_parser.set_defaults(func=run_train) return train_parser def create_predict_parser(subparsers): """创建预测子命令解析器""" predict_parser = subparsers.add_parser( 'predict', help='预测', description='使用训练好的模型进行预测', formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" 预测示例: # 交互式预测 emotion-predict --model model.pth # 快速预测 emotion-predict --model model.pth --quick 0.5 0.3 -0.2 80 0.1 0.4 -0.1 # 批量预测 emotion-predict --model model.pth --batch input.json --output results.json """ ) # 必需参数 predict_parser.add_argument( '--model', '-m', type=str, required=True, help='模型文件路径 (.pth)' ) # 可选参数 predict_parser.add_argument( '--preprocessor', '-p', type=str, help='预处理器文件路径' ) # 模式选择 mode_group = predict_parser.add_mutually_exclusive_group() mode_group.add_argument( '--interactive', '-i', action='store_true', help='交互式模式' ) mode_group.add_argument( '--quick', nargs=7, type=float, metavar='VALUE', help='快速预测模式 (7个数值: user_pleasure user_arousal user_dominance vitality current_pleasure current_arousal current_dominance)' ) mode_group.add_argument( '--batch', type=str, metavar='FILE', help='批量预测模式 (输入文件)' ) predict_parser.add_argument( '--output', '-o', type=str, help='输出文件路径 (批量模式)' ) predict_parser.add_argument( '--device', type=str, choices=['auto', 'cpu', 'cuda'], default='auto', help='计算设备 (默认: auto)' ) predict_parser.add_argument( '--verbose', '-v', action='store_true', help='详细输出' ) predict_parser.add_argument( '--log-level', type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], default='WARNING', help='日志级别 (默认: WARNING)' ) predict_parser.set_defaults(func=run_predict) return predict_parser def create_evaluate_parser(subparsers): """创建评估子命令解析器""" evaluate_parser = subparsers.add_parser( 'evaluate', help='评估模型', description='评估模型性能', formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" 评估示例: # 评估模型 emotion-evaluate --model model.pth --data test_data.csv # 生成详细报告 emotion-evaluate --model model.pth --data test_data.csv --report detailed_report.html # 指定指标 emotion-evaluate --model model.pth --data test_data.csv --metrics mse mae r2 """ ) # 必需参数 evaluate_parser.add_argument( '--model', '-m', type=str, required=True, help='模型文件路径 (.pth)' ) evaluate_parser.add_argument( '--data', '-d', type=str, required=True, help='测试数据文件路径' ) # 可选参数 evaluate_parser.add_argument( '--preprocessor', '-p', type=str, help='预处理器文件路径' ) evaluate_parser.add_argument( '--output', '-o', type=str, help='评估结果输出路径' ) evaluate_parser.add_argument( '--report', type=str, help='生成详细报告文件路径' ) evaluate_parser.add_argument( '--metrics', nargs='+', choices=['mse', 'mae', 'rmse', 'r2', 'mape'], default=['mse', 'mae', 'r2'], help='评估指标 (默认: mse mae r2)' ) evaluate_parser.add_argument( '--batch-size', type=int, default=32, help='批次大小 (默认: 32)' ) evaluate_parser.add_argument( '--device', type=str, choices=['auto', 'cpu', 'cuda'], default='auto', help='计算设备 (默认: auto)' ) evaluate_parser.add_argument( '--verbose', '-v', action='store_true', help='详细输出' ) evaluate_parser.add_argument( '--log-level', type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], default='INFO', help='日志级别 (默认: INFO)' ) evaluate_parser.set_defaults(func=run_evaluate) return evaluate_parser def create_inference_parser(subparsers): """创建推理子命令解析器""" inference_parser = subparsers.add_parser( 'inference', help='推理脚本', description='使用推理脚本进行高级推理', formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" 推理示例: # 单样本推理 emotion-inference --model model.pth --input-cli 0.5 0.3 -0.2 80 0.1 0.4 -0.1 # JSON文件推理 emotion-inference --model model.pth --input-json data.json --output-json results.json # CSV文件推理 emotion-inference --model model.pth --input-csv data.csv --output-csv results.csv # 基准测试 emotion-inference --model model.pth --benchmark --num-samples 1000 """ ) # 模型相关参数 inference_parser.add_argument( '--model', '-m', type=str, required=True, help='模型文件路径 (.pth)' ) inference_parser.add_argument( '--preprocessor', '-p', type=str, help='预处理器文件路径' ) inference_parser.add_argument( '--device', type=str, choices=['auto', 'cpu', 'cuda'], default='auto', help='计算设备 (默认: auto)' ) # 输入相关参数 input_group = inference_parser.add_mutually_exclusive_group(required=True) input_group.add_argument( '--input-cli', nargs='+', metavar='VALUE', help='命令行输入 (7个数值)' ) input_group.add_argument( '--input-json', type=str, metavar='FILE', help='JSON输入文件路径' ) input_group.add_argument( '--input-csv', type=str, metavar='FILE', help='CSV输入文件路径' ) # 输出相关参数 inference_parser.add_argument( '--output-json', type=str, metavar='FILE', help='JSON输出文件路径' ) inference_parser.add_argument( '--output-csv', type=str, metavar='FILE', help='CSV输出文件路径' ) inference_parser.add_argument( '--output-txt', type=str, metavar='FILE', help='文本输出文件路径' ) inference_parser.add_argument( '--quiet', '-q', action='store_true', help='静默模式,不打印结果' ) # 推理参数 inference_parser.add_argument( '--batch-size', type=int, default=32, help='批量推理的批次大小 (默认: 32)' ) # 基准测试参数 inference_parser.add_argument( '--benchmark', action='store_true', help='运行性能基准测试' ) inference_parser.add_argument( '--num-samples', type=int, default=1000, help='基准测试的样本数量 (默认: 1000)' ) inference_parser.add_argument( '--verbose', '-v', action='store_true', help='详细输出' ) inference_parser.add_argument( '--log-level', type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], default='INFO', help='日志级别 (默认: INFO)' ) inference_parser.set_defaults(func=run_inference) return inference_parser def create_benchmark_parser(subparsers): """创建基准测试子命令解析器""" benchmark_parser = subparsers.add_parser( 'benchmark', help='性能基准测试', description='运行模型性能基准测试', formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" 基准测试示例: # 标准基准测试 emotion-benchmark --model model.pth # 自定义测试参数 emotion-benchmark --model model.pth --num-samples 5000 --batch-size 64 # 生成性能报告 emotion-benchmark --model model.pth --report performance_report.json """ ) # 必需参数 benchmark_parser.add_argument( '--model', '-m', type=str, required=True, help='模型文件路径 (.pth)' ) # 可选参数 benchmark_parser.add_argument( '--preprocessor', '-p', type=str, help='预处理器文件路径' ) benchmark_parser.add_argument( '--num-samples', type=int, default=1000, help='测试样本数量 (默认: 1000)' ) benchmark_parser.add_argument( '--batch-size', type=int, default=32, help='批次大小 (默认: 32)' ) benchmark_parser.add_argument( '--device', type=str, choices=['auto', 'cpu', 'cuda'], default='auto', help='计算设备 (默认: auto)' ) benchmark_parser.add_argument( '--report', type=str, help='生成性能报告文件路径' ) benchmark_parser.add_argument( '--warmup', type=int, default=10, help='预热轮数 (默认: 10)' ) benchmark_parser.add_argument( '--verbose', '-v', action='store_true', help='详细输出' ) benchmark_parser.add_argument( '--log-level', type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], default='INFO', help='日志级别 (默认: INFO)' ) benchmark_parser.set_defaults(func=run_benchmark) return benchmark_parser def run_train(args): """运行训练""" try: from src.scripts.train import main as train_main # 构建训练参数 train_args = [ '--config', args.config, '--output-dir', args.output_dir, '--device', args.device, '--seed', str(args.seed), '--log-level', args.log_level ] if args.resume: train_args.extend(['--resume', args.resume]) if args.epochs: train_args.extend(['--epochs', str(args.epochs)]) if args.batch_size: train_args.extend(['--batch-size', str(args.batch_size)]) if args.learning_rate: train_args.extend(['--learning-rate', str(args.learning_rate)]) if args.verbose: train_args.append('--verbose') # 临时修改sys.argv original_argv = sys.argv sys.argv = ['train'] + train_args try: train_main() finally: sys.argv = original_argv except ImportError as e: print(f"错误: 无法导入训练模块: {e}") print("请确保训练脚本存在: src/scripts/train.py") sys.exit(1) except Exception as e: print(f"训练失败: {e}") sys.exit(1) def run_predict(args): """运行预测""" try: from src.scripts.predict import main as predict_main # 构建预测参数 predict_args = ['--model', args.model] if args.preprocessor: predict_args.extend(['--preprocessor', args.preprocessor]) if args.interactive: predict_args.append('--interactive') if args.quick: predict_args.extend(['--quick'] + [str(v) for v in args.quick]) if args.batch: predict_args.extend(['--batch', args.batch]) if args.output: predict_args.extend(['--output', args.output]) predict_args.extend(['--device', args.device]) if args.verbose: predict_args.append('--verbose') predict_args.extend(['--log-level', args.log_level]) # 临时修改sys.argv original_argv = sys.argv sys.argv = ['predict'] + predict_args try: predict_main() finally: sys.argv = original_argv except ImportError as e: print(f"错误: 无法导入预测模块: {e}") print("请确保预测脚本存在: src/scripts/predict.py") sys.exit(1) except Exception as e: print(f"预测失败: {e}") sys.exit(1) def run_evaluate(args): """运行评估""" try: from src.scripts.evaluate import main as evaluate_main # 构建评估参数 evaluate_args = [ '--model', args.model, '--data', args.data, '--batch-size', str(args.batch_size), '--device', args.device, '--log-level', args.log_level ] if args.preprocessor: evaluate_args.extend(['--preprocessor', args.preprocessor]) if args.output: evaluate_args.extend(['--output', args.output]) if args.report: evaluate_args.extend(['--report', args.report]) if args.metrics: evaluate_args.extend(['--metrics'] + args.metrics) if args.verbose: evaluate_args.append('--verbose') # 临时修改sys.argv original_argv = sys.argv sys.argv = ['evaluate'] + evaluate_args try: evaluate_main() finally: sys.argv = original_argv except ImportError as e: print(f"错误: 无法导入评估模块: {e}") print("请确保评估脚本存在: src/scripts/evaluate.py") sys.exit(1) except Exception as e: print(f"评估失败: {e}") sys.exit(1) def run_inference(args): """运行推理""" try: from src.scripts.inference import main as inference_main # 构建推理参数 inference_args = [ '--model', args.model, '--device', args.device, '--batch-size', str(args.batch_size), '--log-level', args.log_level ] if args.preprocessor: inference_args.extend(['--preprocessor', args.preprocessor]) if args.input_cli: inference_args.extend(['--input-cli'] + args.input_cli) if args.input_json: inference_args.extend(['--input-json', args.input_json]) if args.input_csv: inference_args.extend(['--input-csv', args.input_csv]) if args.output_json: inference_args.extend(['--output-json', args.output_json]) if args.output_csv: inference_args.extend(['--output-csv', args.output_csv]) if args.output_txt: inference_args.extend(['--output-txt', args.output_txt]) if args.quiet: inference_args.append('--quiet') if args.benchmark: inference_args.extend(['--benchmark', '--num-samples', str(args.num_samples)]) if args.verbose: inference_args.append('--verbose') # 临时修改sys.argv original_argv = sys.argv sys.argv = ['inference'] + inference_args try: inference_main() finally: sys.argv = original_argv except ImportError as e: print(f"错误: 无法导入推理模块: {e}") print("请确保推理脚本存在: src/scripts/inference.py") sys.exit(1) except Exception as e: print(f"推理失败: {e}") sys.exit(1) def run_benchmark(args): """运行基准测试""" try: from src.utils.inference_engine import create_inference_engine import json # 设置日志 setup_logger(level=args.log_level) logger = logging.getLogger(__name__) # 创建推理引擎 logger.info("初始化推理引擎...") engine = create_inference_engine( model_path=args.model, preprocessor_path=args.preprocessor, device=args.device ) # 运行基准测试 logger.info(f"运行基准测试...") stats = engine.benchmark(args.num_samples, args.batch_size) # 显示结果 print("\n性能基准测试结果") print("=" * 50) print(f"模型: {args.model}") print(f"设备: {engine.device}") print(f"测试样本数: {stats['total_samples']}") print(f"批次大小: {stats['batch_size']}") print(f"总时间: {stats['total_time']:.4f}秒") print(f"吞吐量: {stats['throughput']:.2f} 样本/秒") print(f"平均延迟: {stats['avg_latency']:.2f}ms") print(f"最小延迟: {stats['min_time']*1000:.2f}ms") print(f"最大延迟: {stats['max_time']*1000:.2f}ms") print(f"P95延迟: {stats['p95_latency']:.2f}ms") print(f"P99延迟: {stats['p99_latency']:.2f}ms") # 保存报告 if args.report: report_data = { 'model_info': engine.get_model_info(), 'benchmark_stats': stats, 'test_config': { 'num_samples': args.num_samples, 'batch_size': args.batch_size, 'device': args.device, 'warmup': args.warmup } } with open(args.report, 'w', encoding='utf-8') as f: json.dump(report_data, f, indent=2, ensure_ascii=False) print(f"\n性能报告已保存到: {args.report}") except Exception as e: print(f"基准测试失败: {e}") sys.exit(1) def main(): """主函数""" parser = argparse.ArgumentParser( prog='emotion-prediction', description='情绪与生理状态变化预测模型工具集', formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" 使用示例: %(prog)s train --config configs/training_config.yaml %(prog)s predict --model model.pth --quick 0.5 0.3 -0.2 80 0.1 0.4 -0.1 %(prog)s evaluate --model model.pth --data test.csv %(prog)s inference --model model.pth --input-json data.json %(prog)s benchmark --model model.pth --num-samples 1000 子命令帮助: %(prog)s --help """ ) parser.add_argument( '--version', action='version', version='%(prog)s 1.0.0' ) # 创建子命令解析器 subparsers = parser.add_subparsers( dest='command', help='可用命令', metavar='COMMAND' ) # 添加各种子命令 create_train_parser(subparsers) create_predict_parser(subparsers) create_evaluate_parser(subparsers) create_inference_parser(subparsers) create_benchmark_parser(subparsers) # 解析参数 args = parser.parse_args() # 如果没有提供子命令,显示帮助 if not hasattr(args, 'func'): parser.print_help() sys.exit(1) # 设置日志 if hasattr(args, 'log_level'): setup_logger(level=args.log_level) # 执行对应的函数 try: args.func(args) except KeyboardInterrupt: print("\n用户中断操作") sys.exit(1) except Exception as e: print(f"执行失败: {e}") sys.exit(1) if __name__ == "__main__": main()