Spaces:
Paused
Paused
File size: 11,717 Bytes
399f3c6 9cce495 90b33eb 94a7032 7a4bc96 90b33eb 399f3c6 401184c 399f3c6 401184c 399f3c6 401184c 399f3c6 401184c 399f3c6 9cce495 399f3c6 9cce495 399f3c6 9cce495 399f3c6 9cce495 399f3c6 401184c 399f3c6 401184c 399f3c6 9cce495 399f3c6 9cce495 399f3c6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 |
"""
路由器和评分器模块
包含查询路由、文档相关性评分、答案质量评分和幻觉检测
"""
from typing import List
try:
from langchain_core.prompts import PromptTemplate
except ImportError:
try:
from langchain_core.prompts import PromptTemplate
except ImportError:
from langchain.prompts import PromptTemplate
from langchain_community.chat_models import ChatOllama
from langchain_core.output_parsers import JsonOutputParser, StrOutputParser
from config import LOCAL_LLM
class QueryRouter:
"""查询路由器,决定使用向量存储还是网络搜索"""
def __init__(self):
self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0)
self.prompt = PromptTemplate(
template="""你是一个专家,负责将用户问题路由到向量存储或网络搜索。
对于关于LLM智能体、提示工程和对抗性攻击的问题,使用向量存储。
你不需要严格匹配问题中与这些主题相关的关键词。
否则,使用网络搜索。根据问题给出二进制选择'web_search'或'vectorstore'。
返回一个只包含'datasource'键的JSON,不要前言或解释。
要路由的问题:{question}""",
input_variables=["question"],
)
self.router = self.prompt | self.llm | JsonOutputParser()
def route(self, question: str) -> str:
"""路由问题到相应的数据源"""
result = self.router.invoke({"question": question})
return result.get("datasource", "web_search")
class DocumentGrader:
"""文档相关性评分器"""
def __init__(self):
self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0)
self.prompt = PromptTemplate(
template="""你是一个评分员,评估检索到的文档是否与用户问题相关。
如果文档包含与用户问题相关的关键词或语义,请给出'yes'分数。
给出二进制分数'yes'或'no',以表明文档是否与问题相关。
将二进制分数作为JSON提供,只包含'score'键,不要前言或解释。
检索到的文档:
{document}
用户问题:{question}""",
input_variables=["question", "document"],
)
self.grader = self.prompt | self.llm | JsonOutputParser()
def grade(self, question: str, document: str) -> str:
"""评估文档与问题的相关性"""
result = self.grader.invoke({"question": question, "document": document})
return result.get("score", "no")
class AnswerGrader:
"""答案质量评分器"""
def __init__(self):
self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0)
self.prompt = PromptTemplate(
template="""你是一个评分员,评估答案是否有助于解决问题。
这里是答案:
\n ------- \n
{generation}
\n ------- \n
这里是问题:{question}
给出二进制分数'yes'或'no',表示答案是否有助于解决问题。
将二进制分数作为JSON提供,只包含'score'键,不要前言或解释。""",
input_variables=["generation", "question"],
)
self.grader = self.prompt | self.llm | JsonOutputParser()
def grade(self, question: str, generation: str) -> str:
"""评估答案质量"""
result = self.grader.invoke({"question": question, "generation": generation})
return result.get("score", "no")
class HallucinationGrader:
"""
幻觉检测器 - 使用专业模型(Vectara + NLI)
相比 LLM-as-a-Judge 方法:
- 准确率从 60-75% 提升到 85-95%
- 速度提升 5-10 倍
- 成本降低 90%
"""
def __init__(self, method: str = "hybrid"):
"""
初始化幻觉检测器
Args:
method: 'vectara', 'nli', 或 'hybrid' (推荐)
"""
# 尝试加载专业检测模型
try:
from hallucination_detector import initialize_hallucination_detector
self.detector = initialize_hallucination_detector(method=method)
self.use_professional_detector = True
print(f"✅ 使用专业幻觉检测器: {method}")
except Exception as e:
print(f"⚠️ 专业检测器加载失败,回退到 LLM 方法: {e}")
self.use_professional_detector = False
# 回退到原有的 LLM 方法
self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0)
self.prompt = PromptTemplate(
template="""你是一个评分员,评估LLM生成是否基于/支持一组检索到的事实。
给出二进制分数'yes'或'no'。'yes'意味着答案基于/支持文档。
将二进制分数作为JSON提供,只包含'score'键,不要前言或解释。
检索到的文档:
{documents}
LLM生成:{generation}""",
input_variables=["generation", "documents"],
)
self.grader = self.prompt | self.llm | JsonOutputParser()
def grade(self, generation: str, documents) -> str:
"""
检测生成内容是否存在幻觉
Args:
generation: LLM 生成的内容
documents: 参考文档
Returns:
"yes" 表示无幻觉,"no" 表示有幻觉
"""
if self.use_professional_detector:
# 使用专业检测器
return self.detector.grade(generation, documents)
else:
# 回退到 LLM 方法
result = self.grader.invoke({"generation": generation, "documents": documents})
return result.get("score", "no")
class QueryDecomposer:
"""查询分解器,将复杂的多跳问题分解为子问题序列"""
def __init__(self):
self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0)
self.prompt = PromptTemplate(
template="""你是一个查询分解专家。你的任务是将一个复杂的多跳问题分解为一系列简单的子问题,这些子问题可以按顺序检索来回答原始问题。
分解规则:
1. 如果问题很简单,不需要分解,返回只包含原始问题的列表。
2. 如果问题需要多步推理(例如"A的作者的大学在哪里"),分解为逻辑步骤:
- 步骤1: "谁是A的作者?"
- 步骤2: "该作者在哪个大学?"
3. 保持子问题简洁明了。
4. 即使返回单个问题,也必须包装在JSON的 sub_queries 列表中。
输出格式:返回一个包含 'sub_queries' 键的 JSON,其值为字符串列表。
不要输出任何前言或解释。
复杂问题: {question}""",
input_variables=["question"],
)
self.decomposer = self.prompt | self.llm | JsonOutputParser()
def decompose(self, question: str) -> List[str]:
"""分解问题"""
print(f"---分解问题: {question}---")
try:
result = self.decomposer.invoke({"question": question})
sub_queries = result.get("sub_queries", [question])
# 确保至少包含原始问题
if not sub_queries:
sub_queries = [question]
print(f"---子问题: {sub_queries}---")
return sub_queries
except Exception as e:
print(f"⚠️ 分解失败: {e},使用原始问题")
return [question]
class AnswerabilityGrader:
"""答案可回答性评分器,用于判断当前检索到的文档是否足够回答原始问题"""
def __init__(self):
self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0)
self.prompt = PromptTemplate(
template="""你是一个专家评分员,负责评估检索到的文档是否包含足够的信息来回答用户的问题。
原始问题: {question}
目前检索到的文档集合:
{documents}
任务:
判断上述文档是否已经包含了回答原始问题所需的全部关键信息。
- 如果信息充足,可以终止进一步的检索,返回 'yes'。
- 如果信息缺失,需要继续检索更多信息,返回 'no'。
输出格式:
返回一个只包含 'score' 键的 JSON,值为 'yes' 或 'no'。
不要输出任何前言或解释。""",
input_variables=["question", "documents"],
)
self.grader = self.prompt | self.llm | JsonOutputParser()
def grade(self, question: str, documents: str) -> str:
"""评估文档是否足以回答问题"""
result = self.grader.invoke({"question": question, "documents": documents})
return result.get("score", "no")
class QueryRewriter:
"""查询重写器,优化查询以获得更好的检索结果"""
def __init__(self):
self.llm = ChatOllama(model=LOCAL_LLM, temperature=0)
self.prompt = PromptTemplate(
template="""你是一个问题重写器,负责将输入问题转换为更适合向量存储检索的更好版本。
你的目标是根据原始问题和(可选的)之前的检索上下文,生成一个新的查询,以便检索到回答问题所需的缺失信息。
如果提供了之前的上下文,请分析其中缺少什么信息,并针对缺失的信息构建查询。
初始问题: {question}
之前的上下文(如果有):
{context}
改进的问题(只输出问题,无前言):""",
input_variables=["question", "context"],
)
self.rewriter = self.prompt | self.llm | StrOutputParser()
def rewrite(self, question: str, context: str = "") -> str:
"""重写查询以获得更好的检索效果"""
print(f"---原始查询: {question}---")
if context:
print(f"---参考上下文长度: {len(context)} 字符---")
rewritten_query = self.rewriter.invoke({"question": question, "context": context})
print(f"---重写查询: {rewritten_query}---")
return rewritten_query
def initialize_graders_and_router():
"""初始化所有评分器和路由器"""
# Load detection method from config
try:
from hallucination_config import HALLUCINATION_DETECTION_METHOD
detection_method = HALLUCINATION_DETECTION_METHOD
except ImportError:
detection_method = "hybrid" # Default to hybrid
query_router = QueryRouter()
document_grader = DocumentGrader()
answer_grader = AnswerGrader()
hallucination_grader = HallucinationGrader(method=detection_method)
query_rewriter = QueryRewriter()
query_decomposer = QueryDecomposer()
answerability_grader = AnswerabilityGrader()
return {
"query_router": query_router,
"document_grader": document_grader,
"answer_grader": answer_grader,
"hallucination_grader": hallucination_grader,
"query_rewriter": query_rewriter,
"query_decomposer": query_decomposer,
"answerability_grader": answerability_grader
} |