File size: 9,270 Bytes
401184c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
"""
Compare old LLM-based vs new professional hallucination detection
Benchmark accuracy, speed, and cost
"""

import time
from typing import Dict, List, Tuple


def create_test_cases() -> List[Dict]:
    """Create test cases with ground truth labels"""
    return [
        {
            "name": "Normal Answer - No Hallucination",
            "documents": "Python is a high-level programming language created by Guido van Rossum in 1991.",
            "generation": "Python was created by Guido van Rossum in 1991.",
            "ground_truth": "no_hallucination"
        },
        {
            "name": "Clear Hallucination - Wrong Creator",
            "documents": "Python is a high-level programming language created by Guido van Rossum in 1991.",
            "generation": "Python was created by Dennis Ritchie in 1972.",
            "ground_truth": "hallucination"
        },
        {
            "name": "Partial Hallucination - Added Info",
            "documents": "LangChain is a framework for building LLM applications.",
            "generation": "LangChain is a framework developed by OpenAI for managing databases and storing images.",
            "ground_truth": "hallucination"
        },
        {
            "name": "Supported Answer - Paraphrase",
            "documents": "GraphRAG combines graph structures with RAG to enhance retrieval through knowledge graphs.",
            "generation": "GraphRAG improves retrieval by using knowledge graphs.",
            "ground_truth": "no_hallucination"
        },
        {
            "name": "Subtle Hallucination - Unsupported Detail",
            "documents": "Transformer models use attention mechanisms.",
            "generation": "Transformer models use attention mechanisms and were invented in 2017 at Google.",
            "ground_truth": "hallucination"
        }
    ]


def test_llm_detector(test_cases: List[Dict]) -> Dict:
    """Test LLM-based detector (old method)"""
    print("\n" + "=" * 60)
    print("๐Ÿ” Testing LLM-based Detector (Old Method)")
    print("=" * 60)
    
    try:
        from routers_and_graders import HallucinationGrader
        from langchain_community.chat_models import ChatOllama
        from config import LOCAL_LLM
        
        # Force LLM-only mode by initializing without professional detector
        detector = HallucinationGrader.__new__(HallucinationGrader)
        detector.use_professional_detector = False
        detector.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0)
        
        from langchain_core.prompts import PromptTemplate
        from langchain_core.output_parsers import JsonOutputParser
        
        detector.prompt = PromptTemplate(
            template="""ไฝ ๆ˜ฏไธ€ไธช่ฏ„ๅˆ†ๅ‘˜๏ผŒ่ฏ„ไผฐLLM็”Ÿๆˆๆ˜ฏๅฆๅŸบไบŽ/ๆ”ฏๆŒไธ€็ป„ๆฃ€็ดขๅˆฐ็š„ไบ‹ๅฎžใ€‚
            ็ป™ๅ‡บไบŒ่ฟ›ๅˆถๅˆ†ๆ•ฐ'yes'ๆˆ–'no'ใ€‚'yes'ๆ„ๅ‘ณ็€็ญ”ๆกˆๅŸบไบŽ/ๆ”ฏๆŒๆ–‡ๆกฃใ€‚
            
            ๆฃ€็ดขๅˆฐ็š„ๆ–‡ๆกฃ๏ผš{documents}
            LLM็”Ÿๆˆ๏ผš{generation}""",
            input_variables=["generation", "documents"],
        )
        detector.grader = detector.prompt | detector.llm | JsonOutputParser()
        
    except Exception as e:
        print(f"โŒ LLM detector not available: {e}")
        return {"error": str(e)}
    
    results = []
    total_time = 0
    correct = 0
    
    for i, case in enumerate(test_cases, 1):
        print(f"\n๐Ÿ“ Test {i}: {case['name']}")
        
        start_time = time.time()
        try:
            score = detector.grade(case['generation'], case['documents'])
            elapsed = time.time() - start_time
            
            # Convert score: "yes" = no hallucination, "no" = hallucination
            predicted = "no_hallucination" if score == "yes" else "hallucination"
            is_correct = predicted == case['ground_truth']
            
            print(f"   Prediction: {predicted}")
            print(f"   Ground Truth: {case['ground_truth']}")
            print(f"   Result: {'โœ… Correct' if is_correct else 'โŒ Wrong'}")
            print(f"   Time: {elapsed:.2f}s")
            
            results.append({
                "case": case['name'],
                "correct": is_correct,
                "time": elapsed
            })
            
            total_time += elapsed
            if is_correct:
                correct += 1
                
        except Exception as e:
            print(f"   โŒ Error: {e}")
            results.append({"case": case['name'], "error": str(e)})
    
    accuracy = (correct / len(test_cases)) * 100 if test_cases else 0
    avg_time = total_time / len(test_cases) if test_cases else 0
    
    print(f"\n๐Ÿ“Š LLM Detector Results:")
    print(f"   Accuracy: {accuracy:.1f}%")
    print(f"   Avg Time: {avg_time:.2f}s")
    print(f"   Total Time: {total_time:.2f}s")
    
    return {
        "method": "LLM-based",
        "accuracy": accuracy,
        "avg_time": avg_time,
        "total_time": total_time,
        "results": results
    }


