hr-eval-api-v2 / scripts /test_sentiment_model.py
KarenYYH
Initial commit - HR Evaluation API v2
c8b1f17
"""
测试情绪分析模型效果
对比规则引擎、官方模型和微调模型
"""
import sys
import os
from pathlib import Path
# 添加项目路径
sys.path.insert(0, str(Path(__file__).parent.parent))
# 直接导入模块
from models.sentiment import SentimentAnalyzer
def test_sentiment_analyzer():
"""测试情绪分析器"""
# HR场景测试用例
test_cases = [
# 积极情绪
("好的,谢谢你的帮助,非常满意!", "positive"),
("没问题,我理解公司的规定", "positive"),
("太好了,感谢你的解答", "positive"),
("流程很规范,我很认可", "positive"),
("专业高效的回复,感谢支持", "positive"),
# 中性情绪
("您好,请问有什么可以帮您?", "neutral"),
("请问申请年假需要什么材料", "neutral"),
("我需要了解培训的具体时间", "neutral"),
("好的,我知道了", "neutral"),
("请问还有什么需要补充的吗", "neutral"),
# 消极情绪
("我对这个处理结果很不满意", "negative"),
("这个制度太不合理了,我很生气", "negative"),
("为什么要强制执行这个规定", "negative"),
("你们的做法让我很失望", "negative"),
("我要投诉这个处理方式", "negative"),
]
print("=" * 80)
print("情绪分析模型测试")
print("=" * 80)
# 测试规则引擎(强制不加载模型)
print("\n【方法1: 规则引擎】")
print("-" * 80)
analyzer_rule = SentimentAnalyzer(model_path=None)
correct_rule = 0
for text, expected in test_cases:
result = analyzer_rule.analyze_turn(text)
predicted = result['emotion']
confidence = result.get('confidence', 0)
method = result.get('method', 'unknown')
status = "✓" if predicted == expected else "✗"
if predicted == expected:
correct_rule += 1
print(f"{status} {text[:35]:35} -> {predicted:8} (期望: {expected:8}) 置信度: {confidence:.2f} 方法: {method}")
accuracy_rule = correct_rule / len(test_cases) * 100
print(f"\n准确率: {correct_rule}/{len(test_cases)} = {accuracy_rule:.1f}%")
# 测试微调模型(如果存在)
finetuned_model_paths = [
"./models/sentiment-hr/final_model",
"./models/sentiment-hr/checkpoint-epoch-3",
"./models/sentiment-hr/checkpoint-epoch-2",
"./models/sentiment-hr/checkpoint-epoch-1",
]
model_loaded = False
for model_path in finetuned_model_paths:
full_path = Path(__file__).parent.parent / model_path
if full_path.exists():
print(f"\n【方法2: 微调模型】({model_path})")
print("-" * 80)
try:
analyzer_finetuned = SentimentAnalyzer(model_path=str(full_path))
if analyzer_finetuned.use_model:
correct_finetuned = 0
for text, expected in test_cases:
result = analyzer_finetuned.analyze_turn(text)
predicted = result['emotion']
confidence = result.get('confidence', 0)
method = result.get('method', 'unknown')
status = "✓" if predicted == expected else "✗"
if predicted == expected:
correct_finetuned += 1
print(f"{status} {text[:35]:35} -> {predicted:8} (期望: {expected:8}) 置信度: {confidence:.2f} 方法: {method}")
accuracy_finetuned = correct_finetuned / len(test_cases) * 100
print(f"\n准确率: {correct_finetuned}/{len(test_cases)} = {accuracy_finetuned:.1f}%")
# 对比提升
improvement = accuracy_finetuned - accuracy_rule
print(f"\n相比规则引擎提升: {improvement:+.1f}%")
model_loaded = True
break
except Exception as e:
print(f"加载微调模型失败: {e}")
if not model_loaded:
print("\n【方法2: 微调模型】未找到微调模型")
print("提示: 运行以下命令训练模型")
print(" python scripts/prepare_sentiment_data.py")
print(" python scripts/train_sentiment.py --train_data ./data/processed/sentiment/train.json")
# 总结
print("\n" + "=" * 80)
print("总结")
print("=" * 80)
print(f"规则引擎准确率: {accuracy_rule:.1f}%")
if model_loaded:
print(f"微调模型准确率: {accuracy_finetuned:.1f}%")
print(f"性能提升: {improvement:+.1f}%")
print("=" * 80)
def interactive_test():
"""交互式测试"""
print("\n" + "=" * 80)
print("交互式情绪分析测试")
print("=" * 80)
print("输入文本进行分析,输入 'quit' 退出")
# 尝试加载微调模型
model_path = "./models/sentiment-hr/final_model"
full_path = Path(__file__).parent.parent / model_path
if full_path.exists():
print(f"使用微调模型: {model_path}")
analyzer = SentimentAnalyzer(model_path=str(full_path))
else:
print("未找到微调模型,使用规则引擎")
analyzer = SentimentAnalyzer(model_path=None)
while True:
text = input("\n请输入文本: ").strip()
if text.lower() in ['quit', 'exit', 'q']:
break
if not text:
continue
result = analyzer.analyze_turn(text)
print(f"\n分析结果:")
print(f" 情绪: {result['emotion']}")
print(f" 置信度: {result.get('confidence', 0):.4f}")
print(f" 方法: {result.get('method', 'unknown')}")
if 'probabilities' in result:
print(f" 概率分布:")
for label, prob in result['probabilities'].items():
print(f" {label}: {prob:.4f}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="测试情绪分析模型")
parser.add_argument("--interactive", "-i", action="store_true",
help="启用交互式测试模式")
args = parser.parse_args()
if args.interactive:
interactive_test()
else:
test_sentiment_analyzer()