adaptive_rag / routers_and_graders.py
lanny xu
add react
9cce495
"""
路由器和评分器模块
包含查询路由、文档相关性评分、答案质量评分和幻觉检测
"""
from typing import List
try:
from langchain_core.prompts import PromptTemplate
except ImportError:
try:
from langchain_core.prompts import PromptTemplate
except ImportError:
from langchain.prompts import PromptTemplate
from langchain_community.chat_models import ChatOllama
from langchain_core.output_parsers import JsonOutputParser, StrOutputParser
from config import LOCAL_LLM
class QueryRouter:
"""查询路由器,决定使用向量存储还是网络搜索"""
def __init__(self):
self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0)
self.prompt = PromptTemplate(
template="""你是一个专家,负责将用户问题路由到向量存储或网络搜索。
对于关于LLM智能体、提示工程和对抗性攻击的问题,使用向量存储。
你不需要严格匹配问题中与这些主题相关的关键词。
否则,使用网络搜索。根据问题给出二进制选择'web_search'或'vectorstore'。
返回一个只包含'datasource'键的JSON,不要前言或解释。
要路由的问题:{question}""",
input_variables=["question"],
)
self.router = self.prompt | self.llm | JsonOutputParser()
def route(self, question: str) -> str:
"""路由问题到相应的数据源"""
result = self.router.invoke({"question": question})
return result.get("datasource", "web_search")
class DocumentGrader:
"""文档相关性评分器"""
def __init__(self):
self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0)
self.prompt = PromptTemplate(
template="""你是一个评分员,评估检索到的文档是否与用户问题相关。
如果文档包含与用户问题相关的关键词或语义,请给出'yes'分数。
给出二进制分数'yes'或'no',以表明文档是否与问题相关。
将二进制分数作为JSON提供,只包含'score'键,不要前言或解释。
检索到的文档:
{document}
用户问题:{question}""",
input_variables=["question", "document"],
)
self.grader = self.prompt | self.llm | JsonOutputParser()
def grade(self, question: str, document: str) -> str:
"""评估文档与问题的相关性"""
result = self.grader.invoke({"question": question, "document": document})
return result.get("score", "no")
class AnswerGrader:
"""答案质量评分器"""
def __init__(self):
self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0)
self.prompt = PromptTemplate(
template="""你是一个评分员,评估答案是否有助于解决问题。
这里是答案:
\n ------- \n
{generation}
\n ------- \n
这里是问题:{question}
给出二进制分数'yes'或'no',表示答案是否有助于解决问题。
将二进制分数作为JSON提供,只包含'score'键,不要前言或解释。""",
input_variables=["generation", "question"],
)
self.grader = self.prompt | self.llm | JsonOutputParser()
def grade(self, question: str, generation: str) -> str:
"""评估答案质量"""
result = self.grader.invoke({"question": question, "generation": generation})
return result.get("score", "no")
class HallucinationGrader:
"""
幻觉检测器 - 使用专业模型(Vectara + NLI)
相比 LLM-as-a-Judge 方法:
- 准确率从 60-75% 提升到 85-95%
- 速度提升 5-10 倍
- 成本降低 90%
"""
def __init__(self, method: str = "hybrid"):
"""
初始化幻觉检测器
Args:
method: 'vectara', 'nli', 或 'hybrid' (推荐)
"""
# 尝试加载专业检测模型
try:
from hallucination_detector import initialize_hallucination_detector
self.detector = initialize_hallucination_detector(method=method)
self.use_professional_detector = True
print(f"✅ 使用专业幻觉检测器: {method}")
except Exception as e:
print(f"⚠️ 专业检测器加载失败,回退到 LLM 方法: {e}")
self.use_professional_detector = False
# 回退到原有的 LLM 方法
self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0)
self.prompt = PromptTemplate(
template="""你是一个评分员,评估LLM生成是否基于/支持一组检索到的事实。
给出二进制分数'yes'或'no'。'yes'意味着答案基于/支持文档。
将二进制分数作为JSON提供,只包含'score'键,不要前言或解释。
检索到的文档:
{documents}
LLM生成:{generation}""",
input_variables=["generation", "documents"],
)
self.grader = self.prompt | self.llm | JsonOutputParser()
def grade(self, generation: str, documents) -> str:
"""
检测生成内容是否存在幻觉
Args:
generation: LLM 生成的内容
documents: 参考文档
Returns:
"yes" 表示无幻觉,"no" 表示有幻觉
"""
if self.use_professional_detector:
# 使用专业检测器
return self.detector.grade(generation, documents)
else:
# 回退到 LLM 方法
result = self.grader.invoke({"generation": generation, "documents": documents})
return result.get("score", "no")
class QueryDecomposer:
"""查询分解器,将复杂的多跳问题分解为子问题序列"""
def __init__(self):
self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0)
self.prompt = PromptTemplate(
template="""你是一个查询分解专家。你的任务是将一个复杂的多跳问题分解为一系列简单的子问题,这些子问题可以按顺序检索来回答原始问题。
分解规则:
1. 如果问题很简单,不需要分解,返回只包含原始问题的列表。
2. 如果问题需要多步推理(例如"A的作者的大学在哪里"),分解为逻辑步骤:
- 步骤1: "谁是A的作者?"
- 步骤2: "该作者在哪个大学?"
3. 保持子问题简洁明了。
4. 即使返回单个问题,也必须包装在JSON的 sub_queries 列表中。
输出格式:返回一个包含 'sub_queries' 键的 JSON,其值为字符串列表。
不要输出任何前言或解释。
复杂问题: {question}""",
input_variables=["question"],
)
self.decomposer = self.prompt | self.llm | JsonOutputParser()
def decompose(self, question: str) -> List[str]:
"""分解问题"""
print(f"---分解问题: {question}---")
try:
result = self.decomposer.invoke({"question": question})
sub_queries = result.get("sub_queries", [question])
# 确保至少包含原始问题
if not sub_queries:
sub_queries = [question]
print(f"---子问题: {sub_queries}---")
return sub_queries
except Exception as e:
print(f"⚠️ 分解失败: {e},使用原始问题")
return [question]
class AnswerabilityGrader:
"""答案可回答性评分器,用于判断当前检索到的文档是否足够回答原始问题"""
def __init__(self):
self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0)
self.prompt = PromptTemplate(
template="""你是一个专家评分员,负责评估检索到的文档是否包含足够的信息来回答用户的问题。
原始问题: {question}
目前检索到的文档集合:
{documents}
任务:
判断上述文档是否已经包含了回答原始问题所需的全部关键信息。
- 如果信息充足,可以终止进一步的检索,返回 'yes'。
- 如果信息缺失,需要继续检索更多信息,返回 'no'。
输出格式:
返回一个只包含 'score' 键的 JSON,值为 'yes' 或 'no'。
不要输出任何前言或解释。""",
input_variables=["question", "documents"],
)
self.grader = self.prompt | self.llm | JsonOutputParser()
def grade(self, question: str, documents: str) -> str:
"""评估文档是否足以回答问题"""
result = self.grader.invoke({"question": question, "documents": documents})
return result.get("score", "no")
class QueryRewriter:
"""查询重写器,优化查询以获得更好的检索结果"""
def __init__(self):
self.llm = ChatOllama(model=LOCAL_LLM, temperature=0)
self.prompt = PromptTemplate(
template="""你是一个问题重写器,负责将输入问题转换为更适合向量存储检索的更好版本。
你的目标是根据原始问题和(可选的)之前的检索上下文,生成一个新的查询,以便检索到回答问题所需的缺失信息。
如果提供了之前的上下文,请分析其中缺少什么信息,并针对缺失的信息构建查询。
初始问题: {question}
之前的上下文(如果有):
{context}
改进的问题(只输出问题,无前言):""",
input_variables=["question", "context"],
)
self.rewriter = self.prompt | self.llm | StrOutputParser()
def rewrite(self, question: str, context: str = "") -> str:
"""重写查询以获得更好的检索效果"""
print(f"---原始查询: {question}---")
if context:
print(f"---参考上下文长度: {len(context)} 字符---")
rewritten_query = self.rewriter.invoke({"question": question, "context": context})
print(f"---重写查询: {rewritten_query}---")
return rewritten_query
def initialize_graders_and_router():
"""初始化所有评分器和路由器"""
# Load detection method from config
try:
from hallucination_config import HALLUCINATION_DETECTION_METHOD
detection_method = HALLUCINATION_DETECTION_METHOD
except ImportError:
detection_method = "hybrid" # Default to hybrid
query_router = QueryRouter()
document_grader = DocumentGrader()
answer_grader = AnswerGrader()
hallucination_grader = HallucinationGrader(method=detection_method)
query_rewriter = QueryRewriter()
query_decomposer = QueryDecomposer()
answerability_grader = AnswerabilityGrader()
return {
"query_router": query_router,
"document_grader": document_grader,
"answer_grader": answer_grader,
"hallucination_grader": hallucination_grader,
"query_rewriter": query_rewriter,
"query_decomposer": query_decomposer,
"answerability_grader": answerability_grader
}