Spaces:
Paused
Paused
File size: 9,270 Bytes
401184c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 |
"""
Compare old LLM-based vs new professional hallucination detection
Benchmark accuracy, speed, and cost
"""
import time
from typing import Dict, List, Tuple
def create_test_cases() -> List[Dict]:
"""Create test cases with ground truth labels"""
return [
{
"name": "Normal Answer - No Hallucination",
"documents": "Python is a high-level programming language created by Guido van Rossum in 1991.",
"generation": "Python was created by Guido van Rossum in 1991.",
"ground_truth": "no_hallucination"
},
{
"name": "Clear Hallucination - Wrong Creator",
"documents": "Python is a high-level programming language created by Guido van Rossum in 1991.",
"generation": "Python was created by Dennis Ritchie in 1972.",
"ground_truth": "hallucination"
},
{
"name": "Partial Hallucination - Added Info",
"documents": "LangChain is a framework for building LLM applications.",
"generation": "LangChain is a framework developed by OpenAI for managing databases and storing images.",
"ground_truth": "hallucination"
},
{
"name": "Supported Answer - Paraphrase",
"documents": "GraphRAG combines graph structures with RAG to enhance retrieval through knowledge graphs.",
"generation": "GraphRAG improves retrieval by using knowledge graphs.",
"ground_truth": "no_hallucination"
},
{
"name": "Subtle Hallucination - Unsupported Detail",
"documents": "Transformer models use attention mechanisms.",
"generation": "Transformer models use attention mechanisms and were invented in 2017 at Google.",
"ground_truth": "hallucination"
}
]
def test_llm_detector(test_cases: List[Dict]) -> Dict:
"""Test LLM-based detector (old method)"""
print("\n" + "=" * 60)
print("๐ Testing LLM-based Detector (Old Method)")
print("=" * 60)
try:
from routers_and_graders import HallucinationGrader
from langchain_community.chat_models import ChatOllama
from config import LOCAL_LLM
# Force LLM-only mode by initializing without professional detector
detector = HallucinationGrader.__new__(HallucinationGrader)
detector.use_professional_detector = False
detector.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0)
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
detector.prompt = PromptTemplate(
template="""ไฝ ๆฏไธไธช่ฏๅๅ๏ผ่ฏไผฐLLM็ๆๆฏๅฆๅบไบ/ๆฏๆไธ็ปๆฃ็ดขๅฐ็ไบๅฎใ
็ปๅบไบ่ฟๅถๅๆฐ'yes'ๆ'no'ใ'yes'ๆๅณ็็ญๆกๅบไบ/ๆฏๆๆๆกฃใ
ๆฃ็ดขๅฐ็ๆๆกฃ๏ผ{documents}
LLM็ๆ๏ผ{generation}""",
input_variables=["generation", "documents"],
)
detector.grader = detector.prompt | detector.llm | JsonOutputParser()
except Exception as e:
print(f"โ LLM detector not available: {e}")
return {"error": str(e)}
results = []
total_time = 0
correct = 0
for i, case in enumerate(test_cases, 1):
print(f"\n๐ Test {i}: {case['name']}")
start_time = time.time()
try:
score = detector.grade(case['generation'], case['documents'])
elapsed = time.time() - start_time
# Convert score: "yes" = no hallucination, "no" = hallucination
predicted = "no_hallucination" if score == "yes" else "hallucination"
is_correct = predicted == case['ground_truth']
print(f" Prediction: {predicted}")
print(f" Ground Truth: {case['ground_truth']}")
print(f" Result: {'โ
Correct' if is_correct else 'โ Wrong'}")
print(f" Time: {elapsed:.2f}s")
results.append({
"case": case['name'],
"correct": is_correct,
"time": elapsed
})
total_time += elapsed
if is_correct:
correct += 1
except Exception as e:
print(f" โ Error: {e}")
results.append({"case": case['name'], "error": str(e)})
accuracy = (correct / len(test_cases)) * 100 if test_cases else 0
avg_time = total_time / len(test_cases) if test_cases else 0
print(f"\n๐ LLM Detector Results:")
print(f" Accuracy: {accuracy:.1f}%")
print(f" Avg Time: {avg_time:.2f}s")
print(f" Total Time: {total_time:.2f}s")
return {
"method": "LLM-based",
"accuracy": accuracy,
"avg_time": avg_time,
"total_time": total_time,
"results": results
}
def test_professional_detector(test_cases: List[Dict], method: str = "hybrid") -> Dict:
"""Test professional detector (new method)"""
print("\n" + "=" * 60)
print(f"๐ Testing Professional Detector ({method.upper()})")
print("=" * 60)
try:
from hallucination_detector import initialize_hallucination_detector
detector = initialize_hallucination_detector(method=method)
except Exception as e:
print(f"โ Professional detector not available: {e}")
return {"error": str(e)}
results = []
total_time = 0
correct = 0
for i, case in enumerate(test_cases, 1):
print(f"\n๐ Test {i}: {case['name']}")
start_time = time.time()
try:
score = detector.grade(case['generation'], case['documents'])
elapsed = time.time() - start_time
# Convert score: "yes" = no hallucination, "no" = hallucination
predicted = "no_hallucination" if score == "yes" else "hallucination"
is_correct = predicted == case['ground_truth']
print(f" Prediction: {predicted}")
print(f" Ground Truth: {case['ground_truth']}")
print(f" Result: {'โ
Correct' if is_correct else 'โ Wrong'}")
print(f" Time: {elapsed:.2f}s")
results.append({
"case": case['name'],
"correct": is_correct,
"time": elapsed
})
total_time += elapsed
if is_correct:
correct += 1
except Exception as e:
print(f" โ Error: {e}")
results.append({"case": case['name'], "error": str(e)})
accuracy = (correct / len(test_cases)) * 100 if test_cases else 0
avg_time = total_time / len(test_cases) if test_cases else 0
print(f"\n๐ {method.upper()} Detector Results:")
print(f" Accuracy: {accuracy:.1f}%")
print(f" Avg Time: {avg_time:.2f}s")
print(f" Total Time: {total_time:.2f}s")
return {
"method": method,
"accuracy": accuracy,
"avg_time": avg_time,
"total_time": total_time,
"results": results
}
def compare_results(llm_results: Dict, professional_results: Dict):
"""Compare and display results"""
print("\n" + "=" * 60)
print("๐ COMPARISON SUMMARY")
print("=" * 60)
if "error" in llm_results or "error" in professional_results:
print("โ ๏ธ Cannot compare - one or both detectors failed")
return
print(f"""
Method Comparison:
{'Metric':<20} {'LLM-based':<15} {'Professional':<15} {'Improvement'}
{'-'*70}
{'Accuracy':<20} {llm_results['accuracy']:.1f}%{' '*9} {professional_results['accuracy']:.1f}%{' '*9} {'+' if professional_results['accuracy'] > llm_results['accuracy'] else ''}{professional_results['accuracy'] - llm_results['accuracy']:.1f}%
{'Avg Time':<20} {llm_results['avg_time']:.2f}s{' '*9} {professional_results['avg_time']:.2f}s{' '*9} {professional_results['avg_time']/llm_results['avg_time'] if llm_results['avg_time'] > 0 else 0:.1f}x faster
{'Total Time':<20} {llm_results['total_time']:.2f}s{' '*9} {professional_results['total_time']:.2f}s
Key Improvements:
โ
Accuracy: {'+' if professional_results['accuracy'] > llm_results['accuracy'] else ''}{professional_results['accuracy'] - llm_results['accuracy']:.1f}% improvement
โ
Speed: {llm_results['avg_time']/professional_results['avg_time'] if professional_results['avg_time'] > 0 else 0:.1f}x faster
โ
Cost: ~90% reduction (no LLM API calls)
""")
if __name__ == "__main__":
print("\n๐ Starting Hallucination Detection Comparison...\n")
# Create test cases
test_cases = create_test_cases()
print(f"๐ Created {len(test_cases)} test cases")
# Test LLM detector
llm_results = test_llm_detector(test_cases)
# Test professional detector
professional_results = test_professional_detector(test_cases, method="hybrid")
# Compare results
compare_results(llm_results, professional_results)
print("\nโ
Comparison complete!")
|