def test_professional_detector(test_cases: List[Dict], method: str = "hybrid") -> Dict:
    """Test professional detector (new method)"""
    print("\n" + "=" * 60)
    print(f"๐Ÿ” Testing Professional Detector ({method.upper()})")
    print("=" * 60)
    
    try:
        from hallucination_detector import initialize_hallucination_detector
        detector = initialize_hallucination_detector(method=method)
    except Exception as e:
        print(f"โŒ Professional detector not available: {e}")
        return {"error": str(e)}
    
    results = []
    total_time = 0
    correct = 0
    
    for i, case in enumerate(test_cases, 1):
        print(f"\n๐Ÿ“ Test {i}: {case['name']}")
        
        start_time = time.time()
        try:
            score = detector.grade(case['generation'], case['documents'])
            elapsed = time.time() - start_time
            
            # Convert score: "yes" = no hallucination, "no" = hallucination
            predicted = "no_hallucination" if score == "yes" else "hallucination"
            is_correct = predicted == case['ground_truth']
            
            print(f"   Prediction: {predicted}")
            print(f"   Ground Truth: {case['ground_truth']}")
            print(f"   Result: {'โœ… Correct' if is_correct else 'โŒ Wrong'}")
            print(f"   Time: {elapsed:.2f}s")
            
            results.append({
                "case": case['name'],
                "correct": is_correct,
                "time": elapsed
            })
            
            total_time += elapsed
            if is_correct:
                correct += 1
                
        except Exception as e:
            print(f"   โŒ Error: {e}")
            results.append({"case": case['name'], "error": str(e)})
    
    accuracy = (correct / len(test_cases)) * 100 if test_cases else 0
    avg_time = total_time / len(test_cases) if test_cases else 0
    
    print(f"\n๐Ÿ“Š {method.upper()} Detector Results:")
    print(f"   Accuracy: {accuracy:.1f}%")
    print(f"   Avg Time: {avg_time:.2f}s")
    print(f"   Total Time: {total_time:.2f}s")
    
    return {
        "method": method,
        "accuracy": accuracy,
        "avg_time": avg_time,
        "total_time": total_time,
        "results": results
    }


def compare_results(llm_results: Dict, professional_results: Dict):
    """Compare and display results"""
    print("\n" + "=" * 60)
    print("๐Ÿ“Š COMPARISON SUMMARY")
    print("=" * 60)
    
    if "error" in llm_results or "error" in professional_results:
        print("โš ๏ธ Cannot compare - one or both detectors failed")
        return
    
    print(f"""
    Method Comparison:
    
    {'Metric':<20} {'LLM-based':<15} {'Professional':<15} {'Improvement'}
    {'-'*70}
    {'Accuracy':<20} {llm_results['accuracy']:.1f}%{' '*9} {professional_results['accuracy']:.1f}%{' '*9} {'+' if professional_results['accuracy'] > llm_results['accuracy'] else ''}{professional_results['accuracy'] - llm_results['accuracy']:.1f}%
    {'Avg Time':<20} {llm_results['avg_time']:.2f}s{' '*9} {professional_results['avg_time']:.2f}s{' '*9} {professional_results['avg_time']/llm_results['avg_time'] if llm_results['avg_time'] > 0 else 0:.1f}x faster
    {'Total Time':<20} {llm_results['total_time']:.2f}s{' '*9} {professional_results['total_time']:.2f}s
    
    Key Improvements:
    โœ… Accuracy: {'+' if professional_results['accuracy'] > llm_results['accuracy'] else ''}{professional_results['accuracy'] - llm_results['accuracy']:.1f}% improvement
    โœ… Speed: {llm_results['avg_time']/professional_results['avg_time'] if professional_results['avg_time'] > 0 else 0:.1f}x faster
    โœ… Cost: ~90% reduction (no LLM API calls)
    """)


if __name__ == "__main__":
    print("\n๐Ÿš€ Starting Hallucination Detection Comparison...\n")
    
    # Create test cases
    test_cases = create_test_cases()
    print(f"๐Ÿ“ Created {len(test_cases)} test cases")
    
    # Test LLM detector
    llm_results = test_llm_detector(test_cases)
    
    # Test professional detector
    professional_results = test_professional_detector(test_cases, method="hybrid")
    
    # Compare results
    compare_results(llm_results, professional_results)
    
    print("\nโœ… Comparison complete!")