File size: 11,717 Bytes
399f3c6
 
 
 
 
9cce495
90b33eb
 
94a7032
 
7a4bc96
 
 
90b33eb
399f3c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
401184c
 
 
 
 
 
 
399f3c6
401184c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399f3c6
 
 
 
401184c
 
 
 
399f3c6
 
401184c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399f3c6
 
9cce495
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399f3c6
 
 
 
 
 
9cce495
 
 
 
 
 
 
 
 
 
 
 
399f3c6
 
 
9cce495
399f3c6
 
9cce495
 
 
 
399f3c6
 
 
 
 
 
401184c
 
 
 
 
 
 
399f3c6
 
 
401184c
399f3c6
9cce495
 
399f3c6
 
 
 
 
 
9cce495
 
 
399f3c6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
"""
路由器和评分器模块
包含查询路由、文档相关性评分、答案质量评分和幻觉检测
"""

from typing import List
try:
    from langchain_core.prompts import PromptTemplate
except ImportError:
    try:
        from langchain_core.prompts import PromptTemplate
    except ImportError:
        from langchain.prompts import PromptTemplate

from langchain_community.chat_models import ChatOllama
from langchain_core.output_parsers import JsonOutputParser, StrOutputParser

from config import LOCAL_LLM


class QueryRouter:
    """查询路由器,决定使用向量存储还是网络搜索"""
    
    def __init__(self):
        self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0)
        self.prompt = PromptTemplate(
            template="""你是一个专家,负责将用户问题路由到向量存储或网络搜索。
            对于关于LLM智能体、提示工程和对抗性攻击的问题,使用向量存储。
            你不需要严格匹配问题中与这些主题相关的关键词。
            否则,使用网络搜索。根据问题给出二进制选择'web_search'或'vectorstore'。
            返回一个只包含'datasource'键的JSON,不要前言或解释。
            要路由的问题:{question}""",
            input_variables=["question"],
        )
        self.router = self.prompt | self.llm | JsonOutputParser()
    
    def route(self, question: str) -> str:
        """路由问题到相应的数据源"""
        result = self.router.invoke({"question": question})
        return result.get("datasource", "web_search")


class DocumentGrader:
    """文档相关性评分器"""
    
    def __init__(self):
        self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0)
        self.prompt = PromptTemplate(
            template="""你是一个评分员,评估检索到的文档是否与用户问题相关。
            如果文档包含与用户问题相关的关键词或语义,请给出'yes'分数。
            给出二进制分数'yes'或'no',以表明文档是否与问题相关。
            将二进制分数作为JSON提供,只包含'score'键,不要前言或解释。
            
            检索到的文档:

 {document} 


            用户问题:{question}""",
            input_variables=["question", "document"],
        )
        self.grader = self.prompt | self.llm | JsonOutputParser()
    
    def grade(self, question: str, document: str) -> str:
        """评估文档与问题的相关性"""
        result = self.grader.invoke({"question": question, "document": document})
        return result.get("score", "no")


class AnswerGrader:
    """答案质量评分器"""
    
    def __init__(self):
        self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0)
        self.prompt = PromptTemplate(
            template="""你是一个评分员,评估答案是否有助于解决问题。
            这里是答案:
            \n ------- \n
            {generation} 
            \n ------- \n
            这里是问题:{question}
            给出二进制分数'yes'或'no',表示答案是否有助于解决问题。
            将二进制分数作为JSON提供,只包含'score'键,不要前言或解释。""",
            input_variables=["generation", "question"],
        )
        self.grader = self.prompt | self.llm | JsonOutputParser()
    
    def grade(self, question: str, generation: str) -> str:
        """评估答案质量"""
        result = self.grader.invoke({"question": question, "generation": generation})
        return result.get("score", "no")


class HallucinationGrader:
    """
    幻觉检测器 - 使用专业模型(Vectara + NLI)
    相比 LLM-as-a-Judge 方法:
    - 准确率从 60-75% 提升到 85-95%
    - 速度提升 5-10 倍
    - 成本降低 90%
    """
    
    def __init__(self, method: str = "hybrid"):
        """
        初始化幻觉检测器
        
        Args:
            method: 'vectara', 'nli', 或 'hybrid' (推荐)
        """
        # 尝试加载专业检测模型
        try:
            from hallucination_detector import initialize_hallucination_detector
            self.detector = initialize_hallucination_detector(method=method)
            self.use_professional_detector = True
            print(f"✅ 使用专业幻觉检测器: {method}")
        except Exception as e:
            print(f"⚠️ 专业检测器加载失败,回退到 LLM 方法: {e}")
            self.use_professional_detector = False
            # 回退到原有的 LLM 方法
            self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0)
            self.prompt = PromptTemplate(
                template="""你是一个评分员,评估LLM生成是否基于/支持一组检索到的事实。
                给出二进制分数'yes'或'no'。'yes'意味着答案基于/支持文档。
                将二进制分数作为JSON提供,只包含'score'键,不要前言或解释。
                
                检索到的文档:

 {documents} 


                LLM生成:{generation}""",
                input_variables=["generation", "documents"],
            )
            self.grader = self.prompt | self.llm | JsonOutputParser()
    
    def grade(self, generation: str, documents) -> str:
        """
        检测生成内容是否存在幻觉
        
        Args:
            generation: LLM 生成的内容
            documents: 参考文档
            
        Returns:
            "yes" 表示无幻觉,"no" 表示有幻觉
        """
        if self.use_professional_detector:
            # 使用专业检测器
            return self.detector.grade(generation, documents)
        else:
            # 回退到 LLM 方法
            result = self.grader.invoke({"generation": generation, "documents": documents})
            return result.get("score", "no")


