Spaces:
Sleeping
Sleeping
| """ | |
| 测试情绪分析模型效果 | |
| 对比规则引擎、官方模型和微调模型 | |
| """ | |
| import sys | |
| import os | |
| from pathlib import Path | |
| # 添加项目路径 | |
| sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| # 直接导入模块 | |
| from models.sentiment import SentimentAnalyzer | |
| def test_sentiment_analyzer(): | |
| """测试情绪分析器""" | |
| # HR场景测试用例 | |
| test_cases = [ | |
| # 积极情绪 | |
| ("好的,谢谢你的帮助,非常满意!", "positive"), | |
| ("没问题,我理解公司的规定", "positive"), | |
| ("太好了,感谢你的解答", "positive"), | |
| ("流程很规范,我很认可", "positive"), | |
| ("专业高效的回复,感谢支持", "positive"), | |
| # 中性情绪 | |
| ("您好,请问有什么可以帮您?", "neutral"), | |
| ("请问申请年假需要什么材料", "neutral"), | |
| ("我需要了解培训的具体时间", "neutral"), | |
| ("好的,我知道了", "neutral"), | |
| ("请问还有什么需要补充的吗", "neutral"), | |
| # 消极情绪 | |
| ("我对这个处理结果很不满意", "negative"), | |
| ("这个制度太不合理了,我很生气", "negative"), | |
| ("为什么要强制执行这个规定", "negative"), | |
| ("你们的做法让我很失望", "negative"), | |
| ("我要投诉这个处理方式", "negative"), | |
| ] | |
| print("=" * 80) | |
| print("情绪分析模型测试") | |
| print("=" * 80) | |
| # 测试规则引擎(强制不加载模型) | |
| print("\n【方法1: 规则引擎】") | |
| print("-" * 80) | |
| analyzer_rule = SentimentAnalyzer(model_path=None) | |
| correct_rule = 0 | |
| for text, expected in test_cases: | |
| result = analyzer_rule.analyze_turn(text) | |
| predicted = result['emotion'] | |
| confidence = result.get('confidence', 0) | |
| method = result.get('method', 'unknown') | |
| status = "✓" if predicted == expected else "✗" | |
| if predicted == expected: | |
| correct_rule += 1 | |
| print(f"{status} {text[:35]:35} -> {predicted:8} (期望: {expected:8}) 置信度: {confidence:.2f} 方法: {method}") | |
| accuracy_rule = correct_rule / len(test_cases) * 100 | |
| print(f"\n准确率: {correct_rule}/{len(test_cases)} = {accuracy_rule:.1f}%") | |
| # 测试微调模型(如果存在) | |
| finetuned_model_paths = [ | |
| "./models/sentiment-hr/final_model", | |
| "./models/sentiment-hr/checkpoint-epoch-3", | |
| "./models/sentiment-hr/checkpoint-epoch-2", | |
| "./models/sentiment-hr/checkpoint-epoch-1", | |
| ] | |
| model_loaded = False | |
| for model_path in finetuned_model_paths: | |
| full_path = Path(__file__).parent.parent / model_path | |
| if full_path.exists(): | |
| print(f"\n【方法2: 微调模型】({model_path})") | |
| print("-" * 80) | |
| try: | |
| analyzer_finetuned = SentimentAnalyzer(model_path=str(full_path)) | |
| if analyzer_finetuned.use_model: | |
| correct_finetuned = 0 | |
| for text, expected in test_cases: | |
| result = analyzer_finetuned.analyze_turn(text) | |
| predicted = result['emotion'] | |
| confidence = result.get('confidence', 0) | |
| method = result.get('method', 'unknown') | |
| status = "✓" if predicted == expected else "✗" | |
| if predicted == expected: | |
| correct_finetuned += 1 | |
| print(f"{status} {text[:35]:35} -> {predicted:8} (期望: {expected:8}) 置信度: {confidence:.2f} 方法: {method}") | |
| accuracy_finetuned = correct_finetuned / len(test_cases) * 100 | |
| print(f"\n准确率: {correct_finetuned}/{len(test_cases)} = {accuracy_finetuned:.1f}%") | |
| # 对比提升 | |
| improvement = accuracy_finetuned - accuracy_rule | |
| print(f"\n相比规则引擎提升: {improvement:+.1f}%") | |
| model_loaded = True | |
| break | |
| except Exception as e: | |
| print(f"加载微调模型失败: {e}") | |
| if not model_loaded: | |
| print("\n【方法2: 微调模型】未找到微调模型") | |
| print("提示: 运行以下命令训练模型") | |
| print(" python scripts/prepare_sentiment_data.py") | |
| print(" python scripts/train_sentiment.py --train_data ./data/processed/sentiment/train.json") | |
| # 总结 | |
| print("\n" + "=" * 80) | |
| print("总结") | |
| print("=" * 80) | |
| print(f"规则引擎准确率: {accuracy_rule:.1f}%") | |
| if model_loaded: | |
| print(f"微调模型准确率: {accuracy_finetuned:.1f}%") | |
| print(f"性能提升: {improvement:+.1f}%") | |
| print("=" * 80) | |
| def interactive_test(): | |
| """交互式测试""" | |
| print("\n" + "=" * 80) | |
| print("交互式情绪分析测试") | |
| print("=" * 80) | |
| print("输入文本进行分析,输入 'quit' 退出") | |
| # 尝试加载微调模型 | |
| model_path = "./models/sentiment-hr/final_model" | |
| full_path = Path(__file__).parent.parent / model_path | |
| if full_path.exists(): | |
| print(f"使用微调模型: {model_path}") | |
| analyzer = SentimentAnalyzer(model_path=str(full_path)) | |
| else: | |
| print("未找到微调模型,使用规则引擎") | |
| analyzer = SentimentAnalyzer(model_path=None) | |
| while True: | |
| text = input("\n请输入文本: ").strip() | |
| if text.lower() in ['quit', 'exit', 'q']: | |
| break | |
| if not text: | |
| continue | |
| result = analyzer.analyze_turn(text) | |
| print(f"\n分析结果:") | |
| print(f" 情绪: {result['emotion']}") | |
| print(f" 置信度: {result.get('confidence', 0):.4f}") | |
| print(f" 方法: {result.get('method', 'unknown')}") | |
| if 'probabilities' in result: | |
| print(f" 概率分布:") | |
| for label, prob in result['probabilities'].items(): | |
| print(f" {label}: {prob:.4f}") | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser(description="测试情绪分析模型") | |
| parser.add_argument("--interactive", "-i", action="store_true", | |
| help="启用交互式测试模式") | |
| args = parser.parse_args() | |
| if args.interactive: | |
| interactive_test() | |
| else: | |
| test_sentiment_analyzer() | |