adaptive_rag / compare_hallucination_methods.py
lanny xu
delete urls
401184c
"""
Compare old LLM-based vs new professional hallucination detection
Benchmark accuracy, speed, and cost
"""
import time
from typing import Dict, List, Tuple
def create_test_cases() -> List[Dict]:
"""Create test cases with ground truth labels"""
return [
{
"name": "Normal Answer - No Hallucination",
"documents": "Python is a high-level programming language created by Guido van Rossum in 1991.",
"generation": "Python was created by Guido van Rossum in 1991.",
"ground_truth": "no_hallucination"
},
{
"name": "Clear Hallucination - Wrong Creator",
"documents": "Python is a high-level programming language created by Guido van Rossum in 1991.",
"generation": "Python was created by Dennis Ritchie in 1972.",
"ground_truth": "hallucination"
},
{
"name": "Partial Hallucination - Added Info",
"documents": "LangChain is a framework for building LLM applications.",
"generation": "LangChain is a framework developed by OpenAI for managing databases and storing images.",
"ground_truth": "hallucination"
},
{
"name": "Supported Answer - Paraphrase",
"documents": "GraphRAG combines graph structures with RAG to enhance retrieval through knowledge graphs.",
"generation": "GraphRAG improves retrieval by using knowledge graphs.",
"ground_truth": "no_hallucination"
},
{
"name": "Subtle Hallucination - Unsupported Detail",
"documents": "Transformer models use attention mechanisms.",
"generation": "Transformer models use attention mechanisms and were invented in 2017 at Google.",
"ground_truth": "hallucination"
}
]
def test_llm_detector(test_cases: List[Dict]) -> Dict:
"""Test LLM-based detector (old method)"""
print("\n" + "=" * 60)
print("🔍 Testing LLM-based Detector (Old Method)")
print("=" * 60)
try:
from routers_and_graders import HallucinationGrader
from langchain_community.chat_models import ChatOllama
from config import LOCAL_LLM
# Force LLM-only mode by initializing without professional detector
detector = HallucinationGrader.__new__(HallucinationGrader)
detector.use_professional_detector = False
detector.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0)
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
detector.prompt = PromptTemplate(
template="""你是一个评分员,评估LLM生成是否基于/支持一组检索到的事实。
给出二进制分数'yes'或'no'。'yes'意味着答案基于/支持文档。
检索到的文档:{documents}
LLM生成:{generation}""",
input_variables=["generation", "documents"],
)
detector.grader = detector.prompt | detector.llm | JsonOutputParser()
except Exception as e:
print(f"❌ LLM detector not available: {e}")
return {"error": str(e)}
results = []
total_time = 0
correct = 0
for i, case in enumerate(test_cases, 1):
print(f"\n📝 Test {i}: {case['name']}")
start_time = time.time()
try:
score = detector.grade(case['generation'], case['documents'])
elapsed = time.time() - start_time
# Convert score: "yes" = no hallucination, "no" = hallucination
predicted = "no_hallucination" if score == "yes" else "hallucination"
is_correct = predicted == case['ground_truth']
print(f" Prediction: {predicted}")
print(f" Ground Truth: {case['ground_truth']}")
print(f" Result: {'✅ Correct' if is_correct else '❌ Wrong'}")
print(f" Time: {elapsed:.2f}s")
results.append({
"case": case['name'],
"correct": is_correct,
"time": elapsed
})
total_time += elapsed
if is_correct:
correct += 1
except Exception as e:
print(f" ❌ Error: {e}")
results.append({"case": case['name'], "error": str(e)})
accuracy = (correct / len(test_cases)) * 100 if test_cases else 0
avg_time = total_time / len(test_cases) if test_cases else 0
print(f"\n📊 LLM Detector Results:")
print(f" Accuracy: {accuracy:.1f}%")
print(f" Avg Time: {avg_time:.2f}s")
print(f" Total Time: {total_time:.2f}s")
return {
"method": "LLM-based",
"accuracy": accuracy,
"avg_time": avg_time,
"total_time": total_time,
"results": results
}
def test_professional_detector(test_cases: List[Dict], method: str = "hybrid") -> Dict:
"""Test professional detector (new method)"""
print("\n" + "=" * 60)
print(f"🔍 Testing Professional Detector ({method.upper()})")
print("=" * 60)
try:
from hallucination_detector import initialize_hallucination_detector
detector = initialize_hallucination_detector(method=method)
except Exception as e:
print(f"❌ Professional detector not available: {e}")
return {"error": str(e)}
results = []
total_time = 0
correct = 0
for i, case in enumerate(test_cases, 1):
print(f"\n📝 Test {i}: {case['name']}")
start_time = time.time()
try:
score = detector.grade(case['generation'], case['documents'])
elapsed = time.time() - start_time
# Convert score: "yes" = no hallucination, "no" = hallucination
predicted = "no_hallucination" if score == "yes" else "hallucination"
is_correct = predicted == case['ground_truth']
print(f" Prediction: {predicted}")
print(f" Ground Truth: {case['ground_truth']}")
print(f" Result: {'✅ Correct' if is_correct else '❌ Wrong'}")
print(f" Time: {elapsed:.2f}s")
results.append({
"case": case['name'],
"correct": is_correct,
"time": elapsed
})
total_time += elapsed
if is_correct:
correct += 1
except Exception as e:
print(f" ❌ Error: {e}")
results.append({"case": case['name'], "error": str(e)})
accuracy = (correct / len(test_cases)) * 100 if test_cases else 0
avg_time = total_time / len(test_cases) if test_cases else 0
print(f"\n📊 {method.upper()} Detector Results:")
print(f" Accuracy: {accuracy:.1f}%")
print(f" Avg Time: {avg_time:.2f}s")
print(f" Total Time: {total_time:.2f}s")
return {
"method": method,
"accuracy": accuracy,
"avg_time": avg_time,
"total_time": total_time,
"results": results
}
def compare_results(llm_results: Dict, professional_results: Dict):
"""Compare and display results"""
print("\n" + "=" * 60)
print("📊 COMPARISON SUMMARY")
print("=" * 60)
if "error" in llm_results or "error" in professional_results:
print("⚠️ Cannot compare - one or both detectors failed")
return
print(f"""
Method Comparison:
{'Metric':<20} {'LLM-based':<15} {'Professional':<15} {'Improvement'}
{'-'*70}
{'Accuracy':<20} {llm_results['accuracy']:.1f}%{' '*9} {professional_results['accuracy']:.1f}%{' '*9} {'+' if professional_results['accuracy'] > llm_results['accuracy'] else ''}{professional_results['accuracy'] - llm_results['accuracy']:.1f}%
{'Avg Time':<20} {llm_results['avg_time']:.2f}s{' '*9} {professional_results['avg_time']:.2f}s{' '*9} {professional_results['avg_time']/llm_results['avg_time'] if llm_results['avg_time'] > 0 else 0:.1f}x faster
{'Total Time':<20} {llm_results['total_time']:.2f}s{' '*9} {professional_results['total_time']:.2f}s
Key Improvements:
✅ Accuracy: {'+' if professional_results['accuracy'] > llm_results['accuracy'] else ''}{professional_results['accuracy'] - llm_results['accuracy']:.1f}% improvement
✅ Speed: {llm_results['avg_time']/professional_results['avg_time'] if professional_results['avg_time'] > 0 else 0:.1f}x faster
✅ Cost: ~90% reduction (no LLM API calls)
""")
if __name__ == "__main__":
print("\n🚀 Starting Hallucination Detection Comparison...\n")
# Create test cases
test_cases = create_test_cases()
print(f"📝 Created {len(test_cases)} test cases")
# Test LLM detector
llm_results = test_llm_detector(test_cases)
# Test professional detector
professional_results = test_professional_detector(test_cases, method="hybrid")
# Compare results
compare_results(llm_results, professional_results)
print("\n✅ Comparison complete!")