class QueryDecomposer:
    """查询分解器,将复杂的多跳问题分解为子问题序列"""
    
    def __init__(self):
        self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0)
        self.prompt = PromptTemplate(
            template="""你是一个查询分解专家。你的任务是将一个复杂的多跳问题分解为一系列简单的子问题,这些子问题可以按顺序检索来回答原始问题。
            
            分解规则:
            1. 如果问题很简单,不需要分解,返回只包含原始问题的列表。
            2. 如果问题需要多步推理(例如"A的作者的大学在哪里"),分解为逻辑步骤:
               - 步骤1: "谁是A的作者?"
               - 步骤2: "该作者在哪个大学?"
            3. 保持子问题简洁明了。
            4. 即使返回单个问题,也必须包装在JSON的 sub_queries 列表中。
            
            输出格式:返回一个包含 'sub_queries' 键的 JSON,其值为字符串列表。
            不要输出任何前言或解释。
            
            复杂问题: {question}""",
            input_variables=["question"],
        )
        self.decomposer = self.prompt | self.llm | JsonOutputParser()
    
    def decompose(self, question: str) -> List[str]:
        """分解问题"""
        print(f"---分解问题: {question}---")
        try:
            result = self.decomposer.invoke({"question": question})
            sub_queries = result.get("sub_queries", [question])
            # 确保至少包含原始问题
            if not sub_queries:
                sub_queries = [question]
            print(f"---子问题: {sub_queries}---")
            return sub_queries
        except Exception as e:
            print(f"⚠️ 分解失败: {e},使用原始问题")
            return [question]


class AnswerabilityGrader:
    """答案可回答性评分器,用于判断当前检索到的文档是否足够回答原始问题"""
    
    def __init__(self):
        self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0)
        self.prompt = PromptTemplate(
            template="""你是一个专家评分员,负责评估检索到的文档是否包含足够的信息来回答用户的问题。
            
            原始问题: {question}
            
            目前检索到的文档集合:
            {documents}
            
            任务:
            判断上述文档是否已经包含了回答原始问题所需的全部关键信息。
            - 如果信息充足,可以终止进一步的检索,返回 'yes'。
            - 如果信息缺失,需要继续检索更多信息,返回 'no'。
            
            输出格式:
            返回一个只包含 'score' 键的 JSON,值为 'yes' 或 'no'。
            不要输出任何前言或解释。""",
            input_variables=["question", "documents"],
        )
        self.grader = self.prompt | self.llm | JsonOutputParser()
    
    def grade(self, question: str, documents: str) -> str:
        """评估文档是否足以回答问题"""
        result = self.grader.invoke({"question": question, "documents": documents})
        return result.get("score", "no")


class QueryRewriter:
    """查询重写器,优化查询以获得更好的检索结果"""
    
    def __init__(self):
        self.llm = ChatOllama(model=LOCAL_LLM, temperature=0)
        self.prompt = PromptTemplate(
            template="""你是一个问题重写器,负责将输入问题转换为更适合向量存储检索的更好版本。
            
            你的目标是根据原始问题和(可选的)之前的检索上下文,生成一个新的查询,以便检索到回答问题所需的缺失信息。
            如果提供了之前的上下文,请分析其中缺少什么信息,并针对缺失的信息构建查询。
            
            初始问题: {question}
            
            之前的上下文(如果有):
            {context}
            
            改进的问题(只输出问题,无前言):""",
            input_variables=["question", "context"],
        )
        self.rewriter = self.prompt | self.llm | StrOutputParser()
    
    def rewrite(self, question: str, context: str = "") -> str:
        """重写查询以获得更好的检索效果"""
        print(f"---原始查询: {question}---")
        if context:
            print(f"---参考上下文长度: {len(context)} 字符---")
            
        rewritten_query = self.rewriter.invoke({"question": question, "context": context})
        print(f"---重写查询: {rewritten_query}---")
        return rewritten_query


def initialize_graders_and_router():
    """初始化所有评分器和路由器"""
    # Load detection method from config
    try:
        from hallucination_config import HALLUCINATION_DETECTION_METHOD
        detection_method = HALLUCINATION_DETECTION_METHOD
    except ImportError:
        detection_method = "hybrid"  # Default to hybrid
    
    query_router = QueryRouter()
    document_grader = DocumentGrader()
    answer_grader = AnswerGrader()
    hallucination_grader = HallucinationGrader(method=detection_method)
    query_rewriter = QueryRewriter()
    query_decomposer = QueryDecomposer()
    answerability_grader = AnswerabilityGrader()
    
    return {
        "query_router": query_router,
        "document_grader": document_grader,
        "answer_grader": answer_grader,
        "hallucination_grader": hallucination_grader,
        "query_rewriter": query_rewriter,
        "query_decomposer": query_decomposer,
        "answerability_grader": answerability_grader
    }